import os.path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import Sequential, GCNConv
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt

import operator
from functools import reduce
from functools import partial
from timeit import default_timer

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from custom_metrics import RelativeRootMeanSquaredError, RelativeRootMeanSquaredErrorScore
from ignite.contrib.handlers import ProgressBar
from ignite.engine import Events, Engine
from ignite.metrics import Loss
from model import get_full_model
import argparse
import config
import functools
import ignite.contrib.engines.common as common
from ignite.handlers.param_scheduler import LRScheduler

from torch.optim.lr_scheduler import StepLR
import torch
import torch.nn as nn
import torch.optim as optim
import train_test_fns as ttf
import pickle
import json
import os
import shutil

def main(args):
    device = torch.device(config.DEVICE)

    weight=''
    if args.weight != False:
        weight = '_weight'

    fibers = args.fibers

    # Data
    # Data should be loaded for each instance seperate and each feature is loased by number of nodes by length of feature 
    train_loader = ttf.get_data_loader(args.data_path, args.batch_size, knn=args.knn, modes=args.modes, weight=weight,
                                       fibers=fibers,
                                       withDiff=args.withDiff, num_samps=None, newres=args.new_res)  #  training
                                       
    train_val_loader = ttf.get_data_loader(args.data_path, args.batch_size, knn=args.knn, modes=args.modes, weight=weight,
                                       fibers=fibers,
                                       withDiff=args.withDiff, num_samps=100, newres=args.new_res)                                   
    val_loader = ttf.get_data_loader(args.data_path_val, 1, knn=args.knn, modes=args.modes, weight=weight,
                                     fibers=fibers, withDiff=args.withDiff, num_samps=None, newres=args.new_res)

    ode_model = get_full_model(args.modes, args.width, args.bd_conditions, args.withDiff, args.hs_1, args.d, args.hs_2, args.method, args.rtol, args.atol, device=device)

    optimizer = optim.Adam(ode_model.parameters(), lr=args.lr)  
    torch_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.5)
    scheduler = LRScheduler(torch_lr_scheduler)

    criterion = nn.MSELoss()
    val_metrics = {
        'mse': Loss(criterion),
        'rel_rmse': RelativeRootMeanSquaredError(),
        'rel_rmse_score': RelativeRootMeanSquaredErrorScore()
    }

    # Train and validation steps
    run_params = {'model': ode_model, 'bd_conditions': args.bd_conditions, 'device': device}

    if args.withDiff == 'withDiff':
        run_params['withDiff'] = args.withDiff

    train_step = functools.partial(ttf.train_fn, optimizer=optimizer, loss_criterion=criterion, **run_params)
    validation_step = functools.partial(ttf.validation_fn, **run_params)

    trainer = Engine(train_step)
    trainer.add_event_handler(Events.EPOCH_STARTED, scheduler)

    train_evaluator = Engine(validation_step)
    validation_evaluator = Engine(validation_step)

    for name, metric in val_metrics.items():
        metric.attach(train_evaluator, name)
        metric.attach(validation_evaluator, name)

    pbar = ProgressBar(persist=True, bar_format="")
    pbar.attach(trainer, ['mse'])

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_val_loader)
        metrics = train_evaluator.state.metrics
        rrmse = metrics['rel_rmse']
        pbar.log_message("Training Results - Epoch: {}  Rel. RMSE: {:.6f} ".format(engine.state.epoch, rrmse))
        file1 = open(os.path.join(args.model_save_path, 'training_rrmse_loss.txt'), "a")  # append mode
        file1.write(str(rrmse) + '\n')
        file1.close()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        validation_evaluator.run(val_loader)
        metrics = validation_evaluator.state.metrics
        rrmse = metrics['rel_rmse']
        pbar.log_message("Validation Results - Epoch: {}  Rel. RMSE: {:.6f} ".format(engine.state.epoch, rrmse))
        pbar.n = pbar.last_print_n = 0
        file1 = open(os.path.join(args.model_save_path, 'val_rrmse_loss.txt'), "a")  # append mode
        file1.write(str(rrmse) + '\n')
        file1.close()


    common.save_best_model_by_val_score(
        args.model_save_path,
        evaluator=validation_evaluator,
        model=ode_model,
        metric_name="rel_rmse_score",
        n_saved=2,
        trainer=trainer,
        tag="val",
    )

    pickle.dump(args, open(args.model_save_path + '/arguments.pkl', 'wb'))
    
    with open(args.model_save_path + '/model_info.txt', 'w') as openfile:
        print(trainer.__dict__, file=openfile)
        openfile.close()

    shutil.copy('model.py', os.path.join(args.model_save_path, 'model.py'))

    with open(os.path.join(args.model_save_path,'args.txt'), 'w') as f:
        json.dump(args.__dict__, f, indent=2)

    trainer.run(train_loader, max_epochs=500)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    # Path arguments: data, model save path, etc.
    parser.add_argument('--data_path', type=str, default=config.TRAIN_DATA)
    parser.add_argument('--data_path_val', type=str, default=config.VALIDATION_DATA)
    parser.add_argument('--model_save_path', type=str, default=config.MODELS_DIR)  # data path

    # PDE-related arguments
    parser.add_argument('--bd_conditions', type=str, default=None)
    parser.add_argument('--withDiff', type=str, default='withDiff') #withDiff

    # Arguments for the ODE solver used
    parser.add_argument('--rtol', type=float, default=1e-5)  # relative tolerance for the ode integrator
    parser.add_argument('--atol', type=float, default=1e-5)  # absolute tolerance for the ode integrator
    parser.add_argument('--method', type=str, default='euler')  # ode integrator method rk4

    # Arguments for the MPNN architecture
    parser.add_argument('--hs_1', type=int, default=100)  # hidden layer nodes for message net
    parser.add_argument('--hs_2', type=int, default=100)  # hidden layer nodes for aggr net
    parser.add_argument('--d', type=int, default=40)  # graph feature dimension

    # Arguments for neural network training
    # FIXME: Refactor is needed for batch_size>1 to account for differing time scales
    parser.add_argument('--batch_size', type=int, default=1)  # size of batch
    parser.add_argument('--epochs', type=int, default=None)  # number of epochs for training
    parser.add_argument('--lr', type=float, default=1e-4)  #

    parser.add_argument('--modes', type=int, default=75)  # number of modes to use in training
    parser.add_argument('--width', type=int, default=50)  # width of the GFT output channel
    parser.add_argument('--knn', type=int, default=30)  # number of knn to use as nodal connections
    parser.add_argument('--weight', type=bool, default=True)  # Use weights or not
    parser.add_argument('--fibers', type=bool, default=True)  # Use fibers in training?
    parser.add_argument('--new_res', type=bool, default=None)  # Change Resolution of Original graph? 150

    arguments = parser.parse_args()

    main(arguments)

