import numpy as np
from scipy.linalg import expm
import time
from scipy.optimize import least_squares
import random
import csv
import math

##################################################
##          True Parameters Generate            ##
##################################################

d = 3
p = 3

np.random.seed(0)
A = np.random.randint(-2, 3, size=(d, d))
# A = np.eye(d)
B = np.random.randint(-2, 3, size=(d, p))
C = np.zeros((p,d))
D = np.random.randint(-2, 3, size=(p, p))
for i in range(p):
    for j in range(i, p):
        D[j][i] = 0
x0 = np.random.randint(-2,3, size=(d,1))
z0 = np.random.randint(-2,3, size=(p,1))

xz0 = np.vstack((x0, z0))
M = np.vstack((np.hstack((A, B)), np.hstack((C, D))))

print(f'Matrix M: \n {M}')
print(f'Initial condition: \n {xz0}')

##################################################
##      Check Identifiability Conditions        ##
##################################################

def check_identifiability_condition(A, B, D, d, p, x0, z0):
    gamma = np.linalg.matrix_power(A, p) @ x0
    for j in range(p):
        gamma += np.linalg.matrix_power(A, p-1-j) @ B @ np.linalg.matrix_power(D, j) @ z0

    H = gamma
    for i in range(1,d):
        H = np.hstack((H, np.linalg.matrix_power(A, i) @ gamma))

    rank = np.linalg.matrix_rank(H)
    if rank == d:
        print('The identifiability condition is satisfied.')
    else:
        print('The identifiability condition is NOT satisfied.' )
    
    return(gamma, H, rank)

gamma, H, rank = check_identifiability_condition(A, B, D, d, p, x0, z0)

##################################################
##      Generate True Observations x(t_i)       ##
##################################################

def x(theta, t):

    i = 0
    x = np.zeros((d, len(t)))
    xz0 = np.array(theta[:d+p]).reshape(d+p, 1)
    A = np.array(theta[d+p: d*d+d+p]).reshape(d, d)
    B = np.array(theta[d*d+d+p: d*d+d*p+d+p]).reshape(d, p)
    M1 = np.hstack((A, B))
    M2 = np.hstack((np.zeros((p,d)), np.array(theta[d*d+d*p+d+p: ]).reshape(p, p)))
    M = np.vstack((M1, M2))
    
    for i in range(t.shape[0]):
        eMt = expm(M*t[i])
        xzt = eMt@xz0
        
        for j in range(d):
            x[j, i] = xzt[j][0].item()
        
    return x

T = 1
n = 10 # sample size
N = 100  # number of replications
ts = np.linspace(0, T, n)
theta_true = np.hstack((xz0.flatten(), A.flatten(), B.flatten(), D.flatten())).tolist() # true parameters

xt = x(theta_true, ts) # true states/observations x(t_i)l2_A = np.zeros(N)

##################################################
##   Check Discrete Observations independence   ##
##################################################
hh = xt[:,:d+p]
for j in range(p):
    hh = np.vstack((hh, ts[:d+p]**j))
if np.linalg.matrix_rank(hh) == d+p:
    print("The expended observations are linearly independent.")
else:
    print("The expended observations are NOT linearly independent.")

##################################################
## Nonlinear Least Squares Parameter Estimation ##
##################################################

l2_A = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'A'
l2_x0 = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'x0'
l2_B = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'B'
l2_D = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'D'
l2_z0 = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'z0'
l2_BDz0 = []   # squared l2 norm of the difference between estimated and true 'Bz0', 'BDz0', 'BD^2z0/2!', ...
theta_hat = [] # estimated parameters 

def fun(theta):
    return (x(theta, ts) - xt).flatten()

random.seed(100)
print('Currently, it is the NLS parameter estimation process, it may take some time, please wait for the results.')
start_time = time.time()
for i in np.arange(N):

    # randomly generate initial value of parameters
    theta0 = np.array(theta_true) - random.uniform(-0.1, 0.1)
    res = least_squares(fun, theta0)
    
    l2_A[i] = np.sum((np.array(theta_true[d+p: d*d+d+p]) - res.x[d+p: d*d+d+p])**2)/(d*d)
    
    l2_x0[i]= np.sum((np.array(theta_true[: d]) - res.x[: d])**2)/d
    
    B_true = np.array(theta_true[d*d+d+p: d*d+d*p+d+p]).reshape(d, p)
    B_hat = np.array(res.x[d*d+d+p: d*d+d*p+d+p]).reshape(d, p)
    
    D_true = np.array(theta_true[d*d+d*p+d+p: ]).reshape(p, p)
    D_hat = np.array(res.x[d*d+d*p+d+p: ]).reshape(p, p)
    
    z0_true = np.array(theta_true[d: d+p]).reshape(p, 1)
    z0_hat = np.array(res.x[d: d+p]).reshape(p, 1)
    
    l2_B[i] = np.sum((B_true - B_hat)**2)/(d*p)
    l2_D[i] = np.sum((D_true - D_hat)**2)/(p*p)
    l2_z0[i] = np.sum((z0_true - z0_hat)**2)/p

    l2_BDz0_j = np.zeros(p)
    for j in range(p):
        BDz0_true = B_true @ np.linalg.matrix_power(D_true, j) @ z0_true/math.factorial(j)
        BDz0_hat = B_hat @ np.linalg.matrix_power(D_hat, j) @ z0_hat/math.factorial(j)
        l2_BDz0_j[j] = np.sum((BDz0_true - BDz0_hat)**2)/d
      
    l2_BDz0 += [l2_BDz0_j]
    theta_hat +=[res.x]

