Source code for nucleon.amqp.connection

import errno
import struct

from contextlib import contextmanager

import gevent
from gevent.event import AsyncResult
from gevent.queue import Queue
from gevent.lock import RLock
from gevent import socket

from .urls import parse_amqp_url

from .exceptions import ConnectionError, AMQPHardError, ChannelError
from . import spec
from .buffers import BufferedReader, DecodeBuffer
from .channels import StartChannel, MessageChannel


FRAME_HEADER = struct.Struct('!BHI')


STATE_DISCONNECTED = 0
STATE_CONNECTING = 1
STATE_CONNECTED = 2
STATE_DISCONNECTING = 3


[docs]class Connection(object): """A connection to an AMQP server. This class deals with establishing and reconnecting errors, and routes messages received off the wire to handlers registered with the corresponding channel. AMQP sends connection-level errors (connection.close) that cause the connection to close; to support this we dispatch such errors to all channels. """ frame_max = 131072 # adjusted by Tune frame channel_max = 65535 # adjusted by Tune method MAX_SEND_QUEUE = 32 # frames def __init__(self, amqp_url='amqp:///', debug=True): self.channel_id = 0 self.channels = {} self.channels_lock = RLock() self.queue = None self.state = STATE_DISCONNECTED self.debug = debug (self.username, self.password, self.vhost, self.host, self.port) = \ parse_amqp_url(str(amqp_url))
[docs] def allocate_channel(self): """Create a new channel.""" with self.channels_lock: for i in xrange(self.channel_max): self.channel_id = self.channel_id % (self.channel_max - 1) + 1 if self.channel_id not in self.channels: break else: raise ChannelError("No available channels!") id = self.channel_id chan = MessageChannel(self, id) self.channels[id] = chan chan.channel_open() return chan
def _remove_channel(self, id): """Remove a channel (presumably because it has closed.)""" with self.channels_lock: del(self.channels[id]) @contextmanager
[docs] def channel(self): """Acquire a channel and later release it.""" channel = self.allocate_channel() try: yield channel channel.check_returned() finally: if channel.connection and self.state == STATE_CONNECTED: channel.channel_close(reply_code=200)
[docs] def connect(self): self.connected_event = AsyncResult() self._connect() v = self.connected_event.wait() # Block until the connection is properly ready # FFS, AsyncResult.set_exception doesn't work in gevent 1.0b4 # so we just use normal setting, but with exception types if isinstance(v, Exception): raise v
def _connect(self): """Connect to the remote server and start reader/writer greenlets.""" self.state = STATE_CONNECTING try: try: addrinfo = socket.getaddrinfo(self.host, self.port, socket.AF_INET6, socket.SOCK_STREAM) except socket.gaierror: addrinfo = socket.getaddrinfo(self.host, self.port, socket.AF_INET, socket.SOCK_STREAM) (family, socktype, proto, canonname, sockaddr) = addrinfo[0] self.sock = socket.socket(family, socktype, proto) #set_ridiculously_high_buffers(self.sd) self.sock.connect(sockaddr) except: self.state = STATE_DISCONNECTED raise # Set up connection greenlets self.queue = Queue(self.MAX_SEND_QUEUE) self.reader = gevent.spawn(self.do_read) self.writer = gevent.spawn(self.do_write) self.reader.link(lambda reader: self.writer.kill()) self.writer.link(lambda writer: self.reader.kill())
[docs] def on_error(self, exc): """Dispatch a connection error to all channels.""" for id, channel in self.channels.items(): channel.on_error(exc)
def _on_connect(self): """Called when the connection is fully open.""" self.state = STATE_CONNECTED self.connected_event.set("Connected!") def _on_abnormal_disconnect(self, exc): """Called when the connection has been abnormally disconnected.""" self.state = STATE_DISCONNECTED self.on_error(exc) while True: try: self._connect() except Exception: gevent.sleep(6) else: break def _on_normal_disconnect(self): """Called when the connection has been abnormally disconnected.""" self.state = STATE_DISCONNECTED self.on_error(ConnectionError("Connection closed.")) self.queue.put(None)
[docs] def do_read(self): """Run a reader greenlet. This method will read a preamble then loop forever reading frames off the wire and dispatch them to channels. """ try: reader = BufferedReader(self.sock) # preamble = reader.read(8) # if preamble != spec.PREAMBLE: # raise ConnectionError("Incorrect protocol header from AMQP server") while self.state != STATE_DISCONNECTED: frame_header = reader.read(FRAME_HEADER.size) frame_type, channel, size = FRAME_HEADER.unpack(frame_header) payload = reader.read(size + 1) assert payload[-1] == '\xCE' if self.debug: self._debug_print('s->c', frame_header + payload) buffer = DecodeBuffer(payload) if frame_type == 0x01: # Method frame method_id, = buffer.read('!I') frame = spec.METHODS[method_id].decode(buffer) self.inbound_method(channel, frame) elif frame_type == 0x02: # header frame class_id, body_size = buffer.read('!HxxQ') props = spec.PROPS[class_id](buffer) self.inbound_props(channel, body_size, props) elif frame_type == 0x03: # body frame self.inbound_body(channel, payload[:-1]) elif frame_type in [0x04, 0x08]: # Heartbeat frame # # Catch it as both 0x04 and 0x08 - see # http://www.rabbitmq.com/amqp-0-9-1-errata.html#section_29 pass else: raise ConnectionError("Unknown frame type") except Exception as e: self.connected_event.set(e) if self.state in [STATE_CONNECTED, STATE_CONNECTING]: self._on_abnormal_disconnect(e) else: self.state = STATE_DISCONNECTED
[docs] def inbound_method(self, channel, frame): """Dispatch an inbound method.""" try: c = self.channels[channel] except KeyError: if frame.name == 'connection.start': c = StartChannel(self, channel) self.channels[channel] = c else: return c._on_method(frame)
[docs] def inbound_props(self, channel, body_size, props): """Dispatch an inbound properties frame.""" try: c = self.channels[channel] except KeyError: return c._on_headers(body_size, props)
[docs] def inbound_body(self, channel, payload): """Dispatch an inbound body frame.""" try: c = self.channels[channel] except KeyError: return c._on_body(payload)
[docs] def do_write(self): """Run a writer greenlet. This greenlet will loop until the connection closes, writing frames from the queue. """ # Write the protocol header self.sock.sendall(spec.PREAMBLE) # Enter a send loop while self.state != STATE_DISCONNECTED: msg = self.queue.get() if msg is None: break if self.debug: self._debug_print('s<-c', msg) self.sock.sendall(msg)
def _debug_print(self, direction, msg): try: # Print method, for debugging type, channel, size = FRAME_HEADER.unpack_from(msg) if type == 1: method_id = struct.unpack_from('!I', msg, FRAME_HEADER.size)[0] print direction, spec.METHODS[method_id].name else: print direction, { 2: '[headers %d bytes]', 3: '[payload %d bytes]', 4: '[heartbeat %d bytes]', }[type] % size except Exception: import traceback traceback.print_exc() def _send_frames(self, channel, frames): """Send a sequence of frames on channel. Each frame will be put onto a queue for the writer to write, which could cause the calling greenlet to block if the queue is full. This should cause large outgoing messages to be spliced together so that no caller is starved of service while a large message is sending. """ assert channel in self.channels for type, payload in frames: fdata = ''.join([ FRAME_HEADER.pack(type, channel, len(payload)), payload, '\xCE' ]) self.queue.put(fdata) def _tune(self, frame_max, channel_max, heartbeat=0): """Adjust connection parameters. Called in response to negotiation with the server. """ frame_max = frame_max if frame_max != 0 else 2**19 # limit the maximum frame size, to ensure messages are multiplexed self.frame_max = min(131072, frame_max) self.channel_max = channel_max if channel_max > 0 else 65535 # TODO: do heartbeat
[docs] def close(self): """TODO: shut down cleanly.""" if self.state in [STATE_CONNECTED, STATE_CONNECTING]: self.state = STATE_DISCONNECTING self.channels[0].connection_close() self.writer.join(timeout=2)
def __del__(self): self.close()
def set_ridiculously_high_buffers(sd): ''' Set large tcp/ip buffers kernel. Let's move the complexity to the operating system! That's a wonderful idea! ''' for flag in [socket.SO_SNDBUF, socket.SO_RCVBUF]: for i in range(10): bef = sd.getsockopt(socket.SOL_SOCKET, flag) try: sd.setsockopt(socket.SOL_SOCKET, flag, bef*2) except socket.error: break aft = sd.getsockopt(socket.SOL_SOCKET, flag) if aft <= bef or aft >= 1024*1024: break