'''
Simple RPC library.
The :class:`.Client` and :class:`.Server` classes here work with
sockets which should implement the :class:`.Socket` interface.
'''
from cPickle import PickleError
import cProfile
import weakref
import sys
import threading
import time
import traceback
import types
import cStringIO
import cPickle
from .. import util
from spartan import cloudpickle
from spartan.node import Node
RPC_ID = xrange(1000000000).__iter__()
CLIENT_PENDING = weakref.WeakKeyDictionary()
SERVER_PENDING = weakref.WeakKeyDictionary()
DEFAULT_TIMEOUT = 100
[docs]def set_default_timeout(seconds):
global DEFAULT_TIMEOUT
DEFAULT_TIMEOUT = seconds
util.log_info('Set default timeout to %s seconds.', DEFAULT_TIMEOUT)
[docs]class RPCException(Node):
_members = ['py_exc']
[docs]class PickledData(Node):
'''
Helper class: indicates that this message has already been pickled,
and should be sent as is, rather than being re-pickled.
'''
_members = ['data']
[docs]class SocketBase(object):
[docs] def send(self, blob): pass
[docs] def register_handler(self, handler):
'A handler() is called in response to read requests.'
self._handler = handler
# client
[docs] def connect(self): pass
# server
[docs]def capture_exception(exc_info=None):
if exc_info is None:
exc_info = sys.exc_info()
tb = traceback.format_exception(*exc_info)
return RPCException(py_exc=''.join(tb).replace('\n', '\n:: '))
[docs]class Group(tuple):
pass
[docs]def pickle_to(obj, writer):
try:
pickled = cPickle.dumps(obj, -1)
writer.write(pickled)
except (PickleError, TypeError):
#util.log_warn('CPICKLE failed: %s (%s)', sys.exc_info(), obj)
writer.write(cloudpickle.dumps(obj, -1))
[docs]def pickle(obj):
try:
return cPickle.dumps(obj, -1)
except (PickleError, TypeError):
return cloudpickle.dumps(obj, -1)
NO_RESULT = object()
[docs]class PendingRequest(object):
'''An outstanding RPC request.
Call done(result) when a method is finished processing.
'''
def __init__(self, socket, rpc_id):
self.socket = socket
self.rpc_id = rpc_id
self.created = time.time()
self.finished = False
self.result = NO_RESULT
SERVER_PENDING[self] = 1
[docs] def wait(self):
while self.result is NO_RESULT:
time.sleep(0.001)
return self.result
[docs] def done(self, result=None):
# util.log_info('RPC finished in %.3f seconds' % (time.time() - self.created))
self.finished = True
self.result = result
if self.socket is not None:
header = { 'rpc_id' : self.rpc_id }
# util.log_info('Finished %s, %s', self.socket.addr, self.rpc_id)
w = cStringIO.StringIO()
cPickle.dump(header, w, -1)
pickle_to(result, w)
self.socket.send(w.getvalue())
def __del__(self):
if not self.finished:
util.log_error('PendingRequest.done() not called before destruction (likely due to an exception.)')
self.done(result=RPCException(py_exc='done() not called on request.'))
[docs]class RemoteException(Exception):
'''Wrap a uncaught remote exception.'''
def __init__(self, tb):
self._tb = tb
def __repr__(self):
return 'RemoteException:\n' + self._tb
def __str__(self):
return repr(self)
[docs]class FnFuture(object):
'''Chain ``fn`` to the given future.
``self.wait()`` return ``fn(future.wait())``.
'''
def __init__(self, future, fn):
self.future = future
self.fn = fn
self.result = None
[docs] def wait(self):
result = self.future.wait()
# util.log_info('Applying %s to %s', self.fn, result)
self.result = self.fn(result)
return self.result
[docs]class Future(object):
def __init__(self, addr, rpc_id):
self.addr = addr
self.rpc_id = rpc_id
self.have_result = False
self.result = None
self.finished_fn = None
self._cv = threading.Condition()
self._start = time.time()
self._deadline = time.time() + DEFAULT_TIMEOUT
CLIENT_PENDING[self] = 1
def _set_result(self, result):
self._cv.acquire()
self.have_result = True
if self.finished_fn is not None:
self.result = self.finished_fn(result)
else:
self.result = result
self._cv.notify()
self._cv.release()
[docs] def timed_out(self):
return self._deadline < time.time()
[docs] def wait(self):
self._cv.acquire()
while not self.have_result and not self.timed_out():
# use a timeout so that ctrl-c works.
self._cv.wait(timeout=0.1)
self._cv.release()
# util.log_info('Result from %s in %f seconds.', self.addr, time.time() - self._start)
if not self.have_result and self.timed_out():
util.log_info('timed out!')
raise Exception('Timed out on remote call (%s %s)', self.addr, self.rpc_id)
if isinstance(self.result, RPCException):
raise RemoteException(self.result.py_exc)
return self.result
[docs] def on_finished(self, fn):
return FnFuture(self, fn)
[docs]class DummyFuture(object):
def __init__(self, base=None):
self.v = base
[docs] def wait(self):
return self.v
DUMMY_FUTURE = DummyFuture()
[docs]class FutureGroup(list):
[docs] def wait(self):
return [f.wait() for f in self]
[docs]def wait_for_all(futures):
return [f.wait() for f in futures]
[docs]class Server(object):
def __init__(self, socket):
self._socket = socket
self._socket.register_handler(self.handle_read)
self._methods = {}
self._running = False
self.register_method('diediedie', self._diediedie)
def _diediedie(self, handle, req):
handle.done(None)
self._socket.flush()
self.shutdown()
@property
[docs] def addr(self):
return self._socket.addr
[docs] def serve(self):
self.serve_nonblock()
while self._running:
time.sleep(0.1)
[docs] def serve_nonblock(self):
# util.log_info('Running.')
self._running = True
self._socket.bind()
[docs] def register_object(self, obj):
for name in dir(obj):
if name.startswith('__'): continue
fn = getattr(obj, name)
if isinstance(fn, types.MethodType):
self.register_method(name, fn)
[docs] def register_method(self, name, fn):
self._methods[name] = fn
[docs] def handle_read(self, socket):
#util.log_info('Reading...')
data = socket.recv()
reader = cStringIO.StringIO(data)
header = cPickle.load(reader)
#util.log_info('Reading: %s %s', self._socket.addr, header['rpc_id'])
handle = PendingRequest(socket, header['rpc_id'])
name = header['method']
try:
fn = self._methods[name]
except KeyError:
handle.done(capture_exception())
return
try:
req = cPickle.load(reader)
result = fn(req, handle)
assert result is None, 'non-None result from RPC handler (use handle.done())'
except:
util.log_info('Caught exception in handler.', exc_info=1)
handle.done(capture_exception())
[docs] def shutdown(self):
self._running = 0
self._socket.close()
del self._socket
[docs]class ProxyMethod(object):
def __init__(self, client, method):
self.client = client
self.socket = client._socket
self.method = method
def __call__(self, request=None):
rpc_id = RPC_ID.next()
header = { 'method' : self.method, 'rpc_id' : rpc_id }
f = Future(self.socket.addr, rpc_id)
self.client._futures[rpc_id] = f
w = cStringIO.StringIO()
cPickle.dump(header, w, -1)
if isinstance(request, PickledData):
w.write(request.data)
else:
pickle_to(request, w)
#util.log_info('Sending %s', self.method)
# if len(serialized) > 800000:
# util.log_info('%s::\n %s; \n\n\n %s', self.method, ''.join(traceback.format_stack()), request)
self.socket.send(w.getvalue())
return f
[docs]class Client(object):
def __init__(self, socket):
self._socket = socket
self._socket.register_handler(self.handle_read)
self._socket.connect()
self._futures = {}
def __reduce__(self, *args, **kwargs):
raise cPickle.PickleError('Not pickleable.')
[docs] def addr(self):
return self._socket.addr
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def close(self):
self._socket.close()
def __getattr__(self, method_name):
return ProxyMethod(self, method_name)
[docs] def handle_read(self, socket):
data = socket.recv()
reader = cStringIO.StringIO(data)
header = cPickle.load(reader)
resp = cPickle.load(reader)
rpc_id = header['rpc_id']
f = self._futures[rpc_id]
f._set_result(resp)
del self._futures[rpc_id]
[docs] def close(self):
self._socket.close()
[docs]def forall(clients, method, request):
'''Invoke ``method`` with ``request`` for each client in ``clients``
``request`` is only serialized once, so this is more efficient when
targeting multiple workers with the same data.
Returns a future wrapping all of the requests.
'''
futures = []
pickled = PickledData(data=pickle(request))
for c in clients:
futures.append(getattr(c, method)(pickled))
return FutureGroup(futures)