# -*- coding: utf-8 -*-
"""
Created on Sun Feb 11 23:35:55 2024

@author: wsb15
"""
import time
import numpy as np
# import matplotlib.pyplot as plt
import torch

# import torch.nn as nn

m = 1
eta = 0.01
T =  5
dt = 0.01
A = torch.tensor([[0,0,1,0],[0,0,0,1],[0,0,-eta/m,0],[0,0,0,-eta/m]],device='cuda').to(torch.float32)
B = torch.tensor([[0,0],[0,0],[1/m,0],[0,1/m]],device='cuda').to(torch.float32)
C = torch.tensor([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]],device='cuda').to(torch.float32)
R = 1*torch.tensor([[1,0],[0,1]],device='cuda').to(torch.float32)
Q = 1*torch.tensor([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]],device='cuda').to(torch.float32)

QT = 0*torch.tensor([[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]],device='cuda').to(torch.float32)
N = int(T/dt)

# simulate process X untill tN*dt
def sim_X(x,tN,nn,derivatives = False,G = None):
    B_cols = len(B[0,:])
    if derivatives:
        X = torch.zeros((len(x),tN+1),device='cuda')
        del_X = torch.zeros((len(x),len(x),tN+1),device='cuda')
        if G is None:
            G = np.sqrt(dt)*torch.randn((len(x),tN),device='cuda').to(torch.float32)
        X[:,[0]] = x
        JX = torch.identity(len(x))
        del_X[:,:,0] = JX
        for n in range(tN):
            nn_in_tx = torch.cat((torch.tensor([dt*n],device='cuda'), x.reshape(-1))).to(torch.float32)
            nn_out_x = nn(nn_in_tx)
            Jutx = compute_jacobian_x(nn,nn_in_tx, nn_out_x)
            x = x + A @ x*dt + B @ nn_out_x.reshape((B_cols,1)) *dt + C @ G[:,n:n+1]
            JX = JX + (A + B@Jutx)@JX*dt
            X[:,[n+1]] = x
            del_X[:,:,n+1] = JX
        return X, del_X
    else:
        X = torch.zeros((len(x),tN+1),device='cuda')
        if G is None:
            G = np.sqrt(dt)*torch.randn((len(x),tN),device='cuda').to(torch.float32)
        X[:,[0]] = x
        for n in range(tN):
            nn_in_tx = torch.cat((torch.tensor([dt*n],device='cuda'), x.reshape(-1))).to(torch.float32)
            nn_out_x = nn(nn_in_tx)
            x = x + A @ x*dt + B @ nn_out_x.reshape((B_cols,1)) *dt + C @ G[:,n:n+1]
            X[:,[n+1]] = x
        return X

# simulate random function Z(tN*dt,x)
def sim_Z(x,tN,nn,G= None):
    B_cols = len(B[0,:])
    if G is None:
        G = np.sqrt(dt)*torch.randn((len(x),N-tN),device='cuda').to(torch.float32)
    JX = torch.eye(len(x),device='cuda')
    if N-tN > 0:
        nn_in = torch.cat((torch.tensor([dt*0],device='cuda'), x.reshape(-1))).clone().detach().requires_grad_(True).to(torch.float32)
        nn_t_x = nn(nn_in)
        Jutx = compute_jacobian_x(nn,nn_in, nn_t_x)
        utx = nn_t_x.reshape((len(B[0,:]),1))
        Z1 = dt*B.T@JX.T@(Jutx.T@(R+R.T)@utx + (Q+Q.T)@x)
        for n in range(N-tN-1):
            x = x + A @ x*dt + B @ utx *dt + C @ G[:,n:n+1]
            JX = JX + (A + B@Jutx)@JX*dt
            
            nn_in = torch.cat((torch.tensor([dt*n],device='cuda'), x.reshape(-1))).clone().detach().requires_grad_(True).to(torch.float32)
            
            nn_t_x = nn(nn_in)
            Jutx = compute_jacobian_x(nn,nn_in, nn_t_x)
            utx = nn_t_x.reshape((B_cols,1))
            Z1 += dt*B.T@JX.T@(Jutx.T@(R+R.T)@utx + (Q+Q.T)@x)
        Z2 = B.T@JX.T@(QT+QT.T)@x
        return Z1+Z2
    else:     
        return B.T@JX.T@(QT+QT.T)@x
    # x_data,del_x_data = sim_X(x,N-tN,nn,derivatives = True,G = G)
    # XT_t = x_data[:,[-1]]
    # Xtau2 = x_data[:,[tau2N]]
    # JXT_t = del_x_data[:,:,-1]
    # JXtau2 = del_x_data[:,:,tau2N]
    # nn_in = torch.from_numpy(np.concatenate(([dt*tau2N], Xtau2.reshape(-1)))).to(torch.float32).cuda()
    # nn_tau2_x = nn(nn_in)
    # Ju_tau2_x = compute_jacobian_x(nn,nn_in, nn_tau2_x)
    # u_tau2_x = nn_tau2_x.detach().numpy().reshape((len(B[0,:]),1))
    # Z = (N-tN)*dt*B.T@JXtau2.T@(Ju_tau2_x.T@(R+R.T)@u_tau2_x + (Q+Q.T)@Xtau2)  + \
    #     B.T@JXT_t.T@(QT+QT.T)@XT_t
    # return Z

