import numpy as np
from scipy.linalg import expm
import time
from scipy.optimize import least_squares
import random
import csv
import math

##################################################
##          True Parameters Generate            ##
##################################################

d = 10
p = 5

n_zeros_A = 70
n_zeros_B = 35
n_zeros_D = 20

def check_num_zeros(M):
    n0 = 0
    for i in range(M.shape[0]):
        for j in range(M.shape[1]):
            if M[i,j] == 0:
                n0 += 1
    return n0

def set_aij_zeros(n_zeros, M):
    L = M
    d1 = M.shape[0]
    d2 = M.shape[1]
    n0 = check_num_zeros(L)
    while n0 < n_zeros:
        i = random.choice(range(d1))
        j = random.choice(range(d2))
        L[i,j] = 0
        n0 = check_num_zeros(L)
    return L

np.random.seed(10)
random.seed(10) 
A = np.random.randint(-2, 3, size=(d, d))
A = set_aij_zeros(n_zeros_A, A)

# A = np.eye(d)
B = np.random.randint(-2, 3, size=(d, p))
B = set_aij_zeros(n_zeros_B, B)
C = np.zeros((p,d))
D = np.random.randint(-2, 3, size=(p, p))
for i in range(p):
    for j in range(i, p):
        D[j][i] = 0
D = set_aij_zeros(n_zeros_D, D)

x0 = np.random.randint(-2,3, size=(d,1))
z0_1 = np.array([[1],[0],[0],[0],[0]])
z0_2 = np.array([[0],[1],[0],[0],[0]])
z0_3 = np.array([[0],[0],[1],[0],[0]])
z0_4 = np.array([[0],[0],[0],[1],[0]])
z0_5 = np.array([[0],[0],[0],[0],[1]])

xz0_1 = np.vstack((x0, z0_1))
xz0_2 = np.vstack((x0, z0_2))
xz0_3 = np.vstack((x0, z0_3))
xz0_4 = np.vstack((x0, z0_4))
xz0_5 = np.vstack((x0, z0_5))
M = np.vstack((np.hstack((A, B)), np.hstack((C, D))))

print(f'Matrix M: \n {M}')
print(f'Initial conditions: \n {xz0_1},\n {xz0_2},\n {xz0_3},\n {xz0_4},\n {xz0_5}')

##################################################
##      Check Identifiability Conditions        ##
##################################################

def check_identifiability_condition(A, B, D, d, p, x0, z0):
    gamma = np.linalg.matrix_power(A, p) @ x0
    for j in range(p):
        gamma += np.linalg.matrix_power(A, p-1-j) @ B @ np.linalg.matrix_power(D, j) @ z0

    H = gamma
    for i in range(1,d):
        H = np.hstack((H, np.linalg.matrix_power(A, i) @ gamma))

    rank = np.linalg.matrix_rank(H)
    if rank == d:
        print('The identifiability condition is satisfied.')
    else:
        print('The identifiability condition is NOT satisfied.' )
    
    return(gamma, H, rank)

gamma, H, rank = check_identifiability_condition(A, B, D, d, p, x0, z0_1)
gamma, H, rank = check_identifiability_condition(A, B, D, d, p, x0, z0_2)
gamma, H, rank = check_identifiability_condition(A, B, D, d, p, x0, z0_3)
gamma, H, rank = check_identifiability_condition(A, B, D, d, p, x0, z0_4)
gamma, H, rank = check_identifiability_condition(A, B, D, d, p, x0, z0_5)

print(f'rank of B: {np.linalg.matrix_rank(B)}')

##################################################
##      Generate True Observations x(t_i)       ##
##################################################

def x(theta, z0, t):

    i = 0
    x = np.zeros((d, len(t)))
    x0 = np.array(theta[:d]).reshape(d, 1)
    xz0 = np.vstack((x0, z0))
    A = np.array(theta[d: d*d+d]).reshape(d, d)
    B = np.array(theta[d*d+d: d*d+d*p+d]).reshape(d, p)
    M1 = np.hstack((A, B))
    M2 = np.hstack((np.zeros((p,d)), np.array(theta[d*d+d*p+d: ]).reshape(p, p)))
    M = np.vstack((M1, M2))
    
    for i in range(t.shape[0]):
        eMt = expm(M*t[i])
        xzt = eMt@xz0
        
        for j in range(d):
            x[j, i] = xzt[j][0].item()
        
    return x

T = 1
n = 10 # sample size
N = 50  # number of replications
ts = np.linspace(0, T, n)
theta_true = np.hstack((x0.flatten(), A.flatten(), B.flatten(), D.flatten())).tolist() # true parameters

xt_1 = x(theta_true, z0_1, ts) # true states/observations x(t_i)l2_A = np.zeros(N)
xt_2 = x(theta_true, z0_2, ts)
xt_3 = x(theta_true, z0_3, ts)
xt_4 = x(theta_true, z0_4, ts)
xt_5 = x(theta_true, z0_5, ts)

##################################################
## Nonlinear Least Squares Parameter Estimation ##
##################################################

l2_A = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'A'
l2_x0 = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'x0'
l2_B = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'B'
l2_D = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'D'
theta_hat = [] # estimated parameters 

