Source code for spartan.rpc.common

'''
Simple RPC library.

The :class:`.Client` and :class:`.Server` classes here work with
sockets which should implement the :class:`.Socket` interface.
'''
from cPickle import PickleError
import cProfile
import weakref
import sys
import threading
import time
import traceback
import types
import cStringIO
import cPickle

from .. import util
from spartan import cloudpickle
from spartan.node import Node


RPC_ID = xrange(1000000000).__iter__()

CLIENT_PENDING = weakref.WeakKeyDictionary()
SERVER_PENDING = weakref.WeakKeyDictionary()

DEFAULT_TIMEOUT = 100
[docs]def set_default_timeout(seconds): global DEFAULT_TIMEOUT DEFAULT_TIMEOUT = seconds util.log_info('Set default timeout to %s seconds.', DEFAULT_TIMEOUT)
[docs]class RPCException(Node): _members = ['py_exc']
[docs]class PickledData(Node): ''' Helper class: indicates that this message has already been pickled, and should be sent as is, rather than being re-pickled. ''' _members = ['data']
[docs]class SocketBase(object):
[docs] def send(self, blob): pass
[docs] def recv(self): pass
[docs] def flush(self): pass
[docs] def close(self): pass
[docs] def register_handler(self, handler): 'A handler() is called in response to read requests.' self._handler = handler # client
[docs] def connect(self): pass # server
[docs] def bind(self): pass
[docs]def capture_exception(exc_info=None): if exc_info is None: exc_info = sys.exc_info() tb = traceback.format_exception(*exc_info) return RPCException(py_exc=''.join(tb).replace('\n', '\n:: '))
[docs]class Group(tuple): pass
[docs]def pickle_to(obj, writer): try: pickled = cPickle.dumps(obj, -1) writer.write(pickled) except (PickleError, TypeError): #util.log_warn('CPICKLE failed: %s (%s)', sys.exc_info(), obj) writer.write(cloudpickle.dumps(obj, -1))
[docs]def pickle(obj): try: return cPickle.dumps(obj, -1) except (PickleError, TypeError): return cloudpickle.dumps(obj, -1)
NO_RESULT = object()
[docs]class PendingRequest(object): '''An outstanding RPC request. Call done(result) when a method is finished processing. ''' def __init__(self, socket, rpc_id): self.socket = socket self.rpc_id = rpc_id self.created = time.time() self.finished = False self.result = NO_RESULT SERVER_PENDING[self] = 1
[docs] def wait(self): while self.result is NO_RESULT: time.sleep(0.001) return self.result
[docs] def done(self, result=None): # util.log_info('RPC finished in %.3f seconds' % (time.time() - self.created)) self.finished = True self.result = result if self.socket is not None: header = { 'rpc_id' : self.rpc_id } # util.log_info('Finished %s, %s', self.socket.addr, self.rpc_id) w = cStringIO.StringIO() cPickle.dump(header, w, -1) pickle_to(result, w) self.socket.send(w.getvalue())
def __del__(self): if not self.finished: util.log_error('PendingRequest.done() not called before destruction (likely due to an exception.)') self.done(result=RPCException(py_exc='done() not called on request.'))
[docs]class RemoteException(Exception): '''Wrap a uncaught remote exception.''' def __init__(self, tb): self._tb = tb def __repr__(self): return 'RemoteException:\n' + self._tb def __str__(self): return repr(self)
[docs]class FnFuture(object): '''Chain ``fn`` to the given future. ``self.wait()`` return ``fn(future.wait())``. ''' def __init__(self, future, fn): self.future = future self.fn = fn self.result = None
[docs] def wait(self): result = self.future.wait() # util.log_info('Applying %s to %s', self.fn, result) self.result = self.fn(result) return self.result
[docs]class Future(object): def __init__(self, addr, rpc_id): self.addr = addr self.rpc_id = rpc_id self.have_result = False self.result = None self.finished_fn = None self._cv = threading.Condition() self._start = time.time() self._deadline = time.time() + DEFAULT_TIMEOUT CLIENT_PENDING[self] = 1 def _set_result(self, result): self._cv.acquire() self.have_result = True if self.finished_fn is not None: self.result = self.finished_fn(result) else: self.result = result self._cv.notify() self._cv.release()
[docs] def timed_out(self): return self._deadline < time.time()
[docs] def wait(self): self._cv.acquire() while not self.have_result and not self.timed_out(): # use a timeout so that ctrl-c works. self._cv.wait(timeout=0.1) self._cv.release() # util.log_info('Result from %s in %f seconds.', self.addr, time.time() - self._start) if not self.have_result and self.timed_out(): util.log_info('timed out!') raise Exception('Timed out on remote call (%s %s)', self.addr, self.rpc_id) if isinstance(self.result, RPCException): raise RemoteException(self.result.py_exc) return self.result
[docs] def on_finished(self, fn): return FnFuture(self, fn)
[docs]class DummyFuture(object): def __init__(self, base=None): self.v = base
[docs] def wait(self): return self.v
DUMMY_FUTURE = DummyFuture()
[docs]class FutureGroup(list):
[docs] def wait(self): return [f.wait() for f in self]
[docs]def wait_for_all(futures): return [f.wait() for f in futures]
[docs]class Server(object): def __init__(self, socket): self._socket = socket self._socket.register_handler(self.handle_read) self._methods = {} self._running = False self.register_method('diediedie', self._diediedie) def _diediedie(self, handle, req): handle.done(None) self._socket.flush() self.shutdown() @property
[docs] def addr(self): return self._socket.addr
[docs] def serve(self): self.serve_nonblock() while self._running: time.sleep(0.1)
[docs] def serve_nonblock(self): # util.log_info('Running.') self._running = True self._socket.bind()
[docs] def register_object(self, obj): for name in dir(obj): if name.startswith('__'): continue fn = getattr(obj, name) if isinstance(fn, types.MethodType): self.register_method(name, fn)
[docs] def register_method(self, name, fn): self._methods[name] = fn
[docs] def handle_read(self, socket): #util.log_info('Reading...') data = socket.recv() reader = cStringIO.StringIO(data) header = cPickle.load(reader) #util.log_info('Reading: %s %s', self._socket.addr, header['rpc_id']) handle = PendingRequest(socket, header['rpc_id']) name = header['method'] try: fn = self._methods[name] except KeyError: handle.done(capture_exception()) return try: req = cPickle.load(reader) result = fn(req, handle) assert result is None, 'non-None result from RPC handler (use handle.done())' except: util.log_info('Caught exception in handler.', exc_info=1) handle.done(capture_exception())
[docs] def shutdown(self): self._running = 0 self._socket.close() del self._socket
[docs]class ProxyMethod(object): def __init__(self, client, method): self.client = client self.socket = client._socket self.method = method def __call__(self, request=None): rpc_id = RPC_ID.next() header = { 'method' : self.method, 'rpc_id' : rpc_id } f = Future(self.socket.addr, rpc_id) self.client._futures[rpc_id] = f w = cStringIO.StringIO() cPickle.dump(header, w, -1) if isinstance(request, PickledData): w.write(request.data) else: pickle_to(request, w) #util.log_info('Sending %s', self.method) # if len(serialized) > 800000: # util.log_info('%s::\n %s; \n\n\n %s', self.method, ''.join(traceback.format_stack()), request) self.socket.send(w.getvalue()) return f
[docs]class Client(object): def __init__(self, socket): self._socket = socket self._socket.register_handler(self.handle_read) self._socket.connect() self._futures = {} def __reduce__(self, *args, **kwargs): raise cPickle.PickleError('Not pickleable.')
[docs] def addr(self): return self._socket.addr
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def close(self): self._socket.close() def __getattr__(self, method_name): return ProxyMethod(self, method_name)
[docs] def handle_read(self, socket): data = socket.recv() reader = cStringIO.StringIO(data) header = cPickle.load(reader) resp = cPickle.load(reader) rpc_id = header['rpc_id'] f = self._futures[rpc_id] f._set_result(resp) del self._futures[rpc_id]
[docs] def close(self): self._socket.close()
[docs]def forall(clients, method, request): '''Invoke ``method`` with ``request`` for each client in ``clients`` ``request`` is only serialized once, so this is more efficient when targeting multiple workers with the same data. Returns a future wrapping all of the requests. ''' futures = [] pickled = PickledData(data=pickle(request)) for c in clients: futures.append(getattr(c, method)(pickled)) return FutureGroup(futures)