def grad_est(x,nn):
    tau1N = np.random.choice(range(N+1))
    B_cols = len(B[0,:])

    G = np.sqrt(dt)*torch.randn((len(x),tau1N),device='cuda').to(torch.float32)
    for n in range(tau1N):
        nn_in_tx = torch.cat((torch.tensor([dt*n],device='cuda'), x.reshape(-1))).clone().detach().requires_grad_(True).to(torch.float32)
        nn_out_x = nn(nn_in_tx)
        utx = nn_out_x.reshape((B_cols,1))
        x = x + A @ x*dt + B @ utx *dt + C @ G[:,n:n+1]
    
    Xtau1 = x
    nn_in = torch.cat((torch.tensor([dt*tau1N],device='cuda'), x.reshape(-1))).clone().detach().requires_grad_(True).to(torch.float32)
    nn_tau1_x = nn(nn_in)
    
    # Compute gradients
    J_theta_u = compute_jacobian_params(nn,nn_in, nn_tau1_x)
    u_tau1_x = nn_tau1_x.reshape((len(B[0,:]),1))
    return (sim_Z(Xtau1,tau1N,nn).T +u_tau1_x.T@(R+R.T) )@J_theta_u*T
    
def grad_est_2(x,nn):
    B_cols = len(B[0,:])
    num_param = sum(p.numel() for p in nn.parameters())
    tauN = np.random.choice(range(N+1))
    G = np.sqrt(dt)*torch.randn((len(x),N),device='cuda').to(torch.float32)
    DX = torch.zeros((len(x),num_param),device='cuda')
    if tauN == 0:
        nn_in_tx =  torch.cat((torch.tensor([dt*0],device='cuda'), x.reshape(-1))).clone().detach().requires_grad_(True).to(torch.float32)
        nn_out_x = nn(nn_in_tx)
        utx = nn_out_x.reshape((B_cols,1))
        Jutx = compute_jacobian_x(nn,nn_in_tx, nn_out_x)
        J_theta_u = compute_jacobian_params(nn,nn_in_tx, nn_out_x)
        v1 = T*x.T@(Q+Q.T)@DX + T*utx.T@(R+R.T)@J_theta_u+  T*utx.T@(R+R.T)@Jutx @DX
    for n in range(N+1):
        nn_in_tx = torch.cat((torch.tensor([dt*n],device='cuda'), x.reshape(-1))).clone().detach().requires_grad_(True).to(torch.float32)
        nn_out_x = nn(nn_in_tx)
        utx = nn_out_x.reshape((B_cols,1))
        Jutx = compute_jacobian_x(nn,nn_in_tx, nn_out_x)
        J_theta_u = compute_jacobian_params(nn,nn_in_tx, nn_out_x)
        if n == tauN:
            v1 = T*x.T@(Q+Q.T)@DX + T*utx.T@(R+R.T)@J_theta_u  +  T*utx.T@(R+R.T)@Jutx @DX
        if n== N:
            break
        x = x + A @ x*dt + B @  utx*dt + C @ G[:,n:n+1]
        DX = DX + (A + B@Jutx)@DX*dt + B@J_theta_u*dt
    v2 = x.T@(QT+QT.T)@DX
    return v1+v2
    
    

