#!/usr/bin/env python
# coding: utf-8


import matplotlib.pyplot as plt
import numpy as np
from scipy.integrate import odeint
#from tqdm.notebook import tqdm # this is for notebook
from tqdm import tqdm
from scipy.linalg import expm


#This cell is for computing the true V_\pi of a given system.
def rhs(z,t,A,Q,d):
    #The system we are solving. 
    P = z.reshape((d,d))
    dP = Riccati(A,Q,P)
    dz = dP.reshape((d*d))
    return dz

def Riccati(A,Q,P):
    #Riccati equation for the CT finite horizon LQR
    dP = -(np.dot(A.T,P) + np.dot(P,A) + Q)
    return dP

def get_V(A,Q,d,T,k,W,h):
    print('getting V_\pi')
    t = np.linspace(0,-T,k) #Since we are going backwards in time.
    P_T = np.zeros(d*d) #P_T = Q_f which in our setting is the zero matrix. Scipy only deals with vectors
    sol = odeint(rhs,P_T,t,args=(A,Q,d,)) #The solutions at each discretization step to the ODE.
    V = 0.0 #Compute V_0(x_0) for our policy. 
    for i in (range(k)):
        V += h * sigma**2 * np.trace(np.dot(W,sol[i].reshape((d,d))))
    return V



class MonteCarlo():
    
    def __init__(self, A, Q, T, X, imax, sigma, W, B, d, gamma=1):
        self.A = A
        self.Q = Q # assume Q = I for now
        self.T = T
        self.X_0 = X
        self.sigma = sigma
        self.W = W
        self.B = B
        self.d = d
        self.gamma = gamma
        self.N_0 = 2**16
        self.h_0 = self.T/self.N_0
        N = int(T/h)
        self.M_0 = max(1,int(B/N)) # requires us passing biggest h first
        #print('generating random data')
        self.w_0 = np.random.normal(scale=sigma * np.sqrt(self.h_0),size=(self.N_0-1,self.M_0,self.d))
        # placeholder for x (N0, M, d)
        self.x = np.zeros((self.N_0,self.M_0,self.d)) # dim for x might be large: 65536x65536x3 for instance
        #print('getting matrix exponential')
        self.AA = np.zeros(((self.M_0,self.d,self.d)))
        self.AA[:self.M_0] = expm(self.A*self.h_0) 

    def generate_CT_data(self):
        for k in range(self.N_0-1):
            self.x[k+1,:,:] = np.einsum('ijk,ik->ij', self.AA, self.x[k,:,:]) + self.w_0[k,:,:]
        
        
    def run(self):
        #print('running experiment')
        self.generate_CT_data()

    def compute_Vhat(self,h):
        
        #print('computing v_hat')
        h_ratio = int(h/self.h_0)
        N = int(self.T/h)
        M = max(1,int(self.B/N))
        xh = self.x[::h_ratio,:M,:]
        #print(xh.shape,M)
        #J = np.zeros(M)
        #for m in range(M):
        #    for k in range(N):
        #        J[m] += np.inner(xh[k,m,:],xh[k,m,:])*h
        #return np.mean(J)
        
        flat_x = xh.flatten()
        J = h*np.inner(flat_x,flat_x)
        
        return J/M
        

#The problem instance
d = 3 # dimension of problem
C = [-0.25,-0.5,-1.0,-2.0,-4.0] # the negative constant that scales A
Q = np.identity(d) #our cost
T = 2.0 #our horizon
N_ode = pow(10,7) #for getting the true V
h_ode = T/N_ode
X = np.zeros(d) #our initial state.
sigma = 1.0 #scale parameter of the Wiener process
W = np.identity(d) #covariance matrix of our Wiener process
B = [2**12]
seed = 0
np.random.seed(seed)

imax = 12 # h = [2^0, 2^-1, ..., 2^-imax]
trials = 40
results = np.zeros((len(C),imax,len(B),trials))
V = []
for i in tqdm(range(len(C)), desc='Loop on C', position=0):
    A =  C[i] * np.identity(d) #our system
    V_star = get_V(A,Q,d,T,N_ode,W,h_ode)
    for k in range(len(B)):
        b = B[k]
        for v in tqdm(range(trials), desc='Loop on trials', position=1):
            for j in range(imax):
                h = 2**(-j)
                if j == 0:
                    agent = MonteCarlo(A, Q, T, X, imax, sigma, W, b, d,gamma=1)
                    agent.run()
                V_hat = agent.compute_Vhat(h)
                results[i,j,k,v] = (V_hat-V_star)**2           
            del agent
np.save('lqr_results.npy',results,allow_pickle=True)        
