import numpy as np

class AutoregressiveProcess:
    def __init__(self,
                num_lags,
                seed=None):
        
        np.random.seed(seed)

        self.num_lags = num_lags
        self.memory = np.zeros(num_lags)
        self.model_parameters = {f'A{p}':np.random.random() for p in range(num_lags)}
        self.mean = np.array(0)
        # self.mean = np.zeros(1)
        self.variance = np.random.random()
        # self.covariance = np.diag( np.random.random(1) )

        self._stabilize()   

    def _stabilize(self,eps=1e-1):
        param_vector = np.array([*self.model_parameters.values()])
        norm = np.linalg.norm(param_vector,ord=1)

        if norm > 1:
            for key in self.model_parameters:
                self.model_parameters[key] = self.model_parameters[key] / (norm + eps)


    def generator(self,num_samples=100):
        for _ in range(1,num_samples+1):
            y = np.zeros(1)
            ## A_1 @ y_{t-1} + ... + A_p @ y_{t-p} 
            for p in range(self.num_lags):
                y += self.model_parameters[f'A{p}'] * self.memory[p] 
            ## noise
            y += np.random.normal(self.mean,self.variance)
            ## update memory
            self.memory = np.concatenate(( y, self.memory ))[:p+1] 
        
            yield y


def data_points(num_samples,num_lags=1,percentage=None,flag_training=False,mean=0,std=1,process='ar'):
    """Model seed 0 - training seed 42"""
    if process == 'ar':
        ar = AutoregressiveProcess(num_lags=num_lags,seed=0)

    if flag_training:
        np.random.seed(42)

    ## data values
    if process == 'ar':
        values = ar.generator(num_samples=num_samples)
        signal_values = np.array([ *values ]).squeeze()
    elif process == 'merton':
        seed = 42 if flag_training else 0
        signal_values = discrete_merton_jump_diffusion_model(num_samples,seed=seed)

    if flag_training:
        mean = np.mean(signal_values)
        std = np.std(signal_values)

    ### standarization 
    standarized_signal_values = (signal_values - mean) / std
    
    if flag_training:
        return mean, std, standarized_signal_values
    else:
        ## time stamps
        ### uniform sampling
        if percentage is None:
            time_stamps = np.linspace(1,num_samples,num_samples)
        ### non-uniform sampling
        else:
            num_original_samples = int((1+percentage)*num_samples)
            original_time_stamps = np.linspace(1,num_original_samples,num_original_samples)
            idx = np.random.permutation(num_original_samples)[:num_samples]
            time_stamps = np.sort(original_time_stamps[idx])
        return time_stamps, standarized_signal_values

def discrete_merton_jump_diffusion_model(num_samples,seed=0):
    ''' Merton's jump diffusion model: https://github.com/federicomariamassari/financial-engineering/blob/master/handbook/01-merton-jdm.ipynb
        papers: Merton, R.C. (1976) Option pricing when underlying stock returns are discontinuous, Journal of Financial Economics, 3:125-144
                Glasserman, P. (2003) Monte Carlo Methods in Financial Engineering, Springer Applications of Mathematics, Vol. 53
    '''
    np.random.seed(seed)

    S = 1
    T = 1
    mu = 0.1
    sigma = 0.3
    Lambda = 1
    a = 0.2
    b = 0.2
    Delta_t = T/num_samples

    ## source of randomness (preallocate memory)
    Z1 = np.random.normal(size=num_samples)
    Z2 = np.random.normal(size=num_samples)
    P = np.random.poisson(Lambda*Delta_t,num_samples)

    ## preallocated memory
    f = np.zeros(num_samples) 
    f[0] = S

    for i in range(1,num_samples):
        f[i] = f[i-1] * np.exp( (mu-sigma**2/2)*Delta_t + sigma*np.sqrt(Delta_t)*Z1[i] + a*P[i] + \
                               np.sqrt(b**2)*np.sqrt(P[i])*Z2[i])
    
    ## centered around S
    return f - S