# def sgd_step(x,nn,lr,disp = False):
#     tau1N = np.random.choice(range(N+1))
#     Xtau1 = sim_X(x,tau1N,nn)[:,[-1]]
#     nn_in = torch.from_numpy(np.concatenate(([dt*tau1N], Xtau1.reshape(-1)))).to(torch.float32).cuda()
#     nn_tau1_x = nn(nn_in)
#     # Compute gradients
#     J_theta_u = compute_jacobian_params(nn,nn_in, nn_tau1_x)
#     u_tau1_x = nn_tau1_x.detach().numpy().reshape((len(B[0,:]),1))
#     grad = (sim_Z(Xtau1,tau1N,nn).T +u_tau1_x.T@(R+R.T) )@J_theta_u*T
    
#     grad = torch.tensor(grad.reshape(-1), dtype=torch.float32)
#     param_index = 0
#     for param in nn.parameters():
#         nparams = torch.numel(param.data)
#         param.data -= lr*grad[param_index:param_index+nparams].reshape(param.data.shape)
#         param_index += nparams
#     if disp:
#         print('stepped, lr = {}'.format(lr))
    
    

    
    
def compute_jacobian_x(model, x,y,trim_T = True):
    num_outputs = len(y)
    num_inputs = len(x)

    # Initialize the Jacobian matrix
    jacobian = torch.zeros((num_outputs, num_inputs),device='cuda').to(torch.float32)

    # Compute the gradient of each output with respect to each input
    for i in range(num_outputs):
        model.zero_grad()
        # Compute the gradient of the i-th element of the flattened output
        # with respect to the input
        grad_output = torch.zeros(num_outputs,device='cuda')
        grad_output[i] = 1
        y.backward(grad_output, retain_graph=True)
        jacobian[i] = x.grad
    if trim_T:
        return jacobian[:,1:]
    return jacobian



def compute_jacobian_params(model, x,y):
    params = list(model.parameters())
    num_params = sum(p.numel() for p in params)
    num_outputs = len(y)

    jacobian = torch.zeros((num_outputs, num_params),device='cuda').to(torch.float32)

    # Compute the gradient of each output with respect to each parameter
    for i in range(num_outputs):
        model.zero_grad()
        y[i].backward(retain_graph=True)
        param_index = 0
        for param in params:
            param_grad_flat = param.grad.flatten()
            jacobian[i, param_index:param_index+len(param_grad_flat)] = param_grad_flat
            param_index += len(param_grad_flat)
    return jacobian


