import numpy as np
import os
import math
from tqdm import tqdm

class GraphOUGR(object):
    def __init__(self, theta1, theta2, sigma, dimension=2):
        '''
        SDE: dx = theta1 * (theta2 - x) * dt + sigma * dW
        GR: d P(x_t | theta) / d P(x_t | theta_0) theta_0 = (theta1, theta2)
        difference: theta1' * (theta2' - x) - theta1 * (theta2 - x)
        '''
        self.dimension = dimension
        self.theta1 = theta1
        self.theta2 = theta2
        self.sigma = sigma

        self.A_matrix = np.eye(dimension)
        for i in range(dimension-1):
            if i <= 1:
                self.A_matrix[i, i] = theta1 * (-1)**(i+1)
            else:
                self.A_matrix[i, i+1] = -0.1 * (i)
                self.A_matrix[i+1, i] = -0.1 * (i)
        self.A_matrix[2, 1] = -0.1
        self.A_matrix[-1, 0] = -0.9
        self.A_matrix = -self.A_matrix
        self.mu = np.zeros(dimension)
        self.mu[0] = -theta2
        self.mu[1] = theta2

        theta1_unif = np.arange(0.3, 0.5, 0.014)  # shape: (15,)
        theta2_unif = np.arange(0.002, 0.004, 0.00014)  # shape: (15,)
        X, Y = np.meshgrid(theta1_unif, theta2_unif)

        self.theta1_random = X.flatten()
        self.theta2_random = Y.flatten()
        self.number = len(self.theta1_random)
        self.gr_total = [[] for _ in range(self.number)]
        self.x = []


    def __call__(self, x):
        return self.A_matrix @ x + self.mu * x[0] * x[1]
    
    def step_GR(self, x, noise, dt):
        noise = noise[:2]
        for i in range(self.number):
            theta1_random = self.theta1_random[i]
            theta2_random = self.theta2_random[i]
            A_random = np.array([[theta1_random, 0.], [0., -theta1_random]])
            mu_random = np.array([-theta2_random, theta2_random])
            u_random = A_random @ x[:2] + mu_random * x[0] * x[1]
            u_theta0 = self.A_matrix[:2, :2] @ x[:2] + self.mu[:2] * x[0] * x[1]
            u = u_random - u_theta0
            gr = (u / self.sigma) * noise * (dt ** 0.5) - 0.5 * (u / self.sigma)**2 * dt
            gr = gr.sum()
            self.gr_total[i].append(gr)
        return
    
    def traj(self, x, total_steps, dt=0.001):
        
        for i in tqdm(range(total_steps)):
            self.x.append(x)
            noise = np.random.randn(*x.shape)
            self.step_GR(x, noise, dt)
            x = x.copy()
            x = x + self(x) * dt + self.sigma * noise * (dt ** 0.5)
        
        return self.x, self.gr_total

def main(output_dir):
    ou = GraphOUGR(theta1=0.4, theta2=0.003, sigma=math.sqrt(0.1))
    x = np.zeros(ou.dimension) + 100.
    dt = 0.001
    total_steps = int(51. / dt)
    x, gr_total = ou.traj(x, total_steps, dt)
    x = np.array(x).reshape(1, total_steps, ou.dimension)
    noise = np.random.randn(*x.shape)
    y = x + 5. * noise
    traj =  {
        'sample': x,
        'observation': y,
        'theta0': np.array([ou.theta1, ou.theta2]),
        'theta': np.array([ou.theta1_random, ou.theta2_random])
    }
    gr_total = np.array(gr_total).reshape(ou.number, total_steps)

    os.makedirs(output_dir, exist_ok=True)
    np.savez(f'{output_dir}/trajectory.npz', **traj)
    np.save(f'{output_dir}/gr_total.npy', gr_total)
    theta = np.array([ou.theta1_random, ou.theta2_random])
    np.save(f'{output_dir}/theta_random.npy', theta)


if __name__ == '__main__':
    output_dir = '' # TODO: add output dir
    main(output_dir)