import numpy as np
import os
import math
from tqdm import tqdm

class GraphOUGR(object):
    def __init__(self, theta1, theta2, sigma, dimension=10):
        '''
        SDE: dx = theta1 * (theta2 - x) * dt + sigma * dW
        GR: d P(x_t | theta) / d P(x_t | theta_0) theta_0 = (theta1, theta2)
        difference: theta1' * (theta2' - x) - theta1 * (theta2 - x)
        '''
        self.dimension = dimension
        self.theta1 = theta1
        self.theta2 = theta2
        self.sigma = sigma

        self.A_matrix = np.eye(dimension)
        for i in range(dimension-1):
            val = -theta1 if i == 0 else (-0.1 * i)
            self.A_matrix[i, i+1] = val
            self.A_matrix[i+1, i] = val
        self.mu = (np.arange(dimension)+1) * 0.5
        self.mu[0] = theta2

        theta1_unif = np.arange(0.1, 1.6, 0.1)  # shape: (15,)
        theta2_unif = np.arange(0.5, 8.0, 0.5)  # shape: (15,)
        X, Y = np.meshgrid(theta1_unif, theta2_unif)

        self.theta1_random = X.flatten()
        self.theta2_random = Y.flatten()
        self.number = len(self.theta1_random)
        self.gr_total = [[] for _ in range(self.number)]
        self.x = []


    def __call__(self, x):
        return self.A_matrix @ (self.mu - x)
    
    def step_GR(self, x, noise, dt):
        noise = noise[:2]
        for i in range(self.number):
            theta1_random = self.theta1_random[i]
            theta2_random = self.theta2_random[i]
            u_random = np.array([[1., -theta1_random], [-theta1_random, 1.]]) @ (np.array([theta2_random, 1.]) - x[:2])
            u_theta0 = np.array([[1., -self.theta1], [-self.theta1, 1.]]) @ (self.mu[:2] - x[:2])
            u = u_random - u_theta0
            gr = (u / self.sigma) * noise * (dt ** 0.5) - 0.5 * (u / self.sigma)**2 * dt
            gr = gr.sum()
            self.gr_total[i].append(gr)
        return
    
    def traj(self, x, total_steps, dt=0.001):
        
        for i in tqdm(range(total_steps)):
            self.x.append(x)
            noise = np.random.randn(*x.shape)
            self.step_GR(x, noise, dt)
            x = x.copy()
            x = x + self(x) * dt + self.sigma * noise * (dt ** 0.5)
        
        return self.x, self.gr_total

def main(output_dir):
    ou = GraphOUGR(theta1=0.5, theta2=4., sigma=0.5)
    x = np.zeros(ou.dimension)
    dt = 0.001
    total_steps = int(31. / dt)
    x, gr_total = ou.traj(x, total_steps, dt)
    x = np.array(x).reshape(1, total_steps, ou.dimension)
    noise = np.random.randn(*x.shape)
    y = x + 0.1 * noise
    traj =  {
        'sample': x,
        'observation': y,
        'theta0': np.array([ou.theta1, ou.theta2]),
        'theta': np.array([ou.theta1_random, ou.theta2_random])
    }
    gr_total = np.array(gr_total).reshape(ou.number, total_steps)
    
    os.makedirs(output_dir, exist_ok=True)
    np.savez(f'{output_dir}/trajectory.npz', **traj)
    np.save(f'{output_dir}/gr_total.npy', gr_total)
    theta = np.array([ou.theta1_random, ou.theta2_random])
    np.save(f'{output_dir}/theta_random.npy', theta)

if __name__ == '__main__':
    output_dir = '' # TODO: add output dir
    main(output_dir)