from model import get_full_model
from train_test_fns import validation_fn
import argparse
import config
import functools
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import train_test_fns as ttf
import utils
import pickle


def main(inFile, weightsLoc):
    with open(inFile, 'rb') as inFile:
        args = pickle.load(inFile)

    if args.weight != False:
        weight = '_weight'

    fibers = args.fibers

    device = torch.device(config.DEVICE)

    val_loader = ttf.get_data_loader(args.data_path_val, 1, knn=args.knn, modes=args.modes, weight=weight, fibers=fibers, bd_conditions=args.bd_conditions, withDiff=args.withDiff)  # validation
    dataset_val = val_loader.dataset
    ode_model = get_full_model(args.modes, args.width, args.bd_conditions, args.withDiff, args.hs_1, args.d, args.hs_2, args.method, args.rtol, args.atol, device=device)

    ode_model.load_state_dict(torch.load(os.path.join(args.model_save_path, weightsLoc)))
    ode_model.eval()

    # Loss
    loss_fn = nn.MSELoss()

    run_params = {'model': ode_model, 'bd_conditions': args.bd_conditions, 'device': device}
    if args.withDiff == 'withDiff':
        run_params['withDiff'] = args.withDiff

    test_step = functools.partial(validation_fn, **run_params)

    # Testing
    diffs_over_time = []
    losses = torch.zeros(len(val_loader))

    inds_of_sims_to_show = {0, 1, 2, 3}
    # inds_of_sims_to_show = {}

    with torch.no_grad():
        for i, dp in enumerate(val_loader):
            u_pd, u = test_step(None, dp)

            loss = loss_fn(u_pd, u)
            losses[i] = loss.item()

            u_pd = u_pd.cpu().detach().numpy()
            u = u.cpu().detach().numpy()

            eps = 1.0e-6
            diffs = [np.linalg.norm(u[i].reshape(-1) - u_pd[i].reshape(-1)) / (np.linalg.norm(u[i].reshape(-1)) + eps)
                     for i in range(len(u))]
            diffs_over_time.append(diffs)

            print("test case {:>5d} | test loss: {:>7.12f}".format(i, losses[i]))

            if i in inds_of_sims_to_show:
                print("Plotting...")
                # utils.plot_grid(dataset_val[i].pos.cpu().detach().numpy(), args.model_save_path)
                utils.plot_fields(
                    t=dataset_val[i].t.cpu().detach().numpy(),
                    coords=dataset_val[i].pos.cpu().detach().numpy(),
                    fields={
                        "y_pd": u_pd,
                        "y_gt": u,
                    },
                    save_path=os.path.join(args.model_save_path, str(i)),
                )

            if i in inds_of_sims_to_show:
                for j, indj in enumerate(u_pd):
                    tt = dataset_val[i].t[j].cpu().detach().numpy()
                    plt.close('all')
                    plt.figure(10)
                    plt.plot(np.matmul(dataset_val[i].redfor.cpu().detach().numpy(), u_pd[j]), '--', linewidth=5)
                    plt.plot(np.matmul(dataset_val[i].redfor.cpu().detach().numpy(), u[j]), linewidth=5)
                    save_path1 = os.path.join(args.model_save_path, str(i))
                    plt.savefig(os.path.join(save_path1, 'spectrum' + 't={:.4f}.png'.format(tt)))

            if i in inds_of_sims_to_show:
                for j, indj in enumerate(u_pd):
                    tt = dataset_val[i].t[j].cpu().detach().numpy()
                    plt.close('all')
                    plt.figure(10)
                    plt.plot(np.matmul(dataset_val[i].redfor.cpu().detach().numpy(), u_pd[j]), '--', linewidth=5)
                    plt.plot(np.matmul(dataset_val[i].redfor.cpu().detach().numpy(), u[j]), linewidth=5)
                    save_path1 = os.path.join(args.model_save_path, str(i))
                    plt.savefig(os.path.join(save_path1, 'spectrum' + 't={:.4f}.png'.format(tt)))


    print("Plotting diffs...")
    plt.figure()
    t = dataset_val[0].t.numpy()

    for diff in diffs_over_time:
        plt.plot(t, diff, alpha=0.5)

    plt.plot(t, np.mean(diffs_over_time, axis=0), '--k')

    plt.ylabel("Rel. diff.")
    plt.xlabel("t (sec)")
    # plt.show()
    plt.savefig(os.path.join(args.model_save_path, "diffs.png"))

    diffs_over_time = np.array(diffs_over_time)
    print("diffs_over_time.shape", diffs_over_time.shape)
    print("diffs_over_time.mean", diffs_over_time.mean())
    print("diffs_over_time.mean", diffs_over_time.mean(axis=0))


pather = 'path to arguments.pkl'
weightLocations = 'path to model weight.pt'
main(pather, weightLocations)

