import logging
from functools import partial
from collections import deque
from cStringIO import StringIO
import gevent
from gevent.queue import Queue
from gevent.event import AsyncResult
from .message import Message
from .encoding import encode_message
from . import spec
from . import exceptions
log = logging.getLogger(__name__)
class EventRegister(object):
"""Handle event registration according to AMQP semantics.
Effectively, we register synchronous events to be received once only, while
some events (eg. connection close) can be handled at any time, and persist.
"""
def __init__(self):
self.permanent = {}
self.once = deque()
def set_handler(self, method, callback):
"""Set a permanent handler for an event."""
self.permanent[method] = callback
def unset_handler(self, method):
"""Unset a permanent handler for an event."""
del(self.permanent[method])
def wait_event(self, methods, on_success, on_error):
"""Set a temporary handler for one or more events."""
self.once.append((frozenset(methods), on_success, on_error))
def cancel_event(self, methods, on_success, on_error):
"""Cancel a temporary handler for one or more events."""
try:
self.once.remove((frozenset(methods), on_success, on_error))
except ValueError:
pass
def get_error_handlers(self):
"""Get all error handlers."""
hs = [on_error for _, _, on_error in self.once]
self.once.clear()
return hs
def get(self, method):
"""Get the callbacks for a method.
This method returns at most one temporary callback and one permanent
callback.
"""
cbs = []
for handler in self.once:
methods, on_success, _ = handler
if method in methods:
self.once.remove(handler)
cbs.append(on_success)
break
try:
cbs.append(self.permanent[method])
except KeyError:
pass
return cbs
class Channel(spec.FrameWriter):
"""A communication channel.
Multiple channels are multiplexed onto the same connection.
"""
def __init__(self, connection, id):
self.connection = connection
self.id = id
self.exc = None
self._method = None
self.listeners = EventRegister()
self.queue = Queue()
self.listeners.set_handler('channel.close', self.on_channel_close)
self.listeners.set_handler('channel.close-ok', self.on_channel_close)
self.listeners.set_handler('connection.close', self.on_connection_close)
self.listeners.set_handler('connection.close-ok', self.on_connection_close)
# Start an initial dispatcher
self.replace_dispatcher()
def on_connection_close(self, frame):
"""Handle a connection close error."""
exc = exceptions.exception_from_frame(frame)
# Let the connection class dispatch this to all channels
self.connection.on_error(exc)
def on_error(self, exc):
"""Handle an error that is causing this channel to close."""
self.connection._remove_channel(self.id)
self.exc = exc
self.connection = None
for handler in self.listeners.get_error_handlers():
self.queue.put((handler, (exc,)))
self.stop_dispatcher()
def on_channel_close(self, frame):
"""Handle a channel.close method."""
exc = exceptions.exception_from_frame(frame)
if frame.name == 'channel.close':
self.channel_close_ok()
self.on_error(exc)
def _send(self, frame):
"""Send one frame over the channel."""
if self.exc:
raise self.exc
self.connection._send_frames(self.id, frame.encode())
def _send_message(self, frame, headers, payload):
"""Send method, header and body frames over the channel."""
if self.exc:
raise self.exc
fs = encode_message(frame, headers, payload, self.connection.frame_max)
self.connection._send_frames(self.id, fs)
def _on_method(self, frame):
"""Called when the channel has received a method frame."""
if frame.has_content:
self._method = frame
else:
self.dispatch(frame.name, frame)
def _on_headers(self, size, props):
"""Called when the channel has received a headers frame."""
self._headers = props
self._to_read = size
self._body = StringIO()
self._on_content_receive()
def _on_body(self, payload):
"""Called when the channel has received a body frame."""
self._body.write(payload)
self._to_read -= len(payload)
self._on_content_receive()
def _on_content_receive(self):
"""Check whether a full message has been read, and if so, dispatch it.
No payload frame is sent if the body was empty.
"""
if self._to_read <= 0:
if self._headers is None or self._method is None:
return
m = Message(self,
self._method,
self._headers,
self._body.getvalue()
)
self.dispatch(self._method.name, m)
self._headers = None
self._method = None
self._body = None
def _call_sync(self, method, responses, *args, **kwargs):
"""Call a method, using AsyncResult to wait on the response."""
result = AsyncResult()
self.listeners.wait_event(responses, result.set, result.set_exception)
try:
method(*args, **kwargs)
except:
self.listeners.cancel_event(
responses, result.set, result.set_exception
)
raise
self.must_now_block()
return result.get()
def dispatch(self, method, *args):
"""Fire the listener for a given method.
This enqueues all currently registered listeners for eventual dispatch
by the dispatcher greenlet.
This means that if the dispatcher greenlet abdicates, the new
dispatcher can pick up with the next listener to be called.
"""
l = self.listeners.get(method)
if l:
for h in l:
self.queue.put((h, args))
else:
print "Unhandled method", method
def stop_dispatcher(self):
"""Tell the dispatcher to stop after processing all current events"""
self.queue.put((None, None))
def start_dispatcher(self):
"""Dispatch callbacks, intended to be run as a separate greenlet.
Loops until the connection is closed or the current greenlet is no
longer the dispatcher. In the latter case, don't execute any more
callbacks, as they will be executed by the real dispatcher.
"""
while True:
callback, args = self.queue.get()
if callback is None: # None tells the dispatcher to stop
return
try:
callback(*args)
except Exception:
import traceback
traceback.print_exc()
if not self.current_is_dispatcher():
break
def must_now_block(self):
"""Signal that something is about to block on this connection.
If the current greenlet is the dispatcher, we stop being so, and spawn
a new dispatcher to provide us with the result we need to unblock
ourselves.
"""
if self.current_is_dispatcher():
self.replace_dispatcher()
def current_is_dispatcher(self):
"""Return True if the calling greenlet is the dispatcher."""
return gevent.getcurrent() is self.dispatcher
def replace_dispatcher(self):
"""Spawn a new dispatcher greenlet, replacing the current one.
This automatically triggers the old dispatcher to stop dispatching
after processing the current callback.
"""
self.dispatcher = gevent.spawn(self.start_dispatcher)
class StartChannel(Channel):
"""A channel to handle connection.
The is the initial control channel opened by the server; we pre-register
events so as to handle connection start.
From the AMQP spec:
* The server responds with its protocol version and other properties,
including a list of the security mechanisms that it supports (the Start
method).
* The client selects a security mechanism (Start-Ok).
* The server starts the authentication process, which uses the SASL
challenge-response model. It sends the client a challenge (Secure).
* The client sends an authentication response (Secure-Ok). For example
using the "plain" mechanism, the response consist of a login name and
password.
* The server repeats the challenge (Secure) or moves to negotiation,
sending a set of parameters such as maximum frame size (Tune).
* The client accepts or lowers these parameters (Tune-Ok).
* The client formally opens the connection and selects a virtual host
(Open).
* The server confirms that the virtual host is a valid choice (Open-Ok).
"""
def __init__(self, connection, id):
super(StartChannel, self).__init__(connection, id)
self.listeners.set_handler('connection.start', self.on_start)
self.listeners.set_handler('connection.tune', self.on_tune)
self.listeners.set_handler('connection.close-ok', self.on_close_ok)
def on_start(self, frame):
"""Handle the start frame."""
# TODO: support SASL authentication
assert 'PLAIN' in frame.mechanisms.split(), "Only PLAIN auth supported."
auth = '\0%s\0%s' % (
self.connection.username, self.connection.password
)
scapa = frame.server_properties.get('capabilities', {})
ccapa = {}
if scapa.get('consumer_cancel_notify'):
ccapa['consumer_cancel_notify'] = True
self.connection_start_ok(
{'product': 'nucleon.amqp', 'capabilities': ccapa},
'PLAIN',
auth,
'en_US'
)
def on_close_ok(self, frame):
self.connection._on_normal_disconnect()
def on_tune(self, frame):
"""Handle the tune message.
This message signals that we are allowed to open a virtual host.
"""
self.connection._tune(frame.frame_max, frame.channel_max)
self.connection_tune_ok(frame.channel_max, frame.frame_max, 0)
# open the connection
self.connection_open(self.connection.vhost)
self.connection._on_connect()
[docs]class MessageQueue(Queue):
"""A queue that can receive exceptions."""
def __init__(self, channel, consumer_tag):
self.consumer_tag = consumer_tag
self.channel = channel
super(MessageQueue, self).__init__()
[docs] def get(self, block=True, timeout=None):
if block:
self.channel.must_now_block()
resp = super(MessageQueue, self).get(block=block, timeout=timeout)
if isinstance(resp, Exception):
raise resp
return resp
[docs] def get_nowait(self):
self.get(False)
[docs] def cancel(self):
self.channel.basic_cancel(self.consumer_tag)
class MessageChannel(Channel):
"""A channel that adds useful semantics for publishing and consuming messages.
Semantics that are added:
* Support for registering consumers to receive basic.deliver events.
Consumers also receive errors and are automatically deregistered when
basic_cancel is received.
* Support for the RabbitMQ extension confirm_select, which makes
basic_publish block
* Can check for messages returned with basic_return
"""
def __init__(self, connection, id):
super(MessageChannel, self).__init__(connection, id)
self.consumer_id = 1
self.consumers = {}
self.returned = AsyncResult()
self.listeners.set_handler('basic.deliver', self.on_deliver)
self.listeners.set_handler('basic.return', self.on_basic_return)
self.listeners.set_handler('basic.cancel-ok', self.on_cancel_ok)
def on_deliver(self, message):
"""Called when a message is received.
Dispatches the message to the registered consumer.
"""
self.consumers[message.consumer_tag](message)
def on_basic_return(self, msg):
"""When we receive a basic.return message, store it.
The value can later be checked using .check_returned().
"""
self.returned.set(msg)
[docs] def check_returned(self):
"""Raise an error if a message has been returned.
This also clears the returned frame, with the intention that each
basic.return message may cause at most one MessageReturned error.
"""
if self._method and self._method.name == 'basic.return':
self.must_now_block()
returned = self.returned.get()
else:
try:
returned = self.returned.get_nowait()
except gevent.Timeout:
return
self.clear_returned()
if returned:
raise exceptions.return_exception_from_frame(returned)
def clear_returned(self):
"""Discard any returned message."""
if self.returned.ready():
# we can only replace returned if it is ready - otherwise anything
# that was blocked waiting would wait forever.
self.returned = AsyncResult()
def on_error(self, exc):
"""Override on_error, to pass error to all consumers."""
for consumer in self.consumers.values():
self.queue.put((consumer, exc))
super(MessageChannel, self).on_error(exc)
def on_cancel_ok(self, frame):
"""The server has cancelled a consumer.
We can remove its consumer tag from the registered consumers."""
del(self.consumers[frame.consumer_tag])
[docs] def basic_consume(self, callback=None, **kwargs):
"""Register a consumer for an AMQP queue.
If a callback is given, this will be called on any message.
"""
tag = 'ct-%d' % self.consumer_id
self.consumer_id += 1
kwargs['consumer_tag'] = tag
if callback is not None:
self.consumers[tag] = callback
return super(MessageChannel, self).basic_consume(**kwargs)
else:
queue = MessageQueue(self, tag)
self.consumers[tag] = queue.put
queue.consumer_tag = tag
super(MessageChannel, self).basic_consume(**kwargs)
return queue
def basic_get(self, *args, **kwargs):
"""Wrap basic_get to return None if the response is basic.get-empty.
This will be easier for users to check than testing whether a response
is get-empty.
"""
r = super(MessageChannel, self).basic_get(*args, **kwargs)
return r if isinstance(r, Message) else None
def confirm_select(self, nowait=False):
"""Turn on RabbitMQ's publisher acknowledgements.
See http://www.rabbitmq.com/confirms.html
There are two things that need to be done:
* Swap basic_publish to a version that blocks waiting for the
corresponding ack.
* Support nowait (because this method blocks or not depending on that
argument)
"""
self.basic_publish = self.basic_publish_with_confirm
if nowait:
super(MessageChannel, self).confirm_select(nowait=nowait)
else:
# Send frame directly, as no callback will be received
self._send(spec.FrameConfirmSelect(1))
def basic_publish_with_confirm(self, exchange='', routing_key='', mandatory=False, immediate=False, headers={}, body=''):
"""Version of basic publish that blocks waiting for confirm."""
method = super(MessageChannel, self).basic_publish
self.clear_returned()
ret = self._call_sync(method, ('basic.ack', 'basic.nack'), exchange, routing_key, mandatory, immediate, headers, body)
if ret.name == 'basic.nack':
raise exceptions.PublishFailed(ret)
if mandatory or immediate:
self.check_returned()
return ret