#from model import get_full_model
from train_test_fns import validation_fn
import argparse
import config
import functools
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torch.nn as nn
import train_test_fns as ttf
import utils
import pickle

import sys

from model import get_full_model

def main(inFile, weightsLoc):
    with open(inFile, 'rb') as inFile:
        args = pickle.load(inFile)

    if args.weight != False:
        weight = '_weight'

    fibers = args.fibers

    device = torch.device(config.DEVICE)
    val_loader = ttf.get_data_loader(args.data_path_val, 1, knn=args.knn, modes=args.modes, weight=weight, fibers=fibers, withDiff=args.withDiff, num_samps=None, newres=args.new_res)  # validation
    dataset_val = val_loader.dataset
    ode_model = get_full_model(args.modes, args.width, args.bd_conditions, args.withDiff, args.hs_1, args.d, args.hs_2, args.method, args.rtol, args.atol, device=device)

    ode_model.load_state_dict(torch.load(os.path.join(args.model_save_path, weightsLoc)))
    ode_model.eval()

    # Loss
    loss_fn = nn.MSELoss()

    run_params = {'model': ode_model, 'bd_conditions': args.bd_conditions, 'device': device}
    if args.withDiff == 'withDiff':
        run_params['withDiff'] = args.withDiff

    test_step = functools.partial(validation_fn, **run_params)

    # Testing
    diffs_over_time = []
    diffs_over_time_rae = []
    diffs_over_time_rse = []
    losses = torch.zeros(len(val_loader))

    inds_of_sims_to_show = {0, 1, 2, 3, 4, 5, 6, 7}
    # inds_of_sims_to_show = {}

    with torch.no_grad():
        for i, dp in enumerate(val_loader):
            u_pd, u = test_step(None, dp)

            loss = loss_fn(u_pd, u)
            losses[i] = loss.item()

            u_pd = u_pd.cpu().detach().numpy()
            u = u.cpu().detach().numpy()

            eps = 1.0e-6
            diffs = [np.linalg.norm(u[i].reshape(-1) - u_pd[i].reshape(-1)) / (np.linalg.norm(u[i].reshape(-1)) + eps)
                     for i in range(len(u))]

            diffsrae = [np.sum(np.abs(u[i] - u_pd[i])) / np.sum(np.abs(u[i] - np.mean(u[i]))) for i in range(len(u))]

            diffsrmse = [np.sum(np.square(u[i] - u_pd[i])) / np.sum(np.square(u[i] - np.mean(u[i]))) for i in
                         range(len(u))]

            diffs_over_time.append(diffs)
            diffs_over_time_rae.append(diffsrae)
            diffs_over_time_rse.append(diffsrmse)

            print("test case {:>5d} | test loss: {:>7.12f}".format(i, losses[i]))

            if i in inds_of_sims_to_show:
                print("Plotting...")
                utils.plot_fields(
                    t=dataset_val[i].t.cpu().detach().numpy(),
                    coords=dp.pos.cpu().detach().numpy(),  #dataset_val[i].pos.cpu().detach().numpy(
                    fields={
                        "y_pd": u_pd,
                        "y_gt": u,
                        "sq_er": np.square(u-u_pd)
                    },
                    save_path=os.path.join(args.model_save_path, str(i)),
                )

            if i in inds_of_sims_to_show:
                for j, indj in enumerate(u_pd):
                    tt = dataset_val[i].t[j].cpu().detach().numpy()
                    plt.close('all')
                    plt.figure(10)
                    plt.plot(np.matmul(dp.redfor.cpu().detach().numpy(), u_pd[j]), '--', linewidth=5)
                    plt.plot(np.matmul(dp.redfor.cpu().detach().numpy(), u[j]), linewidth=5)
                    plt.ylim([-10, 10])
                    save_path1 = os.path.join(args.model_save_path, str(i))
                    plt.savefig(os.path.join(save_path1, 'spectrum' + 't={:.4f}.png'.format(tt)))



    print("Plotting diffs...")
    plt.figure()
    t = dataset_val[0].t.numpy()

    for diff in diffs_over_time:
        plt.plot(t, diff, alpha=0.5)

    plt.plot(t, np.mean(diffs_over_time, axis=0), '--k')

    plt.ylabel("Rel. diff.")
    plt.xlabel("t (sec)")
    plt.savefig(os.path.join(args.model_save_path, "diffs.png"))

    diffs_over_time = np.array(diffs_over_time)
    print("diffs_over_time.shape", diffs_over_time.shape)
    print("diffs_over_time.mean", diffs_over_time.mean())
    print("diffs_over_time.mean", diffs_over_time.mean(axis=0))

    plt.close('all')

    plt.figure()
    t = dataset_val[0].t.numpy()

    for diff in diffs_over_time_rae:
        plt.plot(t, diff, alpha=0.5)

    plt.plot(t, np.mean(diffs_over_time_rae, axis=0), '--k')

    plt.ylabel("Rel. Abs diff.")
    plt.xlabel("t (sec)")
    plt.ylim(0, .5)
    plt.savefig(os.path.join(args.model_save_path, "diffs_rae.png"))

    plt.close('all')

    plt.figure()
    t = dataset_val[0].t.numpy()

    for diff in diffs_over_time_rse:
        plt.plot(t, diff, alpha=0.5)

    plt.plot(t, np.mean(diffs_over_time_rse, axis=0), '--k')

    plt.ylabel("Rel. Square diff.")
    plt.xlabel("t (sec)")
    plt.ylim(0, .5)
    plt.savefig(os.path.join(args.model_save_path, "diffs_rse.png"))



pather = 'path to arguments.pkl'
weightLocations = 'path to model weights.pt'
main(pather, weightLocations)


