import numpy as np
import csv
import sys
import time
sys.path.append('./')
from est.models.Model1 import Model1
from est.TransitionDensity import ExactDensity
from est.simulator.Simulator import Simulator
from est.fit.MLE import MLE


# ===========================
# Set the true model params, to simulate the process
# ===========================
d = 2
m = 2
r = 0.5
x0 = np.array([1.87,-0.98]) # initial value of process
A = np.array([[1.76,-0.1],
             [0.98,0]])
np.random.seed(121)
G = np.round(np.random.randn(d, m) * r, 2)
params = np.concatenate((A.flatten(), G.flatten()), axis=None)
model = Model1(dim=d, m=m)
model.params = params
while model.check_condition(x0) == False:
    G = np.round(np.random.randn(d, m) * r, 2)
    params = np.concatenate((A.flatten(), G.flatten()), axis=None)
    model.params = params

# ===========================
# Simulate S repetitions (we will fit to these samples)
# ===========================
S = 100 # repetitive size
N = 50 # num of sample paths
n = 50 # num of time steps per sample path, n+1 observations per path
dt = 1.0/n
seed = 121
sample = np.zeros((S, N, n+1, d))

sim_start = time.time()
for i in range(S):
    simulator = Simulator(dim=d, m=m, x0=x0, n=n, dt=dt, model=model).set_seed(seed=seed + i)
    sample[i] = simulator.sim_path(num_paths=N)
sim_end = time.time()
print(f'simulation time: {sim_end - sim_start}')

# ===========================
# Fit maximum Likelihood estimators
# ===========================

# choose some initial guess for params fit
A_ini = A + 2
G_ini = G + 2
params_ini = np.concatenate((A_ini.flatten(), G_ini.flatten()), axis=None)

mle_start = time.time()
params_est = []
final_likelihood = []
params_est_all = []
final_likelihood_all = []

for i in range(S):
    exact_est = MLE(sample=sample[i], dt=dt, density=ExactDensity(model)).estimate_params(params_ini)
    if exact_est.status != 0: # only record the successful optimization results # use 'trust-constr' method.
        params_est.append(exact_est.params)  # estimated params
        final_likelihood.append(exact_est.log_like)
    params_est_all.append(exact_est.params)
    final_likelihood_all.append(exact_est.log_like)
    
mle_end = time.time()
print(f'mle time: {mle_end - mle_start}')

l1 = len(params_est)
l = len(params_est)  # num of acceptable results
if l == 0:
    params_est = params_est_all
    final_likelihood = final_likelihood_all
    l = S

params_est = np.stack( params_est, axis=0 )
final_likelihood = np.array(final_likelihood)


H = G @ G.T
A_hat = params_est[:, : d ** 2].reshape(l, d, d)
A_hat_mean = np.mean(A_hat, axis=0)
A_hat_minus_A = A_hat - A
mse_A = np.mean(A_hat_minus_A ** 2)
var_A = np.mean(np.var(A_hat, axis=0))

G_hat = params_est[:, d ** 2 :].reshape(l, d, m)
H_hat = np.stack([g_hat @ g_hat.T for g_hat in G_hat], axis=0)
H_hat_mean = np.mean(H_hat, axis=0)
H_hat_minus_H = H_hat - H
mse_H = np.mean(H_hat_minus_H ** 2)
var_H = np.mean(np.var(H_hat, axis=0))

final_log_likelihood = np.average(final_likelihood)

print(f'mse_A: {mse_A}')
print(f'var_A: {var_A}')
print(f'mse_H: {mse_H}')
print(f'var_H: {var_H}')
print(f'final log likelihoood: {final_log_likelihood}')
print(f'A_hat_mean:{A_hat_mean}')
print(f'H_hat_mean:{H_hat_mean}')

print(f'True params:\n A:\n {A},\n H:\n {H}')

filename = f'd{d}_N{N}_n{n}_S{S}.csv'

myFile = open(filename, 'w')
with myFile:
    writer = csv.writer(myFile)
    writer.writerow(('d',d,'N', N, 'n',n, 'S', S, 'l', l1,))
    writer.writerow(('initial params',))
    writer.writerow(params_ini)

    writer.writerow(('final log likelihood', final_log_likelihood))
    writer.writerow(('MSE of A', mse_A))
    writer.writerow(('Variance of A', var_A))
    writer.writerow(('MSE of H', mse_H))
    writer.writerow(('Variance of H', var_H))

    writer.writerow(('x0',))
    writer.writerow(x0)

    writer.writerow(('true A',))
    writer.writerows(A)

    writer.writerow(('mean A_hat',))
    writer.writerows(A_hat_mean)

    writer.writerow(('true G',))
    writer.writerows(G)

    writer.writerow(('true H',))
    writer.writerows(H)

    writer.writerow(('mean H_hat',))
    writer.writerows(H_hat_mean)

    writer.writerow(('log likelihood',))
    writer.writerow(final_likelihood)

    writer.writerow(('estimated A',))
    writer.writerows(A_hat)

    writer.writerow(('estimated G',))
    writer.writerows(G_hat)