l2_BDz0 = np.array(l2_BDz0)
end_time = time.time()

l2_norm_mean_A = np.mean(l2_A)
l2_norm_variance_A = np.var(l2_A)
l2_norm_mean_x0 = np.mean(l2_x0)
l2_norm_variance_x0 = np.var(l2_x0)
l2_norm_mean_BDz0 = np.mean(l2_BDz0, axis=0)
l2_norm_variance_BDz0 = np.var(l2_BDz0, axis=0)

l2_norm_mean_B = np.mean(l2_B)
l2_norm_variance_B = np.var(l2_B)
l2_norm_mean_D = np.mean(l2_D)
l2_norm_variance_D = np.var(l2_D)
l2_norm_mean_z0 = np.mean(l2_z0)
l2_norm_variance_z0 = np.var(l2_z0)

estimation_time = end_time - start_time

print(f'dimensions d and p: {d} and {p}')

print(f'Mean and variance of l2_A: \n {np.round(l2_norm_mean_A, 6)}, {np.round(l2_norm_variance_A, 6)}')
print(f'Mean and variance of l2_x0: \n {np.round(l2_norm_mean_x0, 6)}, {np.round(l2_norm_variance_x0, 6)}')
print(f'Mean and variance of l2_BDz0: \n {np.round(l2_norm_mean_BDz0, 6)}, {np.round(l2_norm_variance_BDz0, 6)}')

print(f'Estimation time: {estimation_time}')

##################################################
##          Export Estimation Results           ##
##################################################

filename = f'd{d}_p{p}_n{n}_N{N}.csv'


myFile = open(filename, 'w')
with myFile:
    writer = csv.writer(myFile)
    
    writer.writerow(('d', d, 'p', p, 'n', n, 'N', N))

    writer.writerow(('M: parameter matrix',))
    writer.writerows(M)
    
    writer.writerow(('xz0: initial condition',))
    writer.writerows(xz0)
    
    writer.writerow(('gamma: new initial condition',))
    writer.writerows(gamma)
    
    writer.writerow(('H: {gamma, Agamma, ..., A^{d-1}gamma}',))
    writer.writerows(H)
    
    writer.writerow(('rank(H)', rank))
    
    writer.writerow(('l2_norm_mean_A', 'l2_norm_variance_A'))
    writer.writerow((l2_norm_mean_A, l2_norm_variance_A))
    
    writer.writerow(('l2_norm_mean_x0', 'l2_norm_variance_x0'))
    writer.writerow((l2_norm_mean_x0, l2_norm_variance_x0))
    
    writer.writerow(('l2_norm_mean_B', 'l2_norm_variance_B'))
    writer.writerow((l2_norm_mean_B, l2_norm_variance_B))
    
    writer.writerow(('l2_norm_mean_D', 'l2_norm_variance_D'))
    writer.writerow((l2_norm_mean_D, l2_norm_variance_D))

    writer.writerow(('l2_norm_mean_z0', 'l2_norm_variance_z0'))
    writer.writerow((l2_norm_mean_z0, l2_norm_variance_z0))

    writer.writerow(('l2_norm_mean_BDz0', ))
    writer.writerows((l2_norm_mean_BDz0.reshape(-1,1)))
    
    writer.writerow(('l2_norm_variance_BDz0',))
    writer.writerows((l2_norm_variance_BDz0.reshape(-1,1)))
    
    writer.writerow(('Estimation time', estimation_time))
    
    writer.writerow(('l2_A', ))
    writer.writerows(l2_A.reshape(-1,1))
    writer.writerow(('l2_x0', ))
    writer.writerows(l2_x0.reshape(-1,1))
    writer.writerow(('l2_BDz0',))
    writer.writerows(l2_BDz0)

    writer.writerow(('l2_B', ))
    writer.writerows(l2_B.reshape(-1,1))
    writer.writerow(('l2_D', ))
    writer.writerows(l2_D.reshape(-1,1))
    writer.writerow(('l2_z0', ))
    writer.writerows(l2_z0.reshape(-1,1))
    
    writer.writerow(('theta_hat: estimated parameters',))
    writer.writerows(theta_hat)