import os.path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import Sequential, GCNConv
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt

import operator
from functools import reduce
from functools import partial
from timeit import default_timer

# from Adam import Adam

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from custom_metrics import RelativeRootMeanSquaredError, RelativeRootMeanSquaredErrorScore
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events, Engine
from ignite.metrics import Loss
from model import get_full_model
import argparse
import config
import functools
import ignite.contrib.engines.common as common
import torch
import torch.nn as nn
import torch.optim as optim
import train_test_fns as ttf
import pickle
import json
import os

def main(args):
    device = 'cuda'

    # Load Data
    train_loader = ttf.get_data_loader(args.data_path, args.batch_size, knn=args.knn, modes=args.modes, weight=weight,
                                       fibers=fibers, bd_conditions=args.bd_conditions,
                                       withDiff=args.withDiff, num_samps=500)  #  training
    val_loader = ttf.get_data_loader(args.data_path_val, 1, knn=args.knn, modes=args.modes, weight=weight,
                                     fibers=fibers,
                                     bd_conditions=args.bd_conditions, withDiff=args.withDiff, num_samps=50)

    # Model, optimizer, metrics
    ode_model = get_full_model(args.modes, args.width, args.bd_conditions, args.withDiff, args.hs_1, args.d, args.hs_2, args.method, args.rtol, args.atol, device=device)

    optimizer = optim.Adam(ode_model.parameters(), lr=args.lr)  # , step_sizes=(1e-8, 10.))

    criterion = nn.MSELoss()
    val_metrics = {
        'mse': Loss(criterion),
        'rel_rmse': RelativeRootMeanSquaredError(),
        'rel_rmse_score': RelativeRootMeanSquaredErrorScore()
    }

    # Train and validation steps
    run_params = {'model': ode_model, 'bd_conditions': args.bd_conditions, 'device': device}

    if args.withDiff == 'withDiff':
        run_params['withDiff'] = args.withDiff

    train_step = functools.partial(ttf.train_fn, optimizer=optimizer, loss_criterion=criterion, **run_params)
    validation_step = functools.partial(ttf.validation_fn, **run_params)

    trainer = Engine(train_step)
    train_evaluator = Engine(validation_step)
    validation_evaluator = Engine(validation_step)

    for name, metric in val_metrics.items():
        metric.attach(train_evaluator, name)
        metric.attach(validation_evaluator, name)

    pbar = ProgressBar(persist=True, bar_format="")
    pbar.attach(trainer, ['mse'])

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        rrmse = metrics['rel_rmse']
        pbar.log_message("Training Results - Epoch: {}  Rel. RMSE: {:.6f} ".format(engine.state.epoch, rrmse))
        file1 = open(os.path.join(args.model_save_path, 'training_rrmse_loss.txt'), "a")  # append mode
        file1.write(str(rrmse) + '\n')
        file1.close()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        validation_evaluator.run(val_loader)
        metrics = validation_evaluator.state.metrics
        rrmse = metrics['rel_rmse']
        pbar.log_message("Validation Results - Epoch: {}  Rel. RMSE: {:.6f} ".format(engine.state.epoch, rrmse))
        pbar.n = pbar.last_print_n = 0
        file1 = open(os.path.join(args.model_save_path, 'val_rrmse_loss.txt'), "a")  # append mode
        file1.write(str(rrmse) + '\n')
        file1.close()


    common.save_best_model_by_val_score(
        args.model_save_path,
        evaluator=validation_evaluator,
        model=ode_model,
        metric_name="rel_rmse_score",
        n_saved=2,
        trainer=trainer,
        tag="val",
    )

    pickle.dump(args, open(args.model_save_path + '/arguments.pkl', 'wb'))
    
    with open(args.model_save_path + '/model_info.txt', 'w') as openfile:
        print(trainer.__dict__, file=openfile)
        openfile.close()

    with open(os.path.join(args.model_save_path,'args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    trainer.run(train_loader, max_epochs=5000)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Path arguments: data, model save path, etc.
    parser.add_argument('--data_path', type=str, default=config.TRAIN_DATA)
    parser.add_argument('--data_path_val', type=str, default=config.VALIDATION_DATA)
    parser.add_argument('--model_save_path', type=str, default=config.MODELS_DIR)  # data path

    # PDE-related arguments
    parser.add_argument('--bd_conditions', type=str, default='neumann')
    parser.add_argument('--withDiff', type=str, default='withDiff') #withDiff

    # Arguments for the ODE solver used
    parser.add_argument('--rtol', type=float, default=1e-5)  # relative tolerance for the ode integrator
    parser.add_argument('--atol', type=float, default=1e-5)  # absolute tolerance for the ode integrator
    parser.add_argument('--method', type=str, default='euler')  # ode integrator method rk4

    # Arguments for the MPNN architecture
    parser.add_argument('--hs_1', type=int, default=100)  # hidden layer nodes for message net
    parser.add_argument('--hs_2', type=int, default=100)  # hidden layer nodes for aggr net
    parser.add_argument('--d', type=int, default=40)  # graph feature dimension

    # Arguments for neural network training
    # FIXME: Refactor is needed for batch_size>1 to account for differing time scales
    parser.add_argument('--batch_size', type=int, default=1)  # size of batch
    parser.add_argument('--epochs', type=int, default=None)  # number of epochs for training
    parser.add_argument('--lr', type=float, default=5e-5)  #

    parser.add_argument('--modes', type=int, default=50)  # number of modes to use in training
    parser.add_argument('--width', type=int, default=200)  # width of the GFT output channel
    parser.add_argument('--knn', type=int, default=30)  # number of knn to use as nodal connections
    parser.add_argument('--weight', type=bool, default=True)  # Use weights or not
    parser.add_argument('--fibers', type=bool, default=True)  # Use fibers in training?

    arguments = parser.parse_args()

    main(arguments)