class NN_Control(torch.nn.Module):
    def __init__(self, input_dim=5, hidden_dim=8, output_dim=4):
        super(NN_Control, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
        self.sp1 = torch.nn.Softplus()
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.sp2 = torch.nn.Softplus()
        self.fc3 = torch.nn.Linear(hidden_dim, hidden_dim)
        self.sp3 = torch.nn.Softplus()
        self.fc4 = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.sp1(x)
        x = self.fc2(x)
        x = self.sp2(x)
        x = self.fc3(x)
        x = self.sp3(x)
        x = self.fc4(x)
        return x


if __name__ == '__main__':
    
    print('Has GPU = {}'.format(torch.cuda.is_available()))
    torch.cuda.empty_cache()
    print(torch.cuda.get_device_name(device='cuda'))
    # for param in nn.parameters():
    #     param.data.fill_(1)
    
    x_0 = torch.tensor([[20.0,20.0,-1,1]],device='cuda').t()
    # with torch.no_grad():  # Temporarily set all the requires_grad flags to false
    #     for param in nn.parameters():
    #         param.fill_(0.001)
    
    
    
    ### SGD finding optimal stationary control
    # n = 1000
    # alpha_0 = 0.05
    # for i in  range(n-1):
    #     if i%20 == 0:
    #         y = sim_X(x_0,N,nn)[0:2,-1]
    #         print('stage {}, terminal position {}'.format(i,y))
    #     lr = alpha_0/(100+i**1)
    #     sgd_step(x_0, nn, lr)
    # m = 100
    # yy = np.zeros((2,m))
    # for i in range(m):
    #     yy[:,i] = sim_X(x_0,N,nn)[0:2,-1]
    #     if i%20 == 0:
    #         print(i)
    # print('means = {}'.format(np.mean(yy,axis = 1)))
    # print('stds of the means = {}'.format( np.sqrt(np.var(yy,axis = 1)/m) ))
    # xx_inspection = sim_X(x_0,N,nn)
    
    # nn = NN_Control(input_dim=5, hidden_dim=8, output_dim=2).cuda()
    # n = 200
    # data1 = []
    # start_time = time.time()
    # for i in range(n):  
    #     data1.append(grad_est(x_0,nn).detach().cpu().numpy())
    # data1 = np.array(data1)
    # print("grad_est n = {m}: {sec:.2f} seconds".format(m = n,sec = time.time() - start_time))
    # m1 = np.mean(data1,axis = 0).reshape(-1)
    # std1 = np.sqrt(np.var(data1,axis = 0)/n).reshape(-1)
    
    # n = 200
    # data2 = []
    # start_time = time.time()
    # for i in range(n):  
    #     data2.append(grad_est_2(x_0,nn).detach().cpu().numpy())
    # data2 = np.array(data2)
    # print("grad_est_2 n = {m}: {sec:.2f} seconds".format(m = n,sec = time.time() - start_time))
    # m2= np.mean(data2,axis = 0).reshape(-1)
    # std2 = np.sqrt(np.var(data2,axis = 0)/n).reshape(-1)
    
    
    
    # plt.figure(figsize=(10, 5))
    # plt.plot(m1,label='mean of est1')
    # plt.plot(m2,label='mean of est2')
    # plt.legend(loc='upper right')
    
    # plt.figure(figsize=(10, 5))
    # plt.plot(m1-m2,label='mean1 - mean2')
    # plt.legend(loc='upper right')
    
    # plt.figure()
    # plt.rcParams["figure.figsize"] = (6.5,4.5)
    # plt.rcParams.update({'font.size': 18})
    # plt.rcParams['font.family'] = 'serif'
    # plt.tick_params(axis='both', which='major', labelsize=16)
    # plt.rc('legend',fontsize=13)
    # plt.plot(std1,label='Generator Gradient',linewidth=2.0)
    # plt.plot(std2,label='IPA',linestyle = ':',linewidth=2.0)
    # plt.legend(loc='upper left')
    # plt.xlabel("Coordinate index of θ")
    # plt.ylabel("Std of the mean")
    # plt.tight_layout()
    # plt.savefig('compare_std_n={m}'.format(m = n), dpi=800)
    
    # plt.figure(figsize=(10, 5))
    # plt.plot(std1-std2,label='std1 - std2')
    # plt.legend(loc='upper right')
    
    
    
    
    # n = 100
    # # hidden_dim_list = np.array([5,20,50,100,200,400,800,1600,2400,3600,4800,6000,7500])
    # # hidden_dim_list = np.array([5,20,50,100])
    # hidden_dim_list = np.array([8])
    # compute_time_data = np.zeros((4,len(hidden_dim_list)))
    # compute_time_data[0,:]=hidden_dim_list
    
    # for idx in range(len(hidden_dim_list)):
    #     nn = NN_Control(input_dim=5, hidden_dim=hidden_dim_list[idx], output_dim=2).cuda()
        
    #     n_params = sum(p.numel() for p in nn.parameters())
    #     compute_time_data[1,idx] = n_params
        
    #     start_time = time.time()
    #     for i in range(n):  
    #         grad_est(x_0,nn)
    #     sec = time.time() - start_time
    #     print("grad_est n = {m}: {sec:.2f} seconds".format(m = n,sec = sec))
    #     compute_time_data[2,idx] = sec/n
    
    #     start_time = time.time()
    #     for i in range(n):  
    #         grad_est_2(x_0,nn)
    #     sec = time.time() - start_time
    #     print("grad_est_2 n = {m}: {sec:.2f} seconds".format(m = n,sec = sec))
    #     compute_time_data[3,idx] = sec/n
    
    # np.save('avg_time_n_eqls_{n}'.format(n = n),compute_time_data)