role_based_system/venv/Lib/site-packages/aioredis/util.py

243 lines
6.6 KiB
Python
Raw Normal View History

import asyncio
import sys
from urllib.parse import urlparse, parse_qsl
from .log import logger
_NOTSET = object()
IS_PY38 = sys.version_info >= (3, 8)
# NOTE: never put here anything else;
# just this basic types
_converters = {
bytes: lambda val: val,
bytearray: lambda val: val,
str: lambda val: val.encode(),
int: lambda val: b'%d' % val,
float: lambda val: b'%r' % val,
}
def encode_command(*args, buf=None):
"""Encodes arguments into redis bulk-strings array.
Raises TypeError if any of args not of bytearray, bytes, float, int, or str
type.
"""
if buf is None:
buf = bytearray()
buf.extend(b'*%d\r\n' % len(args))
try:
for arg in args:
barg = _converters[type(arg)](arg)
buf.extend(b'$%d\r\n%s\r\n' % (len(barg), barg))
except KeyError:
raise TypeError("Argument {!r} expected to be of bytearray, bytes,"
" float, int, or str type".format(arg))
return buf
def decode(obj, encoding):
if isinstance(obj, bytes):
return obj.decode(encoding)
elif isinstance(obj, list):
return [decode(o, encoding) for o in obj]
return obj
async def wait_ok(fut):
res = await fut
if res in (b'QUEUED', 'QUEUED'):
return res
return res in (b'OK', 'OK')
async def wait_convert(fut, type_, **kwargs):
result = await fut
if result in (b'QUEUED', 'QUEUED'):
return result
return type_(result, **kwargs)
async def wait_make_dict(fut):
res = await fut
if res in (b'QUEUED', 'QUEUED'):
return res
it = iter(res)
return dict(zip(it, it))
class coerced_keys_dict(dict):
def __getitem__(self, other):
if not isinstance(other, bytes):
other = _converters[type(other)](other)
return dict.__getitem__(self, other)
def __contains__(self, other):
if not isinstance(other, bytes):
other = _converters[type(other)](other)
return dict.__contains__(self, other)
class _ScanIter:
__slots__ = ('_scan', '_cur', '_ret')
def __init__(self, scan):
self._scan = scan
self._cur = b'0'
self._ret = []
def __aiter__(self):
return self
async def __anext__(self):
while not self._ret and self._cur:
self._cur, self._ret = await self._scan(self._cur)
if not self._cur and not self._ret:
raise StopAsyncIteration # noqa
else:
ret = self._ret.pop(0)
return ret
def _set_result(fut, result, *info):
if fut.done():
logger.debug("Waiter future is already done %r %r", fut, info)
assert fut.cancelled(), (
"waiting future is in wrong state", fut, result, info)
else:
fut.set_result(result)
def _set_exception(fut, exception):
if fut.done():
logger.debug("Waiter future is already done %r", fut)
assert fut.cancelled(), (
"waiting future is in wrong state", fut, exception)
else:
fut.set_exception(exception)
def parse_url(url):
"""Parse Redis connection URI.
Parse according to IANA specs:
* https://www.iana.org/assignments/uri-schemes/prov/redis
* https://www.iana.org/assignments/uri-schemes/prov/rediss
Also more rules applied:
* empty scheme is treated as unix socket path no further parsing is done.
* 'unix://' scheme is treated as unix socket path and parsed.
* Multiple query parameter values and blank values are considered error.
* DB number specified as path and as query parameter is considered error.
* Password specified in userinfo and as query parameter is
considered error.
"""
r = urlparse(url)
assert r.scheme in ('', 'redis', 'rediss', 'unix'), (
"Unsupported URI scheme", r.scheme)
if r.scheme == '':
return url, {}
query = {}
for p, v in parse_qsl(r.query, keep_blank_values=True):
assert p not in query, ("Multiple parameters are not allowed", p, v)
assert v, ("Empty parameters are not allowed", p, v)
query[p] = v
if r.scheme == 'unix':
assert r.path, ("Empty path is not allowed", url)
assert not r.netloc, (
"Netlocation is not allowed for unix scheme", r.netloc)
return r.path, _parse_uri_options(query, '', r.password)
address = (r.hostname or 'localhost', int(r.port or 6379))
path = r.path
if path.startswith('/'):
path = r.path[1:]
options = _parse_uri_options(query, path, r.password)
if r.scheme == 'rediss':
options['ssl'] = True
return address, options
def _parse_uri_options(params, path, password):
def parse_db_num(val):
if not val:
return
assert val.isdecimal(), ("Invalid decimal integer", val)
assert val == '0' or not val.startswith('0'), (
"Expected integer without leading zeroes", val)
return int(val)
options = {}
db1 = parse_db_num(path)
db2 = parse_db_num(params.get('db'))
assert db1 is None or db2 is None, (
"Single DB value expected, got path and query", db1, db2)
if db1 is not None:
options['db'] = db1
elif db2 is not None:
options['db'] = db2
password2 = params.get('password')
assert not password or not password2, (
"Single password value is expected, got in net location and query")
if password:
options['password'] = password
elif password2:
options['password'] = password2
if 'encoding' in params:
options['encoding'] = params['encoding']
if 'ssl' in params:
assert params['ssl'] in ('true', 'false'), (
"Expected 'ssl' param to be 'true' or 'false' only",
params['ssl'])
options['ssl'] = params['ssl'] == 'true'
if 'timeout' in params:
options['timeout'] = float(params['timeout'])
return options
class CloseEvent:
def __init__(self, on_close):
self._close_init = asyncio.Event()
self._close_done = asyncio.Event()
self._on_close = on_close
async def wait(self):
await self._close_init.wait()
await self._close_done.wait()
def is_set(self):
return self._close_done.is_set() or self._close_init.is_set()
def set(self):
if self._close_init.is_set():
return
task = asyncio.ensure_future(self._on_close())
task.add_done_callback(self._cleanup)
self._close_init.set()
def _cleanup(self, task):
self._on_close = None
self._close_done.set()
get_event_loop = getattr(asyncio, 'get_running_loop', asyncio.get_event_loop)