import numpy as np
from scipy.linalg import expm
import time
from scipy.optimize import least_squares
import random
import csv
import math

##################################################
##          True Parameters Generate            ##
##################################################

d = 10
p = 5

n_zeros_A = 70
n_zeros_B = 35
n_zeros_D = 20

def check_num_zeros(M):
    n0 = 0
    for i in range(M.shape[0]):
        for j in range(M.shape[1]):
            if M[i,j] == 0:
                n0 += 1
    return n0

def set_aij_zeros(n_zeros, M):
    L = M
    d1 = M.shape[0]
    d2 = M.shape[1]
    n0 = check_num_zeros(L)
    while n0 < n_zeros:
        i = random.choice(range(d1))
        j = random.choice(range(d2))
        L[i,j] = 0
        n0 = check_num_zeros(L)
    return L

def generate_true_parameters(d, p, seed):
    np.random.seed(seed)
    random.seed(seed)
    A = np.random.randint(-2, 3, size=(d, d))
    A = set_aij_zeros(n_zeros_A, A)

    A = np.eye(d)
    B = np.random.randint(-2, 3, size=(d, p))
    B = set_aij_zeros(n_zeros_B, B)
    C = np.zeros((p,d))
    D = np.random.randint(-2, 3, size=(p, p))
    for i in range(p):
        for j in range(i, p):
            D[j][i] = 0
    D = set_aij_zeros(n_zeros_D, D)

    x0 = np.random.randint(-2,3, size=(d,1))
    z0_1 = np.array([[1],[0],[0],[0],[0]])
    z0_2 = np.array([[0],[1],[0],[0],[0]])
    z0_3 = np.array([[0],[0],[1],[0],[0]])
    z0_4 = np.array([[0],[0],[0],[1],[0]])
    z0_5 = np.array([[0],[0],[0],[0],[1]])

    xz0_1 = np.vstack((x0, z0_1))
    xz0_2 = np.vstack((x0, z0_2))
    xz0_3 = np.vstack((x0, z0_3))
    xz0_4 = np.vstack((x0, z0_4))
    xz0_5 = np.vstack((x0, z0_5))
    M = np.vstack((np.hstack((A, B)), np.hstack((C, D))))

    return(A, B, D, x0, z0_1, z0_2, z0_3, z0_4, z0_5)


##################################################
##      Check Identifiability Conditions        ##
##################################################

def check_identifiability_condition(A, B, D, d, p, x0, z0):
    gamma = np.linalg.matrix_power(A, p) @ x0
    for j in range(p):
        gamma += np.linalg.matrix_power(A, p-1-j) @ B @ np.linalg.matrix_power(D, j) @ z0

    H = gamma
    for i in range(1,d):
        H = np.hstack((H, np.linalg.matrix_power(A, i) @ gamma))

    L = B
    for j in range(1, p):
        L = np.vstack((L, B@np.linalg.matrix_power(D, j)))

    rank_1 = np.linalg.matrix_rank(H)
    rank_2 = np.linalg.matrix_rank(L)

    if rank_1 == d and rank_2 == p:
        print('The identifiability condition is satisfied.')
    else:
        print('The identifiability condition is NOT satisfied.' )
    
    return(rank_1, rank_2)

# rank = check_identifiability_condition(A, B, D, d, p, x0, z0)

##################################################
##      Generate True Observations x(t_i)       ##
##################################################
def x(theta, z0, t):

    i = 0
    x = np.zeros((d, len(t)))
    x0 = np.array(theta[:d]).reshape(d, 1)
    xz0 = np.vstack((x0, z0))
    A = np.array(theta[d: d*d+d]).reshape(d, d)
    B = np.array(theta[d*d+d: d*d+d*p+d]).reshape(d, p)
    M1 = np.hstack((A, B))
    M2 = np.hstack((np.zeros((p,d)), np.array(theta[d*d+d*p+d: ]).reshape(p, p)))
    M = np.vstack((M1, M2))
    
    for i in range(t.shape[0]):
        eMt = expm(M*t[i])
        xzt = eMt@xz0
        
        for j in range(d):
            x[j, i] = xzt[j][0].item()
        
    return x


T = 1
n = 10 # sample size
N = 50  # number of replications
ts = np.linspace(0, T, n)

identify_indicator = 0 # =0 indicates it is not identifiable
true_parameters = []
seed = 0
i = 0
while i < N:
    A, B, D, x0, z0_1, z0_2, z0_3, z0_4, z0_5 = generate_true_parameters(d, p, seed)
    rank_11, rank_12 = check_identifiability_condition(A, B, D, d, p, x0, z0_1)
    rank_21, rank_22 = check_identifiability_condition(A, B, D, d, p, x0, z0_2)
    rank_31, rank_32 = check_identifiability_condition(A, B, D, d, p, x0, z0_3)
    rank_41, rank_42 = check_identifiability_condition(A, B, D, d, p, x0, z0_4)
    rank_51, rank_52 = check_identifiability_condition(A, B, D, d, p, x0, z0_5)
    if identify_indicator == 0:
        theta_true = np.hstack((x0.flatten(), A.flatten(), B.flatten(), D.flatten())).tolist() # true parameters
        true_parameters.append(theta_true)
        i = i+1
        seed += 1
    elif rank_11 == rank_21 == rank_31 == rank_41 == rank_51 == d and rank_12 == rank_22 == rank_32 == rank_42 == rank_52 == p:
        theta_true = np.hstack((x0.flatten(), A.flatten(), B.flatten(), D.flatten())).tolist() # true parameters
        true_parameters.append(theta_true)
        i = i+1
        seed += 1
    else:
        seed += 1

