243 lines
6.6 KiB
Python
243 lines
6.6 KiB
Python
![]() |
import asyncio
|
||
|
import sys
|
||
|
|
||
|
from urllib.parse import urlparse, parse_qsl
|
||
|
|
||
|
from .log import logger
|
||
|
|
||
|
_NOTSET = object()
|
||
|
|
||
|
IS_PY38 = sys.version_info >= (3, 8)
|
||
|
|
||
|
# NOTE: never put here anything else;
|
||
|
# just this basic types
|
||
|
_converters = {
|
||
|
bytes: lambda val: val,
|
||
|
bytearray: lambda val: val,
|
||
|
str: lambda val: val.encode(),
|
||
|
int: lambda val: b'%d' % val,
|
||
|
float: lambda val: b'%r' % val,
|
||
|
}
|
||
|
|
||
|
|
||
|
def encode_command(*args, buf=None):
|
||
|
"""Encodes arguments into redis bulk-strings array.
|
||
|
|
||
|
Raises TypeError if any of args not of bytearray, bytes, float, int, or str
|
||
|
type.
|
||
|
"""
|
||
|
if buf is None:
|
||
|
buf = bytearray()
|
||
|
buf.extend(b'*%d\r\n' % len(args))
|
||
|
|
||
|
try:
|
||
|
for arg in args:
|
||
|
barg = _converters[type(arg)](arg)
|
||
|
buf.extend(b'$%d\r\n%s\r\n' % (len(barg), barg))
|
||
|
except KeyError:
|
||
|
raise TypeError("Argument {!r} expected to be of bytearray, bytes,"
|
||
|
" float, int, or str type".format(arg))
|
||
|
return buf
|
||
|
|
||
|
|
||
|
def decode(obj, encoding):
|
||
|
if isinstance(obj, bytes):
|
||
|
return obj.decode(encoding)
|
||
|
elif isinstance(obj, list):
|
||
|
return [decode(o, encoding) for o in obj]
|
||
|
return obj
|
||
|
|
||
|
|
||
|
async def wait_ok(fut):
|
||
|
res = await fut
|
||
|
if res in (b'QUEUED', 'QUEUED'):
|
||
|
return res
|
||
|
return res in (b'OK', 'OK')
|
||
|
|
||
|
|
||
|
async def wait_convert(fut, type_, **kwargs):
|
||
|
result = await fut
|
||
|
if result in (b'QUEUED', 'QUEUED'):
|
||
|
return result
|
||
|
return type_(result, **kwargs)
|
||
|
|
||
|
|
||
|
async def wait_make_dict(fut):
|
||
|
res = await fut
|
||
|
if res in (b'QUEUED', 'QUEUED'):
|
||
|
return res
|
||
|
it = iter(res)
|
||
|
return dict(zip(it, it))
|
||
|
|
||
|
|
||
|
class coerced_keys_dict(dict):
|
||
|
|
||
|
def __getitem__(self, other):
|
||
|
if not isinstance(other, bytes):
|
||
|
other = _converters[type(other)](other)
|
||
|
return dict.__getitem__(self, other)
|
||
|
|
||
|
def __contains__(self, other):
|
||
|
if not isinstance(other, bytes):
|
||
|
other = _converters[type(other)](other)
|
||
|
return dict.__contains__(self, other)
|
||
|
|
||
|
|
||
|
class _ScanIter:
|
||
|
|
||
|
__slots__ = ('_scan', '_cur', '_ret')
|
||
|
|
||
|
def __init__(self, scan):
|
||
|
self._scan = scan
|
||
|
self._cur = b'0'
|
||
|
self._ret = []
|
||
|
|
||
|
def __aiter__(self):
|
||
|
return self
|
||
|
|
||
|
async def __anext__(self):
|
||
|
while not self._ret and self._cur:
|
||
|
self._cur, self._ret = await self._scan(self._cur)
|
||
|
if not self._cur and not self._ret:
|
||
|
raise StopAsyncIteration # noqa
|
||
|
else:
|
||
|
ret = self._ret.pop(0)
|
||
|
return ret
|
||
|
|
||
|
|
||
|
def _set_result(fut, result, *info):
|
||
|
if fut.done():
|
||
|
logger.debug("Waiter future is already done %r %r", fut, info)
|
||
|
assert fut.cancelled(), (
|
||
|
"waiting future is in wrong state", fut, result, info)
|
||
|
else:
|
||
|
fut.set_result(result)
|
||
|
|
||
|
|
||
|
def _set_exception(fut, exception):
|
||
|
if fut.done():
|
||
|
logger.debug("Waiter future is already done %r", fut)
|
||
|
assert fut.cancelled(), (
|
||
|
"waiting future is in wrong state", fut, exception)
|
||
|
else:
|
||
|
fut.set_exception(exception)
|
||
|
|
||
|
|
||
|
def parse_url(url):
|
||
|
"""Parse Redis connection URI.
|
||
|
|
||
|
Parse according to IANA specs:
|
||
|
* https://www.iana.org/assignments/uri-schemes/prov/redis
|
||
|
* https://www.iana.org/assignments/uri-schemes/prov/rediss
|
||
|
|
||
|
Also more rules applied:
|
||
|
|
||
|
* empty scheme is treated as unix socket path no further parsing is done.
|
||
|
|
||
|
* 'unix://' scheme is treated as unix socket path and parsed.
|
||
|
|
||
|
* Multiple query parameter values and blank values are considered error.
|
||
|
|
||
|
* DB number specified as path and as query parameter is considered error.
|
||
|
|
||
|
* Password specified in userinfo and as query parameter is
|
||
|
considered error.
|
||
|
"""
|
||
|
r = urlparse(url)
|
||
|
|
||
|
assert r.scheme in ('', 'redis', 'rediss', 'unix'), (
|
||
|
"Unsupported URI scheme", r.scheme)
|
||
|
if r.scheme == '':
|
||
|
return url, {}
|
||
|
query = {}
|
||
|
for p, v in parse_qsl(r.query, keep_blank_values=True):
|
||
|
assert p not in query, ("Multiple parameters are not allowed", p, v)
|
||
|
assert v, ("Empty parameters are not allowed", p, v)
|
||
|
query[p] = v
|
||
|
|
||
|
if r.scheme == 'unix':
|
||
|
assert r.path, ("Empty path is not allowed", url)
|
||
|
assert not r.netloc, (
|
||
|
"Netlocation is not allowed for unix scheme", r.netloc)
|
||
|
return r.path, _parse_uri_options(query, '', r.password)
|
||
|
|
||
|
address = (r.hostname or 'localhost', int(r.port or 6379))
|
||
|
path = r.path
|
||
|
if path.startswith('/'):
|
||
|
path = r.path[1:]
|
||
|
options = _parse_uri_options(query, path, r.password)
|
||
|
if r.scheme == 'rediss':
|
||
|
options['ssl'] = True
|
||
|
return address, options
|
||
|
|
||
|
|
||
|
def _parse_uri_options(params, path, password):
|
||
|
|
||
|
def parse_db_num(val):
|
||
|
if not val:
|
||
|
return
|
||
|
assert val.isdecimal(), ("Invalid decimal integer", val)
|
||
|
assert val == '0' or not val.startswith('0'), (
|
||
|
"Expected integer without leading zeroes", val)
|
||
|
return int(val)
|
||
|
|
||
|
options = {}
|
||
|
|
||
|
db1 = parse_db_num(path)
|
||
|
db2 = parse_db_num(params.get('db'))
|
||
|
assert db1 is None or db2 is None, (
|
||
|
"Single DB value expected, got path and query", db1, db2)
|
||
|
if db1 is not None:
|
||
|
options['db'] = db1
|
||
|
elif db2 is not None:
|
||
|
options['db'] = db2
|
||
|
|
||
|
password2 = params.get('password')
|
||
|
assert not password or not password2, (
|
||
|
"Single password value is expected, got in net location and query")
|
||
|
if password:
|
||
|
options['password'] = password
|
||
|
elif password2:
|
||
|
options['password'] = password2
|
||
|
|
||
|
if 'encoding' in params:
|
||
|
options['encoding'] = params['encoding']
|
||
|
if 'ssl' in params:
|
||
|
assert params['ssl'] in ('true', 'false'), (
|
||
|
"Expected 'ssl' param to be 'true' or 'false' only",
|
||
|
params['ssl'])
|
||
|
options['ssl'] = params['ssl'] == 'true'
|
||
|
|
||
|
if 'timeout' in params:
|
||
|
options['timeout'] = float(params['timeout'])
|
||
|
return options
|
||
|
|
||
|
|
||
|
class CloseEvent:
|
||
|
def __init__(self, on_close):
|
||
|
self._close_init = asyncio.Event()
|
||
|
self._close_done = asyncio.Event()
|
||
|
self._on_close = on_close
|
||
|
|
||
|
async def wait(self):
|
||
|
await self._close_init.wait()
|
||
|
await self._close_done.wait()
|
||
|
|
||
|
def is_set(self):
|
||
|
return self._close_done.is_set() or self._close_init.is_set()
|
||
|
|
||
|
def set(self):
|
||
|
if self._close_init.is_set():
|
||
|
return
|
||
|
|
||
|
task = asyncio.ensure_future(self._on_close())
|
||
|
task.add_done_callback(self._cleanup)
|
||
|
self._close_init.set()
|
||
|
|
||
|
def _cleanup(self, task):
|
||
|
self._on_close = None
|
||
|
self._close_done.set()
|
||
|
|
||
|
|
||
|
get_event_loop = getattr(asyncio, 'get_running_loop', asyncio.get_event_loop)
|