import os.path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import Sequential, GCNConv
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt

import operator
from functools import reduce
from functools import partial
from timeit import default_timer

# from Adam import Adam

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from custom_metrics import RelativeRootMeanSquaredError, RelativeRootMeanSquaredErrorScore
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events, Engine
from ignite.metrics import Loss
from model import get_full_model
import argparse
import config
import functools
import ignite.contrib.engines.common as common
from ignite.handlers.param_scheduler import LRScheduler

from torch.optim.lr_scheduler import StepLR
import torch
import torch.nn as nn
import torch.optim as optim
import train_test_fns as ttf
import pickle
import json
import os
import shutil
from utils import direct_mapping, to_reference_neighbors

# torch.manual_seed(0)
# np.random.seed(0)

def main(args):
    device = torch.device(config.DEVICE)

    weight=''
    if args.weight != False:
        weight = '_weight'

    mapping = 'Not_Direct'

    #fibers = args.fibers

    # Data
    train_loader = ttf.get_data_loader(args.data_path, args.batch_size, knn=args.knn, modes=args.modes, weight=weight,
                                       fibers=args.fibers,
                                       withDiff=args.withDiff, num_samps=None, newres=args.new_res)  #  training
                                       
    train_val_loader = ttf.get_data_loader(args.data_path, args.batch_size, knn=args.knn, modes=args.modes, weight=weight,
                                       fibers=args.fibers,
                                       withDiff=args.withDiff, num_samps=10, newres=args.new_res)                                                                         
    val_loader = ttf.get_data_loader(args.data_path_val, 1, knn=args.knn, modes=args.modes, weight=weight,
                                     fibers=args.fibers, withDiff=args.withDiff, num_samps=None, newres=args.new_res)

    if mapping == 'Direct':
        eigenpairs_train = {}
        for idx, pid in enumerate(os.listdir(args.data_path)):
            print(pid)
            if idx == 0:
                pid_ref = pid
                with open(os.path.join(args.data_path,pid,'UAC.pkl'), 'rb') as fp:
                    ref_UAC = pickle.load(fp)[0]
                eig_path = os.path.join(os.path.join(args.data_path,pid), 'knn='+str(args.knn)+weight, 'HK')
                with open(os.path.join(eig_path, 'vecs.pkl'), 'rb') as fp:
                    ref_eig_vec = pickle.load(fp)[0]
                    pid_eig_vec = torch.Tensor(ref_eig_vec[:, :args.modes].T).to(device)
                with open(os.path.join(eig_path, 'vals.pkl'), 'rb') as fp:
                    ref_eig_val = pickle.load(fp)[0][:args.modes]
                    pid_eig_val = torch.Tensor(np.eye(len(ref_eig_val))*ref_eig_val).to(device)
            else:
                with open(os.path.join(args.data_path,pid,'UAC.pkl'), 'rb') as fp:
                    mapped_UAC = pickle.load(fp)[0]
                eig_path = os.path.join(os.path.join(args.data_path,pid), 'knn='+str(args.knn)+weight, 'HK')
                with open(os.path.join(eig_path, 'vecs.pkl'), 'rb') as fp:
                    pid_eig_vec_original = pickle.load(fp)[0]#[:, :args.modes]
                with open(os.path.join(eig_path, 'lap.pkl'), 'rb') as fp:
                    pid_lap = pickle.load(fp)[0]
                pid_neighs = to_reference_neighbors(ref_UAC, mapped_UAC)
                pid_eig_vec_prime = pid_eig_vec_original[pid_neighs].mean(axis=1)
                C = np.linalg.pinv(pid_eig_vec_prime)@ref_eig_vec
                pid_eig_vec = pid_eig_vec_original@C
                pid_eig_vec /= np.linalg.norm(pid_eig_vec, axis=0)
                pid_eig_val = np.linalg.pinv(pid_eig_vec)@pid_lap@pid_eig_vec #np.diagonal()

                pid_eig_vec = torch.Tensor(pid_eig_vec[:, :args.modes].T).to(device)
                pid_eig_val = torch.Tensor(pid_eig_val[:args.modes, :args.modes]).to(device)


            eigenpairs_train[pid] = [pid_eig_vec, torch.linalg.pinv(pid_eig_vec), pid_eig_val]

        eigenpairs_val = {}
        for pid in os.listdir(args.data_path_val):
            with open(os.path.join(args.data_path_val,pid,'UAC.pkl'), 'rb') as fp:
                mapped_UAC = pickle.load(fp)[0]
            eig_path = os.path.join(os.path.join(args.data_path_val,pid), 'knn='+str(args.knn)+weight, 'HK')
            with open(os.path.join(eig_path, 'vecs.pkl'), 'rb') as fp:
                pid_eig_vec_original = pickle.load(fp)[0]#[:, :args.modes]
            with open(os.path.join(eig_path, 'lap.pkl'), 'rb') as fp:
                pid_lap = pickle.load(fp)[0]
            pid_neighs = to_reference_neighbors(ref_UAC, mapped_UAC)
            pid_eig_vec_prime = pid_eig_vec_original[pid_neighs].mean(axis=1)
            C = np.linalg.pinv(pid_eig_vec_prime)@ref_eig_vec
            pid_eig_vec = pid_eig_vec_original@C
            pid_eig_vec /= np.linalg.norm(pid_eig_vec, axis=0)
            pid_eig_val = np.linalg.pinv(pid_eig_vec)@pid_lap@pid_eig_vec #np.diagonal()

            pid_eig_vec = torch.Tensor(pid_eig_vec[:, :args.modes].T).to(device)
            pid_eig_val = torch.Tensor(pid_eig_val[:args.modes, :args.modes]).to(device)
            
            
            eigenpairs_val[pid] = [pid_eig_vec, torch.linalg.pinv(pid_eig_vec), pid_eig_val]

    else:
        eigenpairs_train = {}
        for pid in os.listdir(args.data_path):
            eig_path = os.path.join(os.path.join(args.data_path,pid), 'knn='+str(args.knn)+weight, 'HK')
            with open(os.path.join(eig_path, 'vecs.pkl'), 'rb') as fp:
                pid_eig_vec = torch.Tensor(pickle.load(fp)[0][:, :args.modes].T).to(device)
            with open(os.path.join(eig_path, 'vals.pkl'), 'rb') as fp:
                pid_eig_val = (torch.Tensor(pickle.load(fp)[0][:args.modes])*torch.eye(args.modes)).to(device)
            eigenpairs_train[pid] = [pid_eig_vec, pid_eig_vec.T, pid_eig_val]

        eigenpairs_val = {}
        for pid in os.listdir(args.data_path_val):
            eig_path = os.path.join(os.path.join(args.data_path_val,pid), 'knn='+str(args.knn)+weight, 'HK')
            with open(os.path.join(eig_path, 'vecs.pkl'), 'rb') as fp:
                pid_eig_vec = torch.Tensor(pickle.load(fp)[0][:, :args.modes].T).to(device)
            with open(os.path.join(eig_path, 'vals.pkl'), 'rb') as fp:
                pid_eig_val = (torch.Tensor(pickle.load(fp)[0][:args.modes])*torch.eye(args.modes)).to(device)
            eigenpairs_val[pid] = [pid_eig_vec, pid_eig_vec.T, pid_eig_val]
    
    ode_model = get_full_model(args.modes, args.width, args.bd_conditions, args.withDiff, args.hs_1, args.d, args.hs_2, args.method, args.rtol, args.atol, device=device)

    optimizer = optim.Adam(ode_model.parameters(), lr=args.lr)  # , step_sizes=(1e-8, 10.))
    torch_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)
    scheduler = LRScheduler(torch_lr_scheduler)

    criterion = nn.MSELoss(reduction='none')
    val_metrics = {
        'mse': Loss(nn.MSELoss()),
        'rel_rmse': RelativeRootMeanSquaredError(),
        'rel_rmse_score': RelativeRootMeanSquaredErrorScore()
    }

    # Train and validation steps
    run_params_train = {'model': ode_model, 'eigenpairs': eigenpairs_train, 'bd_conditions': args.bd_conditions, 'device': device}
    run_params_val = {'model': ode_model, 'eigenpairs': eigenpairs_val, 'bd_conditions': args.bd_conditions, 'device': device}

    if args.withDiff == 'withDiff':
        run_params_train['withDiff'] = args.withDiff
        run_params_val['withDiff'] = args.withDiff

    train_step = functools.partial(ttf.train_fn, optimizer=optimizer, loss_criterion=criterion, **run_params_train)
    train_validation_step = functools.partial(ttf.train_validation_fn, **run_params_train)
    validation_step = functools.partial(ttf.validation_fn, **run_params_val)

    trainer = Engine(train_step)
    trainer.add_event_handler(Events.EPOCH_STARTED, scheduler)
    
    train_evaluator = Engine(train_validation_step)
    validation_evaluator = Engine(validation_step)

    for name, metric in val_metrics.items():
        metric.attach(train_evaluator, name)
        metric.attach(validation_evaluator, name)

    pbar = ProgressBar(persist=True, bar_format="")
    pbar.attach(trainer, ['mse'])

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_val_loader)
        metrics = train_evaluator.state.metrics
        rrmse = metrics['rel_rmse']
        pbar.log_message("Training Results - Epoch: {}  Rel. RMSE: {:.6f} ".format(engine.state.epoch, rrmse))
        file1 = open(os.path.join(args.model_save_path, 'training_rrmse_loss.txt'), "a")  # append mode
        file1.write(str(rrmse) + '\n')
        file1.close()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        validation_evaluator.run(val_loader)
        metrics = validation_evaluator.state.metrics
        rrmse = metrics['rel_rmse']
        pbar.log_message("Validation Results - Epoch: {}  Rel. RMSE: {:.6f} ".format(engine.state.epoch, rrmse))
        pbar.n = pbar.last_print_n = 0
        file1 = open(os.path.join(args.model_save_path, 'val_rrmse_loss.txt'), "a")  # append mode
        file1.write(str(rrmse) + '\n')
        file1.close()


    common.save_best_model_by_val_score(
        args.model_save_path,
        evaluator=validation_evaluator,
        model=ode_model,
        metric_name="rel_rmse_score",
        n_saved=2,
        trainer=trainer,
        tag="val",
    )

    pickle.dump(args, open(args.model_save_path + '/arguments.pkl', 'wb'))
    
    with open(args.model_save_path + '/model_info.txt', 'w') as openfile:
        print(trainer.__dict__, file=openfile)
        openfile.close()

    shutil.copy('model.py', os.path.join(args.model_save_path, 'model.py'))
    shutil.copy('config.py', os.path.join(args.model_save_path, 'config.py'))
        
    #with open(args.model_save_path + '/model_layout.txt', 'w') as openfile:
    #   print(ode_model, file=openfile)
    #   openfile.close()

    with open(os.path.join(args.model_save_path,'args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)
    # torch.save(ode_model, args.model_save_path+'/model.pt')
    # torch.save(ode_model.state_dict(), args.model_save_path + '/model_state_dict.pt')

    trainer.run(train_loader, max_epochs=5000)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Path arguments: data, model save path, etc.
    parser.add_argument('--data_path', type=str, default=config.TRAIN_DATA)
    parser.add_argument('--data_path_val', type=str, default=config.VALIDATION_DATA)
    parser.add_argument('--model_save_path', type=str, default=config.MODELS_DIR)  # data path

    # PDE-related arguments
    parser.add_argument('--bd_conditions', type=str, default=None)
    parser.add_argument('--withDiff', type=str, default='withDiff') #withDiff

    # Arguments for the ODE solver used
    parser.add_argument('--rtol', type=float, default=1e-5)  # relative tolerance for the ode integrator
    parser.add_argument('--atol', type=float, default=1e-5)  # absolute tolerance for the ode integrator
    parser.add_argument('--method', type=str, default='euler')  # ode integrator method rk4 euler dopri5

    # Arguments for the MPNN architecture
    parser.add_argument('--hs_1', type=int, default=100)  # hidden layer nodes for message net
    parser.add_argument('--hs_2', type=int, default=100)  # hidden layer nodes for aggr net
    parser.add_argument('--d', type=int, default=40)  # graph feature dimension

    # Arguments for neural network training
    # FIXME: Refactor is needed for batch_size>1 to account for differing time scales
    parser.add_argument('--batch_size', type=int, default=1)  # size of batch
    parser.add_argument('--epochs', type=int, default=None)  # number of epochs for training
    parser.add_argument('--lr', type=float, default=1e-4)  #

    parser.add_argument('--modes', type=int, default=50)  # number of modes to use in training
    parser.add_argument('--width', type=int, default=50)  # width of the GFT output channel
    parser.add_argument('--knn', type=int, default=6)  # number of knn to use as nodal connections
    parser.add_argument('--weight', type=bool, default=True)  # Use weights or not
    parser.add_argument('--fibers', type=bool, default=True)  # Use fibers in training?
    parser.add_argument('--new_res', type=bool, default=None)  # Change Resolution of Original graph? 150

    arguments = parser.parse_args()

    main(arguments)

# ################################################################
# #  configurations
# ################################################################
# ntrain = 1000
# ntest = 100
#
# sub = 2 ** 3  # subsampling rate
# h = 2 ** 13 // sub  # total grid size divided by the subsampling rate
# s = h
#
# batch_size = 20
# learning_rate = 0.001
#
# epochs = 500
# step_size = 50
# gamma = 0.5
#
# modes = 16
# width = 64
#
# ################################################################
# # read data
# ################################################################
#
# # Data is of the shape (number of samples, grid size)
# dataloader = MatReader('data/burgers_data_R10.mat')
# x_data = dataloader.read_field('a')[:, ::sub]
# y_data = dataloader.read_field('u')[:, ::sub]
#
# x_train = x_data[:ntrain, :]
# y_train = y_data[:ntrain, :]
# x_test = x_data[-ntest:, :]
# y_test = y_data[-ntest:, :]
#
# x_train = x_train.reshape(ntrain, s, 1)
# x_test = x_test.reshape(ntest, s, 1)
#
# train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size,
#                                            shuffle=True)
# test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size,
#                                           shuffle=False)
#
# # model
# model = FNO1d(modes, width).cuda()
# print(count_params(model))
#
# ################################################################
# # training and evaluation
# ################################################################
# optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
#
# myloss = LpLoss(size_average=False)
# for ep in range(epochs):
#     model.train()
#     t1 = default_timer()
#     train_mse = 0
#     train_l2 = 0
#     for x, y in train_loader:
#         x, y = x.cuda(), y.cuda()
#
#         optimizer.zero_grad()
#         out = model(x)
#
#         mse = F.mse_loss(out.view(batch_size, -1), y.view(batch_size, -1), reduction='mean')
#         l2 = myloss(out.view(batch_size, -1), y.view(batch_size, -1))
#         l2.backward()  # use the l2 relative loss
#
#         optimizer.step()
#         train_mse += mse.item()
#         train_l2 += l2.item()
#
#     scheduler.step()
#     model.eval()
#     test_l2 = 0.0
#     with torch.no_grad():
#         for x, y in test_loader:
#             x, y = x.cuda(), y.cuda()
#
#             out = model(x)
#             test_l2 += myloss(out.view(batch_size, -1), y.view(batch_size, -1)).item()
#
#     train_mse /= len(train_loader)
#     train_l2 /= ntrain
#     test_l2 /= ntest
#
#     t2 = default_timer()
#     print(ep, t2 - t1, train_mse, train_l2, test_l2)
#
# # torch.save(model, 'model/ns_fourier_burgers')
# pred = torch.zeros(y_test.shape)
# index = 0
# test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=1, shuffle=False)
# with torch.no_grad():
#     for x, y in test_loader:
#         test_l2 = 0
#         x, y = x.cuda(), y.cuda()
#
#         out = model(x).view(-1)
#         pred[index] = out
#
#         test_l2 += myloss(out.view(1, -1), y.view(1, -1)).item()
#         print(index, test_l2)
#         index = index + 1
#
# # scipy.io.savemat('pred/burger_test.mat', mdict={'pred': pred.cpu().numpy()})