# print(f'true_parameters:{true_parameters}')  

xt_values = [np.hstack((x(theta, z0_1, ts), x(theta, z0_2, ts), x(theta, z0_3, ts), x(theta, z0_4, ts), x(theta, z0_5, ts))) for theta in true_parameters]# true states/observations x(t_i)

# print(xt_values)

##################################################
## Nonlinear Least Squares Parameter Estimation ##
##################################################

l2_A = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'A'
l2_x0 = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'x0'
l2_B = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'B'
l2_D = np.zeros(N)  # squared l2 norm of the difference between estimated and true 'D'
theta_hat = [] # estimated parameters 

def make_fun(xt, ts):
    def fun(theta):
        return (np.hstack((x(theta, z0_1, ts), x(theta, z0_2, ts), x(theta, z0_3, ts), x(theta, z0_4, ts), x(theta, z0_5, ts))) - xt).flatten()
    return fun

random.seed(0)
print('Currently, it is the NLS parameter estimation process, it may take some time, please wait for the results.')
start_time = time.time()
i = 0
for theta_true, xt in zip(true_parameters, xt_values):

    # randomly generate initial value of parameters
    theta0 = np.array(theta_true) + random.uniform(-0.05, 0.05)
    res = least_squares(make_fun(xt, ts), theta0)
    
    l2_A[i] = np.sum((np.array(theta_true[d: d*d+d]) - res.x[d: d*d+d])**2)/(d*d)
    
    l2_x0[i]= np.sum((np.array(theta_true[: d]) - res.x[: d])**2)/d
    
    B_true = np.array(theta_true[d*d+d: d*d+d*p+d]).reshape(d, p)
    B_hat = np.array(res.x[d*d+d: d*d+d*p+d]).reshape(d, p)
    
    D_true = np.array(theta_true[d*d+d*p+d: ]).reshape(p, p)
    D_hat = np.array(res.x[d*d+d*p+d: ]).reshape(p, p)
     
    l2_B[i] = np.sum((B_true - B_hat)**2)/(d*p)
    l2_D[i] = np.sum((D_true - D_hat)**2)/(p*p)

    theta_hat +=[res.x]
    i += 1
end_time = time.time()

l2_norm_mean_A = np.mean(l2_A)
l2_norm_variance_A = np.var(l2_A)
l2_norm_mean_x0 = np.mean(l2_x0)
l2_norm_variance_x0 = np.var(l2_x0)

l2_norm_mean_B = np.mean(l2_B)
l2_norm_variance_B = np.var(l2_B)
l2_norm_mean_D = np.mean(l2_D)
l2_norm_variance_D = np.var(l2_D)

estimation_time = end_time - start_time

print(f'dimensions d and p: {d} and {p}')

print(f'Mean and variance of l2_A: \n {np.round(l2_norm_mean_A, 6)}, {np.round(l2_norm_variance_A, 6)}')
print(f'Mean and variance of l2_x0: \n {np.round(l2_norm_mean_x0, 6)}, {np.round(l2_norm_variance_x0, 6)}')
print(f'Mean and variance of l2_B: \n {np.round(l2_norm_mean_B, 6)}, {np.round(l2_norm_variance_B, 6)}')
print(f'Mean and variance of l2_D: \n {np.round(l2_norm_mean_D, 6)}, {np.round(l2_norm_variance_D, 6)}')
print(f'Estimation time: {estimation_time}')


##################################################
##          Export Estimation Results           ##
##################################################

filename = f'd{d}_p{p}_n{n}_N{N}.csv'


myFile = open(filename, 'w')
with myFile:
    writer = csv.writer(myFile)
    
    writer.writerow(('d', d, 'p', p, 'n', n, 'N', N))
    
    writer.writerow(('l2_norm_mean_A', 'l2_norm_variance_A'))
    writer.writerow((l2_norm_mean_A, l2_norm_variance_A))
    
    writer.writerow(('l2_norm_mean_x0', 'l2_norm_variance_x0'))
    writer.writerow((l2_norm_mean_x0, l2_norm_variance_x0))
    
    writer.writerow(('l2_norm_mean_B', 'l2_norm_variance_B'))
    writer.writerow((l2_norm_mean_B, l2_norm_variance_B))
    
    writer.writerow(('l2_norm_mean_D', 'l2_norm_variance_D'))
    writer.writerow((l2_norm_mean_D, l2_norm_variance_D))
    
    writer.writerow(('Estimation time', estimation_time))
    
    writer.writerow(('l2_A', ))
    writer.writerows(l2_A.reshape(-1,1))
    writer.writerow(('l2_x0', ))
    writer.writerows(l2_x0.reshape(-1,1))
    writer.writerow(('l2_B', ))
    writer.writerows(l2_B.reshape(-1,1))
    writer.writerow(('l2_D', ))
    writer.writerows(l2_D.reshape(-1,1))

    
    writer.writerow(('theta_hat: estimated parameters',))
    writer.writerows(theta_hat)