import numpy as np
import time, math,os
#import pycuda.tools as tools
#import pycuda.driver as driver
[docs]class Solver(object):
    """
    'Abstract' base class for gpu based solver implementations.
    
    """
    _MAXBLOCKSPERDEVICE = 500
    _MAXTHREADSPERBLOCK = 64
    _WARP_SIZE = 32
    _info = False
    # private variables
    _compiledKernel = None
    _completeCode = None
    
    _timepoints = None
    
    _neq = None
    _nsystems = None
    _resultNumber = None
    
    # device used
    # ToDo enable more than default device to be used
    _device = None
#     _maxThreadsPerMP = None
#     _maxBlocksPerMP = None
    generator = None
    def __init__(self):
        """
        Constructor for the Solver.
        """
        
        device = os.getenv("CUDA_DEVICE")
        if(device==None):
            self._device = 0
        else:
            self._device = int(device)
        
            
#         compability = driver.Device(self._device).compute_capability()
#         self._maxThreadsPerMP =  utils.getMaxThreadsPerMP(compability)
#         self._maxBlocksPerMP = utils.getMaxBlocksPerMP(compability)
        
            
    
    # method for calculating optimal number of blocks and threads per block
[docs]    def _getOptimalGPUParam(self, compiledKernel = None):
        """
        Returns the optimal size of blocks and threads for the given compiled source
        
        Parameters
        ----------
        compiledKernel : sourceModule
            The kernel to use to determine the optimal param config
            
        :returns: blocks, threads
        
        """
        if compiledKernel == None:
            compiledKernel = self._compiledKernel
        
        # general parameters
#         maxThreadsPerBlock = driver.Device(self._device).max_threads_per_block
        
        # calculate number of threads per block; assuming that registers are the limiting factor
        #maxThreads = min(driver.Device(self._device).max_registers_per_block/compiledKernel.num_regs,maxThreadsPerBlock)
        
        # assume smaller blocksize creates less overhead; ignore occupancy..
        maxThreads = min(driver.Device(self._device).max_registers_per_block/compiledKernel.num_regs, self._MAXTHREADSPERBLOCK)
        
        maxWarps = maxThreads / self._WARP_SIZE
        
        # warp granularity up to compability 2.0 is 2. Therefore if maxWarps is uneven only maxWarps-1 warps
        # can be run
        #if(maxWarps % 2 == 1):
            #maxWarps -= 1
        
        # maximum number of threads per block
        threads = maxWarps * self._WARP_SIZE
        
        # assign number of blocks
        if (self._nsystems%threads == 0):
            blocks = self._nsystems/threads
        else:
            blocks = self._nsystems/threads + 1
        
        return blocks, threads
    
     
[docs]    def solve(self, y0, t, args=None, timing=False, info=False, write_code=False, full_output=False, **kwargs):
        """
        Integrate a system of ordinary differential equations.
        Solves the initial value problem for stiff or non-stiff systems
        of first order ode-s::
    
            dy/dt = func(y,t0,...)
    
        where y can be a vector.
    
        Parameters:
        
        y0 : array
            Initial condition on y (can be a vector).
        t : array
            A sequence of time points for which to solve for y.  The initial
            value point should be the first element of this sequence.
        args : array
            Extra arguments to pass to function.
        full_output : boolean
            True if to return a dictionary of optional outputs as the second output    
        use_jacobian: bool
            Flag indicating if a jacobian matrix is provided. Requires an 
            implementation in the kernel
        rtol, atol : float
            The input parameters rtol and atol determine the error
            control performed by the solver.  The solver will control the
            vector, e, of estimated local errors in y, according to an
            inequality of the form ``max-norm of (e / ewt) <= 1``,
            where ewt is a vector of positive error weights computed as:
            ``ewt = rtol * abs(y) + atol``
            rtol and atol can be either vectors the same length as y or scalars.
            Defaults to 1.49012e-8.
        h0 : float, (0: solver-determined)
            The step size to be attempted on the first step.
        mxstep : integer, (0: solver-determined)
            Maximum number of (internally defined) steps allowed for each
            integration point in t.
        
        :returns:
        y : array, shape (len(t), len(y0))
            Array containing the value of y for each desired time in t,
            with the initial value y0 in the first row.
    
        infodict : dict, only returned if full_output == True
            Dictionary containing additional output information
    
            =========  ============================================================
            key        meaning
            =========  ============================================================
            'message'  message representing state of system
            'system'   index of system
            'nst'      cumulative number of time steps
            'nfe'      cumulative number of function evaluations for each time step
            'nje'      cumulative number of jacobian evaluations for each time step
            =========  ============================================================
    
        """
        self._info = info
        self._nsystems=y0.shape[0]
        self._neq = y0.shape[1]
        self._timepoints = np.array(t,dtype=np.float32)
        self._resultNumber = len(t)
        
        if(self._compiledKernel == None):
            #compile to determine blocks and threads
            if timing:
                start = time.time()
                
            self._completeCode, self._compiledKernel = self._compile(write_code)
            if timing:
                print("CudaPyInt: compiling kernel took: {0} s").format(round((time.time()-start),4))
            
        
        blocks, threads = self._getOptimalGPUParam()
        if info:
            print("CudaPyInt: threads: {0}, blocks: {1}").format(threads, blocks)
            print("CudaPyInt: kernel mem local: {0}, shared: {1}, registers: {2}").format(self._compiledKernel.local_size_bytes, self._compiledKernel.shared_size_bytes, self._compiledKernel.num_regs)
            occ = tools.OccupancyRecord( tools.DeviceData(), threads=threads, shared_mem=self._compiledKernel.shared_size_bytes, registers=self._compiledKernel.num_regs )
            print("CudaPyInt: threadblocks per mp: {0}, limit: {1}, occupancy:{2}").format(occ.tb_per_mp, occ.limited_by, occ.occupancy)
        if timing:
            start = time.time()
        
        # number of device calls
        runs = int(math.ceil(blocks / float(self._MAXBLOCKSPERDEVICE)))
        for i in range(runs):
            # for last device call calculate number of remaining threads to run
            if(i==runs-1):
                runblocks = int(blocks % self._MAXBLOCKSPERDEVICE)
                if(runblocks == 0):
                    runblocks = self._MAXBLOCKSPERDEVICE
            else:
                runblocks = int(self._MAXBLOCKSPERDEVICE)
            if info:
                print("CudaPyInt: Run {0} blocks.").format(runblocks)
            minIndex = self._MAXBLOCKSPERDEVICE*i*threads
            maxIndex = minIndex + threads*runblocks
            runParameters = args[minIndex:maxIndex]
            runInitValues = y0[minIndex:maxIndex]
            
            values, outputs = self._solve_internal(runInitValues, runParameters, runblocks, threads, full_output=full_output, **kwargs)
            if(i==0):
                returnValue = values
                returnOutputs = outputs
            else:
                returnValue = np.append(returnValue,values,axis=0)
                returnOutputs = np.append(returnOutputs,outputs,axis=0)
        
        if timing:
            print("CudaPyInt: GPU blocks: {0}, threads: {1}, systems: {2}, running time: {3}s").format(blocks, threads, self._nsystems, round((time.time()-start),4))
        if full_output:
            return returnValue, returnOutputs
        
        return returnValue