319 lines
9.9 KiB
Python
319 lines
9.9 KiB
Python
import asyncio
|
|
import functools
|
|
|
|
from ..abc import AbcPool
|
|
from ..errors import (
|
|
RedisError,
|
|
PipelineError,
|
|
MultiExecError,
|
|
ConnectionClosedError,
|
|
)
|
|
from ..util import (
|
|
wait_ok,
|
|
_set_exception,
|
|
get_event_loop,
|
|
)
|
|
|
|
|
|
class TransactionsCommandsMixin:
|
|
"""Transaction commands mixin.
|
|
|
|
For commands details see: http://redis.io/commands/#transactions
|
|
|
|
Transactions HOWTO:
|
|
|
|
>>> tr = redis.multi_exec()
|
|
>>> result_future1 = tr.incr('foo')
|
|
>>> result_future2 = tr.incr('bar')
|
|
>>> try:
|
|
... result = await tr.execute()
|
|
... except MultiExecError:
|
|
... pass # check what happened
|
|
>>> result1 = await result_future1
|
|
>>> result2 = await result_future2
|
|
>>> assert result == [result1, result2]
|
|
"""
|
|
|
|
def unwatch(self):
|
|
"""Forget about all watched keys."""
|
|
fut = self._pool_or_conn.execute(b'UNWATCH')
|
|
return wait_ok(fut)
|
|
|
|
def watch(self, key, *keys):
|
|
"""Watch the given keys to determine execution of the MULTI/EXEC block.
|
|
"""
|
|
# FIXME: we can send watch through one connection and then issue
|
|
# 'multi/exec' command through other.
|
|
# Possible fix:
|
|
# "Remember" a connection that was used for 'watch' command
|
|
# and then send 'multi / exec / discard' through it.
|
|
fut = self._pool_or_conn.execute(b'WATCH', key, *keys)
|
|
return wait_ok(fut)
|
|
|
|
def multi_exec(self):
|
|
"""Returns MULTI/EXEC pipeline wrapper.
|
|
|
|
Usage:
|
|
|
|
>>> tr = redis.multi_exec()
|
|
>>> fut1 = tr.incr('foo') # NO `await` as it will block forever!
|
|
>>> fut2 = tr.incr('bar')
|
|
>>> result = await tr.execute()
|
|
>>> result
|
|
[1, 1]
|
|
>>> await asyncio.gather(fut1, fut2)
|
|
[1, 1]
|
|
"""
|
|
return MultiExec(self._pool_or_conn, self.__class__)
|
|
|
|
def pipeline(self):
|
|
"""Returns :class:`Pipeline` object to execute bulk of commands.
|
|
|
|
It is provided for convenience.
|
|
Commands can be pipelined without it.
|
|
|
|
Example:
|
|
|
|
>>> pipe = redis.pipeline()
|
|
>>> fut1 = pipe.incr('foo') # NO `await` as it will block forever!
|
|
>>> fut2 = pipe.incr('bar')
|
|
>>> result = await pipe.execute()
|
|
>>> result
|
|
[1, 1]
|
|
>>> await asyncio.gather(fut1, fut2)
|
|
[1, 1]
|
|
>>> #
|
|
>>> # The same can be done without pipeline:
|
|
>>> #
|
|
>>> fut1 = redis.incr('foo') # the 'INCRY foo' command already sent
|
|
>>> fut2 = redis.incr('bar')
|
|
>>> await asyncio.gather(fut1, fut2)
|
|
[2, 2]
|
|
"""
|
|
return Pipeline(self._pool_or_conn, self.__class__)
|
|
|
|
|
|
class _RedisBuffer:
|
|
|
|
def __init__(self, pipeline, *, loop=None):
|
|
# TODO: deprecation note
|
|
# if loop is None:
|
|
# loop = asyncio.get_event_loop()
|
|
self._pipeline = pipeline
|
|
|
|
def execute(self, cmd, *args, **kw):
|
|
fut = get_event_loop().create_future()
|
|
self._pipeline.append((fut, cmd, args, kw))
|
|
return fut
|
|
|
|
# TODO: add here or remove in connection methods like `select`, `auth` etc
|
|
|
|
|
|
class Pipeline:
|
|
"""Commands pipeline.
|
|
|
|
Usage:
|
|
|
|
>>> pipe = redis.pipeline()
|
|
>>> fut1 = pipe.incr('foo')
|
|
>>> fut2 = pipe.incr('bar')
|
|
>>> await pipe.execute()
|
|
[1, 1]
|
|
>>> await fut1
|
|
1
|
|
>>> await fut2
|
|
1
|
|
"""
|
|
error_class = PipelineError
|
|
|
|
def __init__(self, pool_or_connection, commands_factory=lambda conn: conn,
|
|
*, loop=None):
|
|
# TODO: deprecation note
|
|
# if loop is None:
|
|
# loop = asyncio.get_event_loop()
|
|
self._pool_or_conn = pool_or_connection
|
|
self._pipeline = []
|
|
self._results = []
|
|
self._buffer = _RedisBuffer(self._pipeline)
|
|
self._redis = commands_factory(self._buffer)
|
|
self._done = False
|
|
|
|
def __getattr__(self, name):
|
|
assert not self._done, "Pipeline already executed. Create new one."
|
|
attr = getattr(self._redis, name)
|
|
if callable(attr):
|
|
|
|
@functools.wraps(attr)
|
|
def wrapper(*args, **kw):
|
|
try:
|
|
task = asyncio.ensure_future(attr(*args, **kw))
|
|
except Exception as exc:
|
|
task = get_event_loop().create_future()
|
|
task.set_exception(exc)
|
|
self._results.append(task)
|
|
return task
|
|
return wrapper
|
|
return attr
|
|
|
|
async def execute(self, *, return_exceptions=False):
|
|
"""Execute all buffered commands.
|
|
|
|
Any exception that is raised by any command is caught and
|
|
raised later when processing results.
|
|
|
|
Exceptions can also be returned in result if
|
|
`return_exceptions` flag is set to True.
|
|
"""
|
|
assert not self._done, "Pipeline already executed. Create new one."
|
|
self._done = True
|
|
|
|
if self._pipeline:
|
|
if isinstance(self._pool_or_conn, AbcPool):
|
|
async with self._pool_or_conn.get() as conn:
|
|
return await self._do_execute(
|
|
conn, return_exceptions=return_exceptions)
|
|
else:
|
|
return await self._do_execute(
|
|
self._pool_or_conn,
|
|
return_exceptions=return_exceptions)
|
|
else:
|
|
return await self._gather_result(return_exceptions)
|
|
|
|
async def _do_execute(self, conn, *, return_exceptions=False):
|
|
await asyncio.gather(*self._send_pipeline(conn),
|
|
return_exceptions=True)
|
|
return await self._gather_result(return_exceptions)
|
|
|
|
async def _gather_result(self, return_exceptions):
|
|
errors = []
|
|
results = []
|
|
for fut in self._results:
|
|
try:
|
|
res = await fut
|
|
results.append(res)
|
|
except Exception as exc:
|
|
errors.append(exc)
|
|
results.append(exc)
|
|
if errors and not return_exceptions:
|
|
raise self.error_class(errors)
|
|
return results
|
|
|
|
def _send_pipeline(self, conn):
|
|
with conn._buffered():
|
|
for fut, cmd, args, kw in self._pipeline:
|
|
try:
|
|
result_fut = conn.execute(cmd, *args, **kw)
|
|
result_fut.add_done_callback(
|
|
functools.partial(self._check_result, waiter=fut))
|
|
except Exception as exc:
|
|
fut.set_exception(exc)
|
|
else:
|
|
yield result_fut
|
|
|
|
def _check_result(self, fut, waiter):
|
|
if fut.cancelled():
|
|
waiter.cancel()
|
|
elif fut.exception():
|
|
waiter.set_exception(fut.exception())
|
|
else:
|
|
waiter.set_result(fut.result())
|
|
|
|
|
|
class MultiExec(Pipeline):
|
|
"""Multi/Exec pipeline wrapper.
|
|
|
|
Usage:
|
|
|
|
>>> tr = redis.multi_exec()
|
|
>>> f1 = tr.incr('foo')
|
|
>>> f2 = tr.incr('bar')
|
|
>>> # A)
|
|
>>> await tr.execute()
|
|
>>> res1 = await f1
|
|
>>> res2 = await f2
|
|
>>> # or B)
|
|
>>> res1, res2 = await tr.execute()
|
|
|
|
and ofcourse try/except:
|
|
|
|
>>> tr = redis.multi_exec()
|
|
>>> f1 = tr.incr('1') # won't raise any exception (why?)
|
|
>>> try:
|
|
... res = await tr.execute()
|
|
... except RedisError:
|
|
... pass
|
|
>>> assert f1.done()
|
|
>>> assert f1.result() is res
|
|
|
|
>>> tr = redis.multi_exec()
|
|
>>> wait_ok_coro = tr.mset('1')
|
|
>>> try:
|
|
... ok1 = await tr.execute()
|
|
... except RedisError:
|
|
... pass # handle it
|
|
>>> ok2 = await wait_ok_coro
|
|
>>> # for this to work `wait_ok_coro` must be wrapped in Future
|
|
"""
|
|
error_class = MultiExecError
|
|
|
|
async def _do_execute(self, conn, *, return_exceptions=False):
|
|
self._waiters = waiters = []
|
|
with conn._buffered():
|
|
multi = conn.execute('MULTI')
|
|
coros = list(self._send_pipeline(conn))
|
|
exec_ = conn.execute('EXEC')
|
|
gather = asyncio.gather(multi, *coros,
|
|
return_exceptions=True)
|
|
last_error = None
|
|
try:
|
|
await asyncio.shield(gather)
|
|
except asyncio.CancelledError:
|
|
await gather
|
|
except Exception as err:
|
|
last_error = err
|
|
raise
|
|
finally:
|
|
if conn.closed:
|
|
if last_error is None:
|
|
last_error = ConnectionClosedError()
|
|
for fut in waiters:
|
|
_set_exception(fut, last_error)
|
|
# fut.cancel()
|
|
for fut in self._results:
|
|
if not fut.done():
|
|
fut.set_exception(last_error)
|
|
# fut.cancel()
|
|
else:
|
|
try:
|
|
results = await exec_
|
|
except RedisError as err:
|
|
for fut in waiters:
|
|
fut.set_exception(err)
|
|
else:
|
|
assert len(results) == len(waiters), (
|
|
"Results does not match waiters", results, waiters)
|
|
self._resolve_waiters(results, return_exceptions)
|
|
return (await self._gather_result(return_exceptions))
|
|
|
|
def _resolve_waiters(self, results, return_exceptions):
|
|
errors = []
|
|
for val, fut in zip(results, self._waiters):
|
|
if isinstance(val, RedisError):
|
|
fut.set_exception(val)
|
|
errors.append(val)
|
|
else:
|
|
fut.set_result(val)
|
|
if errors and not return_exceptions:
|
|
raise MultiExecError(errors)
|
|
|
|
def _check_result(self, fut, waiter):
|
|
assert waiter not in self._waiters, (fut, waiter, self._waiters)
|
|
assert not waiter.done(), waiter
|
|
if fut.cancelled(): # await gather was cancelled
|
|
waiter.cancel()
|
|
elif fut.exception(): # server replied with error
|
|
waiter.set_exception(fut.exception())
|
|
elif fut.result() in {b'QUEUED', 'QUEUED'}:
|
|
# got result, it should be QUEUED
|
|
self._waiters.append(waiter)
|