def fun(theta):
    return (np.hstack((x(theta, z0_1, ts), x(theta, z0_2, ts), x(theta, z0_3, ts), x(theta, z0_4, ts), x(theta, z0_5, ts))) - np.hstack((xt_1, xt_2, xt_3, xt_4, xt_5))).flatten()

random.seed(0)
print('Currently, it is the NLS parameter estimation process, it may take some time, please wait for the results.')
start_time = time.time()
for i in np.arange(N):

    # randomly generate initial value of parameters
    theta0 = np.array(theta_true) - random.uniform(-0.1, 0.1)
    res = least_squares(fun, theta0)
    
    l2_A[i] = np.sum((np.array(theta_true[d: d*d+d]) - res.x[d: d*d+d])**2)/(d*d)
    
    l2_x0[i]= np.sum((np.array(theta_true[: d]) - res.x[: d])**2)/d
    
    B_true = np.array(theta_true[d*d+d: d*d+d*p+d]).reshape(d, p)
    B_hat = np.array(res.x[d*d+d: d*d+d*p+d]).reshape(d, p)
    
    D_true = np.array(theta_true[d*d+d*p+d: ]).reshape(p, p)
    D_hat = np.array(res.x[d*d+d*p+d: ]).reshape(p, p)
     
    l2_B[i] = np.sum((B_true - B_hat)**2)/(d*p)
    l2_D[i] = np.sum((D_true - D_hat)**2)/(p*p)

    theta_hat +=[res.x]

end_time = time.time()

l2_norm_mean_A = np.mean(l2_A)
l2_norm_variance_A = np.var(l2_A)
l2_norm_mean_x0 = np.mean(l2_x0)
l2_norm_variance_x0 = np.var(l2_x0)

l2_norm_mean_B = np.mean(l2_B)
l2_norm_variance_B = np.var(l2_B)
l2_norm_mean_D = np.mean(l2_D)
l2_norm_variance_D = np.var(l2_D)

estimation_time = end_time - start_time

print(f'dimensions d and p: {d} and {p}')

print(f'Mean and variance of l2_A: \n {np.round(l2_norm_mean_A, 6)}, {np.round(l2_norm_variance_A, 6)}')
print(f'Mean and variance of l2_x0: \n {np.round(l2_norm_mean_x0, 6)}, {np.round(l2_norm_variance_x0, 6)}')
print(f'Mean and variance of l2_B: \n {np.round(l2_norm_mean_B, 6)}, {np.round(l2_norm_variance_B, 6)}')
print(f'Mean and variance of l2_D: \n {np.round(l2_norm_mean_D, 6)}, {np.round(l2_norm_variance_D, 6)}')
print(f'Estimation time: {estimation_time}')

##################################################
##          Export Estimation Results           ##
##################################################

filename = f'd{d}_p{p}_n{n}_N{N}.csv'


myFile = open(filename, 'w')
with myFile:
    writer = csv.writer(myFile)
    
    writer.writerow(('d', d, 'p', p, 'n', n, 'N', N))

    writer.writerow(('M: parameter matrix',))
    writer.writerows(M)
    
    writer.writerow(('xz0_1: initial condition',))
    writer.writerows(xz0_1)
    
    writer.writerow(('xz0_2: initial condition',))
    writer.writerows(xz0_2)
    
    writer.writerow(('xz0_3: initial condition',))
    writer.writerows(xz0_3)

    writer.writerow(('xz0_4: initial condition',))
    writer.writerows(xz0_4)

    writer.writerow(('xz0_5: initial condition',))
    writer.writerows(xz0_5)
    
    writer.writerow(('gamma: new initial condition',))
    writer.writerows(gamma)
    
    writer.writerow(('H: {gamma, Agamma, ..., A^{d-1}gamma}',))
    writer.writerows(H)
    
    writer.writerow(('rank(H)', rank))
    
    writer.writerow(('l2_norm_mean_A', 'l2_norm_variance_A'))
    writer.writerow((l2_norm_mean_A, l2_norm_variance_A))
    
    writer.writerow(('l2_norm_mean_x0', 'l2_norm_variance_x0'))
    writer.writerow((l2_norm_mean_x0, l2_norm_variance_x0))
    
    writer.writerow(('l2_norm_mean_B', 'l2_norm_variance_B'))
    writer.writerow((l2_norm_mean_B, l2_norm_variance_B))
    
    writer.writerow(('l2_norm_mean_D', 'l2_norm_variance_D'))
    writer.writerow((l2_norm_mean_D, l2_norm_variance_D))
    
    writer.writerow(('Estimation time', estimation_time))
    
    writer.writerow(('l2_A', ))
    writer.writerows(l2_A.reshape(-1,1))
    writer.writerow(('l2_x0', ))
    writer.writerows(l2_x0.reshape(-1,1))
    writer.writerow(('l2_B', ))
    writer.writerows(l2_B.reshape(-1,1))
    writer.writerow(('l2_D', ))
    writer.writerows(l2_D.reshape(-1,1))

    
    writer.writerow(('theta_hat: estimated parameters',))
    writer.writerows(theta_hat)