#from model import get_full_model
from train_test_fns import validation_fn, weighted_graph_grad, lag
import argparse
import config
import functools
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import train_test_fns as ttf
import utils
import pickle
import open3d as o3d
import sys

from model import get_full_model

def main(inFile, weightsLoc):
    with open(inFile, 'rb') as inFile:
        args = pickle.load(inFile)

    if args.weight != False:
        weight = '_weight'

    fibers = args.fibers
    
    save_higer_time_rez = True

    device = torch.device(config.DEVICE)
    val_loader = ttf.get_test_data_loader(args.data_path_val, 1, knn=args.knn, modes=args.modes, weight=weight,
                                     fibers=fibers, withDiff=args.withDiff, num_samps=None, newres=args.new_res)  # validation
    dataset_val = val_loader.dataset

    eigenpairs_val = {}
    for pid in os.listdir(args.data_path_val):
        eig_path = os.path.join(os.path.join(args.data_path_val,pid), 'knn='+str(args.knn)+weight, 'HK')
        with open(os.path.join(eig_path, 'vecs.pkl'), 'rb') as fp:
            pid_eig_vec = torch.Tensor(pickle.load(fp)[0][:, :args.modes].T).to(device)
        with open(os.path.join(eig_path, 'vecs.pkl'), 'rb') as fp:
            eigvecs = pickle.load(fp)[0]
        with open(os.path.join(eig_path, 'vals.pkl'), 'rb') as fp:
            pid_eig_val = torch.Tensor(pickle.load(fp)[0][:args.modes]).to(device)
        eigenpairs_val[pid] = [pid_eig_vec, pid_eig_val]
        with open(os.path.join(args.data_path_val,pid,'x.pkl'), 'rb') as fp:
            pts = pickle.load(fp)[0]

    ode_model = get_full_model(args.modes, args.width, args.bd_conditions, args.withDiff, args.hs_1, args.d, args.hs_2, args.method, args.rtol, args.atol, device=device)

    ode_model.load_state_dict(torch.load(os.path.join(args.model_save_path, weightsLoc)))
    ode_model.eval()

    # Loss
    loss_fn = nn.MSELoss()

    run_params_val = {'model': ode_model, 'eigenpairs': eigenpairs_val, 'bd_conditions': args.bd_conditions, 'device': device}

    if args.withDiff == 'withDiff':
        run_params_val['withDiff'] = args.withDiff

    test_step = functools.partial(ttf.validation_fn, **run_params_val)
    #higher_time_resolution_step = functools.partial(ttf.higher_time_resolution_fn, **run_params_val)

    # Testing
    diffs_over_time = []
    diffs_over_time_rae = []
    diffs_over_time_rse = []
    rel_grad_over_time = []
    all_lags = []
    all_corr = []
    losses = torch.zeros(len(val_loader))

    inds_of_sims_to_show = {}
    # inds_of_sims_to_show = {}

    with torch.no_grad():
        for i, dp in enumerate(val_loader):
            u_pd, u = test_step(None, dp)
            # u_pd_higer_time_rez = higher_time_resolution_step(None, dp)

            loss = loss_fn(u_pd, u)
            losses[i] = loss.item()

            u_pd = u_pd.cpu().detach().numpy()
            u = u.cpu().detach().numpy()

            eps = 1.0e-6
            diffs = [np.linalg.norm(u[i].reshape(-1) - u_pd[i].reshape(-1)) / (np.linalg.norm(u[i].reshape(-1)) + eps)
                     for i in range(len(u))]

            diffsrae = [np.sum(np.abs(u[i] - u_pd[i])) / np.sum(np.abs(u[i] - np.mean(u[i]))) for i in range(len(u))]

            diffsrmse = [np.sum(np.square(u[i] - u_pd[i])) / np.sum(np.square(u[i] - np.mean(u[i]))) for i in
                         range(len(u))]
            
            grad_u_pd = weighted_graph_grad(torch.Tensor(u_pd.squeeze()), dp.edge_index, dp.edge_weights)
            
            rel_grad = torch.linalg.norm(grad_u_pd-dp.grad_y, axis=0)/torch.linalg.norm(dp.grad_y, axis=0)

            lags = [lag(u[x], u_pd[x]) for x in range(len(u))]

            correlation_coefficients = np.zeros(len(u_pd))
            for idx in range(len(u_pd)):
                correlation_coefficients[idx] = np.corrcoef(u_pd.squeeze().T[idx], u.squeeze().T[idx])[0, 1]

            diffs_over_time.append(diffs)
            diffs_over_time_rae.append(diffsrae)
            diffs_over_time_rse.append(diffsrmse)
            rel_grad_over_time.append(rel_grad.cpu().numpy())
            all_lags.append(abs(np.asarray(lags)).mean())
            all_corr.append(np.asarray(correlation_coefficients).mean())

            print("test case {:>5d} | test loss: {:>7.12f}".format(i, losses[i]))

            if i in inds_of_sims_to_show:
                print("Plotting...")
                save_dir = os.path.join(args.model_save_path, str(i))
                if not os.path.exists(save_dir):
                    os.makedirs(save_dir)
                # utils.plot_grid(dataset_val[i].pos.cpu().detach().numpy(), args.model_save_path)
                utils.plot_3d_object_ortho(t=dataset_val[i].t.cpu().detach().numpy(), coords=pts, fields={
                        "y_pd": u_pd,
                        "y_gt": u,
                        "sq_er": np.square(u-u_pd)
                    }, save_path=save_dir)
                np.savetxt(os.path.join(save_dir, 'u.dat'), u.squeeze())
                np.savetxt(os.path.join(save_dir, 'u_pd.dat'), u_pd.squeeze())
                if save_higer_time_rez:
                    np.savetxt(os.path.join(save_dir, 'u_pd_higher_time_rez.dat'), higher_time_resolution_step(None, dp).cpu().detach().numpy().squeeze())

            if i in inds_of_sims_to_show:
                for j, indj in enumerate(u_pd):
                    tt = dataset_val[i].t[j].cpu().detach().numpy()
                    plt.close('all')
                    plt.figure(10)
                    plt.plot(np.matmul(eigvecs.T, u_pd[j]), '--', linewidth=5)  #dp.redfor.cpu().detach().numpy()
                    plt.plot(np.matmul(eigvecs.T, u[j]), linewidth=5)  #dp.redfor.cpu().detach().numpy()
                    plt.ylim([-10, 10])
                    plt.legend(['Predicted', 'Ground Truth'])
                    save_path1 = os.path.join(args.model_save_path, str(i))
                    plt.savefig(os.path.join(save_path1, 'spectrum' + 't={:.4f}.png'.format(tt)))


    print("Plotting diffs...")
    plt.figure()
    t = dataset_val[0].t.numpy()

    for diff in diffs_over_time:
        plt.plot(t, diff, alpha=0.5)

    plt.plot(t, np.mean(diffs_over_time, axis=0), '--k')

    plt.ylabel("Rel. diff.")
    plt.xlabel("t (sec)")
    # plt.show()
    plt.savefig(os.path.join(args.model_save_path, "diffs.png"))

    diffs_over_time = np.array(diffs_over_time)
    print("diffs_over_time.shape", diffs_over_time.shape)
    print("diffs_over_time.mean", diffs_over_time.mean())
    print("diffs_over_time.mean", diffs_over_time.mean(axis=0))

    plt.close('all')

    plt.figure()
    t = dataset_val[0].t.numpy()

    for diff in diffs_over_time_rae:
        plt.plot(t, diff, alpha=0.5)

    plt.plot(t, np.mean(diffs_over_time_rae, axis=0), '--k')

    plt.ylabel("Rel. Abs diff.")
    plt.xlabel("t (sec)")
    plt.ylim(0, .5)
    plt.savefig(os.path.join(args.model_save_path, "diffs_rae.png"))

    plt.close('all')

    plt.figure()
    t = dataset_val[0].t.numpy()

    for diff in diffs_over_time_rse:
        plt.plot(t, diff, alpha=0.5)

    plt.plot(t, np.mean(diffs_over_time_rse, axis=0), '--k')

    plt.ylabel("Rel. Square diff.")
    plt.xlabel("t (sec)")
    plt.ylim(0, .5)
    # plt.show()
    plt.savefig(os.path.join(args.model_save_path, "diffs_rse.png"))

    plt.close('all')

    plt.figure()
    t = dataset_val[0].t.numpy()

    for diff in rel_grad_over_time:
        plt.plot(t, diff, alpha=0.5)

    plt.plot(t, np.mean(rel_grad_over_time, axis=0), '--k')

    plt.ylabel("Rel. Grad")
    plt.xlabel("t (sec)")
    plt.ylim(0, 1)
    # plt.show()
    plt.savefig(os.path.join(args.model_save_path, "rel_grad.png"))

    plt.close('all')

    with open(os.path.join(args.model_save_path, "metrics"), 'w') as f:
        print("diffs_over_time.shape", diffs_over_time.shape, file=f)
        print("diffs_over_time.mean", diffs_over_time.mean(), file=f)
        print("diffs_over_time.mean", diffs_over_time.mean(axis=0), file=f)
        print('mean_lag = ' + str(np.asarray(all_lags).mean()), file=f)
        print('mean_corr = ' + str(np.asarray(all_corr).mean()), file=f)


pather = 'path to arguments.pkl'
weightLocations = 'path to model weights'
main(pather, weightLocations)
