#! /usr/bin/env python
# Author: Marc Claesen
#
# Copyright (c) 2014 KU Leuven, ESAT-STADIUS
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# 3. Neither name of copyright holders nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from __future__ import print_function
import json
import sys
import socket
import itertools
from . import functions
from . import parallel
import math
import multiprocessing
import threading
__DEBUG = False
__channel_in = sys.stdin
__channel_out = sys.stdout
def _find_replacement(key, kwargs):
"""Finds a replacement for key that doesn't collide with anything in kwargs."""
key += '_'
while key in kwargs:
key += '_'
return key
def _find_replacements(illegal_keys, kwargs):
"""Finds replacements for all illegal keys listed in kwargs."""
replacements = {}
for key in illegal_keys:
if key in kwargs:
replacements[key] = _find_replacement(key, kwargs)
return replacements
def _replace_keys(kwargs, replacements):
"""Replace illegal keys in kwargs with another value."""
for key, replacement in replacements.items():
if key in kwargs:
kwargs[replacement] = kwargs[key]
del kwargs[key]
return kwargs
[docs]def json_encode(data):
"""Encodes given data in a JSON string."""
return json.dumps(data)
[docs]def json_decode(data):
"""Decodes given JSON string and returns its data.
>>> orig = {1: 2, 'a': [1,2]}
>>> data = json_decode(json_encode(orig))
>>> data[str(1)]
2
>>> data['a']
[1, 2]
"""
return json.loads(data)
[docs]def send(data):
"""Writes data to channel and flushes."""
print(data, file=__channel_out)
__channel_out.flush()
[docs]def receive():
"""Reads data from channel."""
line = __channel_in.readline()[:-1]
if not line:
raise EOFError("Unexpected end of communication.")
return line
[docs]def open_socket(port, host='localhost'):
"""Opens a socket to host:port and reconfigures internal channels."""
global __channel_in
global __channel_out
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(('', 0))
sock.connect((host, port))
__channel_in = sock.makefile('r')
__channel_out = sock.makefile('w')
except (socket.error, OverflowError, ValueError) as e:
print('Error making socket: ' + str(e), file=sys.stderr)
sys.exit(1)
[docs]def open_server_socket():
serv_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
serv_sock.bind(('', 0))
except socket.err as msg:
print('Bind failed. Error Code : ' + str(msg[0]) + ' Message ' + msg[1])
sys.exit()
serv_sock.listen(0)
port = serv_sock.getsockname()[1]
return port, serv_sock
[docs]def accept_server_connection(server_socket):
global __channel_in
global __channel_out
try:
sock, _ = server_socket.accept()
__channel_in = sock.makefile('r')
__channel_out = sock.makefile('w')
except (socket.error, OverflowError, ValueError) as e:
print('Error making socket: ' + str(e), file=sys.stderr)
sys.exit(1)
[docs]class EvalManager(object):
def __init__(self, max_vectorized=100, replacements={}):
"""Constructs an EvalManager object.
:param max_vectorized: the maximum size of a vector evaluation
larger vectorizations will be chunked
:type max_vectorized: int
:param replacements: a mapping of `original:replacement` keyword names
:type replacements: dict
"""
# are we doing a parallel function evaluation?
self._vectorized = False
self._max_vectorized = max_vectorized
# the queue used for parallel evaluations
self._queue = None
# lock to use to add evaluations to the queue
self._queue_lock = multiprocessing.Lock()
# lock to get results
self._result_lock = multiprocessing.Lock()
# used to check whether Future started running
self._semaphore = None
# used to signal Future's to get their result
self._processed_semaphore = None
# keys that must be replaced
self._replacements = dict((v, k) for k, v in replacements.items())
@property
[docs] def replacements(self):
"""Keys that must be replaced: keys are current values, values are
what must be sent through the pipe."""
return self._replacements
@property
[docs] def max_vectorized(self):
return self._max_vectorized
@property
[docs] def cv(self):
return self._cv
@property
[docs] def semaphore(self):
return self._semaphore
@property
[docs] def processed_semaphore(self):
return self._processed_semaphore
[docs] def pmap(self, f, *args):
"""Performs vector evaluations through pipes.
:param f: the objective function (piped_function_eval)
:type f: callable
:param args: function arguments
:type args: iterables
The vector evaluation is sent in chunks of size self.max_vectorized.
"""
argslist = zip(*args)
results = []
# partition the vector evaluation in chunks <= self.max_vectorized
for idx in range(int(math.ceil(1.0 * len(argslist) / self.max_vectorized))):
chunk = argslist[idx * self.max_vectorized:(idx + 1) * self.max_vectorized]
results.extend(self._vector_eval(f, *zip(*chunk)))
return results
def _vector_eval(self, f, *args):
self._queue = []
self._vectorized = True
self._semaphore = multiprocessing.Semaphore(0)
self._processed_semaphore = multiprocessing.Semaphore(0)
# make sure correct evaluations lead to counter increments
def wrap(*args):
result = f(*args)
# signal mgr to add next future
self.semaphore.release()
return result
# create list of futures
futures = []
for ar in zip(*args):
futures.append(parallel.Future(wrap, *ar))
try:
self.semaphore.acquire()
except functions.MaximumEvaluationsException:
if not len(self.queue):
# FIXME: some threads may be running atm
raise
break
# process queue
self.flush_queue()
# notify all waiting futures
for _ in range(len(self.queue)):
self.processed_semaphore.release()
# gather results
results = [f() for f in futures]
for f in futures:
f.join()
return results
[docs] def pipe_eval(self, **kwargs):
# fix python keywords back to original
kwargs = _replace_keys(kwargs, self.replacements)
json_data = json_encode(kwargs)
send(json_data)
json_reply = receive()
decoded = json_decode(json_reply)
if decoded.get('error', False): # TODO: allow error handling higher up?
sys.stderr('ERROR: ' + decoded['error'])
sys.exit(1)
return decoded['value']
[docs] def add_to_queue(self, **kwargs):
kwargs = _replace_keys(kwargs, self.replacements)
try:
self.queue_lock.acquire()
self.queue.append(kwargs)
idx = len(self.queue) - 1
finally:
self.queue_lock.release()
return idx
@property
[docs] def vectorized(self):
return self._vectorized
@property
[docs] def queue(self):
return self._queue
@property
[docs] def queue_lock(self):
return self._queue_lock
[docs] def get(self, number):
try:
self._result_lock.acquire()
value = self._results[number]
finally:
self._result_lock.release()
return value
[docs] def flush_queue(self):
self._results = []
if self._queue:
json_data = json_encode(self.queue)
send(json_data)
json_reply = receive()
decoded = json_decode(json_reply)
if decoded.get('error', False):
sys.stderr('ERROR: ' + decoded['error'])
sys.exit(1)
self._results = decoded['values']
[docs]def make_piped_function(mgr):
def piped_function_eval(**kwargs):
"""Returns a function evaluated through a pipe with arguments args.
args must be a namedtuple."""
if mgr.vectorized:
# add to mgr queue
number = mgr.add_to_queue(**kwargs)
# signal mgr to continue
mgr.semaphore.release()
# wait for results
mgr.processed_semaphore.acquire()
return mgr.get(number)
else:
return mgr.pipe_eval(**kwargs)
return piped_function_eval
if __name__ == '__main__':
import doctest
doctest.testmod()