import numpy as np
import csv
import sys
import time
sys.path.append('./')
from est.models.Model2 import Model2
from est.simulator.Simulator import Simulator
from est.TransitionDensity import EulerMaruyamaDensity
from est.fit.MLE import MLE


# ===========================
# Set the true model params, to simulate the process
# ===========================
d = 2
m = 2   # dimension of the Wiener process
x0 = np.array([1,-1]) # initial value of process
A = np.array([[1,2],
             [1,0]])
G1 = np.array([[0.22,0.34],
               [0.42,0.30]])
G2 = np.array([[0.38,-0.24],
               [-0.77,-0.15]])
Gs = np.stack((G1, G2))

params = np.concatenate((A.flatten(), Gs.flatten()), axis=None)
model = Model2(dim=d, m=m)
model.params = params

print(f'satisfy condition 1: {model.check_condition1(x0)}')
print(f'satisfy condition 2: {model.check_condition2(x0)}')


# ===========================
# Simulate S repetitions (we will fit to these samples)
# ===========================
S = 100 # repetitive size
N = 50# num of sample paths
n = 50 # num of time steps per sample path, n+1 observations per path
dt = 1.0/n
seed = 121
sample = np.zeros((S, N, n+1, d))

sim_start = time.time()
for i in range(S):
    simulator = Simulator(dim=d, m=m, x0=x0, n=n, dt=dt, model=model).set_seed(seed=seed + i)
    sample[i] = simulator.sim_path(num_paths=N)
sim_end = time.time()
print(f'simulation time: {sim_end - sim_start}')


# ===========================
# Fit maximum Likelihood estimators
# ===========================

# choose some initial guess for params fit
A_ini = A + 2
Gs_ini = Gs + 2
params_ini = np.concatenate((A_ini.flatten(), Gs_ini.flatten()), axis=None)

mle_start = time.time()
params_est = []
final_likelihood = []
params_est_all = []
final_likelihood_all = []

for i in range(S):
    exact_est = MLE(sample=sample[i], dt=dt, density=EulerMaruyamaDensity(model)).estimate_params(params_ini)
    if exact_est.status == 0: # only record the successful optimization results # use 'BFGS' method.
        params_est.append(exact_est.params)  # estimated params
        final_likelihood.append(exact_est.log_like)
    params_est_all.append(exact_est.params)
    final_likelihood_all.append(exact_est.log_like)
mle_end = time.time()
print(f'mle time: {mle_end - mle_start}')

l1 = len(params_est)
l = len(params_est)  # num of acceptable results
if l <= 1:
    params_est = params_est_all
    final_likelihood = final_likelihood_all
    l = S
params_est = np.stack(params_est, axis=0 )
final_likelihood = np.array(final_likelihood)

A_hat = params_est[:, :d ** 2].reshape(l, d, d)
A_hat_mean = np.mean(A_hat, axis=0)
A_hat_minus_A = A_hat - A
mse_A = np.mean(A_hat_minus_A ** 2)
var_A = np.mean(np.var(A_hat, axis=0))


Gs_hat = params_est[:, d ** 2: ].reshape(l, m, d, d)

np.random.seed(10)
x1 = np.round(np.random.randn(d, 1), 2)
Gsx = np.zeros((d, d))
for i in range(m):
    Gsx += Gs[i] @ x1 @ x1.T @ Gs[i].T 

Gsx_hat = np.zeros((l, d, d))
for j in range(l):
    for i in range(m):
        Gsx_hat[j] += Gs_hat[j, i] @ x1 @ x1.T @ Gs_hat[j,i].T
Gsx_hat_mean = np.mean(Gsx_hat, axis=0)
Gsx_hat_minus_Gsx = Gsx_hat - Gsx
mse_Gsx = np.mean(Gsx_hat_minus_Gsx ** 2)
var_Gsx = np.mean(np.var(Gsx_hat, axis=0))

final_log_likelihood = np.average(final_likelihood)


print(f"MSE of A: {mse_A}")
print(f"var of A: {var_A}")
print(f"MSE of Gsx: {mse_Gsx}")
print(f"var of Gsx: {var_Gsx}")
print(f'final log likelihoood: {final_log_likelihood}')
print(f'A_hat_mean:{A_hat_mean}')
print(f'Gsx_hat_mean:{Gsx_hat_mean}')
print(f"x1:\n {x1}")
print(f'True params:\n A:\n {A},\n Gsx:\n {Gsx}')

filename = f'Model2_d{d}_N{N}_n{n}_S{S}.csv'

myFile = open(filename, 'w')
with myFile:
    writer = csv.writer(myFile)
    writer.writerow(('d',d,'N', N, 'n',n, 'S', S, 'l', l1,))
    writer.writerow(('initial params',))
    writer.writerow(params_ini)

    writer.writerow(('final log likelihood', final_log_likelihood))
    writer.writerow(('MSE of A', mse_A))
    writer.writerow(('Variance of A', var_A))
    writer.writerow(('MSE of Gsx', mse_Gsx))
    writer.writerow(('Variance of Gsx', var_Gsx))

    writer.writerow(('x0',))
    writer.writerow(x0)

    writer.writerow(('true A',))
    writer.writerows(A)
    
    writer.writerow(('mean A_hat',))
    writer.writerows(A_hat_mean)

    writer.writerow(('x1',))
    writer.writerows(x1)

    writer.writerow(('true Gsx',))
    writer.writerows(Gsx)

    writer.writerow(('mean Gsx_hat',))
    writer.writerows(Gsx_hat_mean)

    writer.writerow(('log likelihood',))
    writer.writerow(final_likelihood)

    writer.writerow(('estimated Gsx',))
    writer.writerows(Gsx_hat)

    writer.writerow(('true Gs',))
    writer.writerows(Gs)

    writer.writerow(('estimated A',))
    writer.writerows(A_hat)

    writer.writerow(('estimated Gs',))
    writer.writerows(Gs_hat)