import torch
import torch.nn as nn
from torch_geometric.data import Data
import numpy as np
import os, pickle, random
from torch_geometric.loader import DataLoader
# import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import argparse
import yaml
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import deepspeed

from models import EGATTransformerNetwork
from tools.EarlyStopping import EarlyStopping
from loss_function import loss_function, evaluate_model
from get_data import data_loader, data_loader_index
from save_results import save_model_results, save_loss_results

def train(model, args, cmd_args, model_number):
    early_stopper = EarlyStopping(patience=50, min_delta=0.001)
    train_loader, val_loader, test_loader = data_loader_index(args)

    model_engine, _, _, _ = deepspeed.initialize(
        args=cmd_args,
        model=model,
        model_parameters=model.parameters()
    )
    losses = []

    for epoch in tqdm(range(args['n_epochs']), desc="TRAINING"):
        start_time = time.time()
        print('New EPOCH')
        epoch_loss = 0
        model_engine.train()
        for data in train_loader:
            data.cuda()
            net_E, net_H = model_engine(data)
            loss = loss_function(net_E, net_H, data.y_E, data.y_H)
            epoch_loss += loss.item()
            model_engine.backward(loss)
            model_engine.step()

        # scheduler.step(val_loss)

        if epoch % args['save_loss_interval'] == 0:
            val_loss = evaluate_model(model_engine, val_loader)
            train_loss = epoch_loss / len(train_loader)
            losses.append((epoch, train_loss, val_loss))
            save_loss_results(args, losses, model_number)
            t_cost = time.time() - start_time
            print("Epoch: {} Train loss: {:.2e} Validation loss: {:.2e} Time cost: {:.2e}".format(epoch, train_loss, val_loss, t_cost))

        if epoch % args['save_model_interval'] == 0:
            save_model_results(args, model_engine, model_number, epoch)

        test_loss = evaluate_model(model_engine, test_loader)
        print("Average MSE loss on Test Data: {:.2e}".format(test_loss))
        re_train_loss, re_train_time = relative_error(train_loader, model_engine)
        re_test_loss, re_test_time = relative_error(test_loader, model_engine)
        print(f"re_train_loss:{re_train_loss}, re_train_time:{re_train_time}")
        print(f"re_test_loss:{re_test_loss}, re_test_time:{re_test_time}")

    return losses

def parse_arguments():
    parser = argparse.ArgumentParser(description='deepspeed training script.')
    parser.add_argument('--local_rank', type=int, default=-1,
                       help='local rank passed from distributed launcher')
    parser.add_argument('--config_yaml', type=str,
                       default='./configs/model.yml')
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()
    return args

def relative_error(dataset, model_engine):
    loss = 0.0
    total_time = 0
    model_engine.eval()
    with torch.no_grad():
        for data in dataset:
            data.cuda()
            start_time = time.time()
            net_E, net_H = model_engine(data)
            end_time = time.time()
            total_time += end_time - start_time
            loss_up = (torch.sum((net_E - data.y_E) ** 2) + torch.sum((net_H - data.y_H) ** 2)).cpu().detach().numpy()
            loss_down = (torch.sum(data.y_E ** 2) + torch.sum(data.y_H ** 2)).cpu().detach().numpy()
            loss += loss_up / loss_down
    return loss / len(dataset), total_time / len(dataset)

if __name__ == '__main__':
    print("program started")
    torch.cuda.empty_cache()
    cmd_args = parse_arguments()
    deepspeed.init_distributed()
    if cmd_args.config_yaml:
        with open(cmd_args.config_yaml, 'r') as f:
            args = yaml.safe_load(f)
    else:
        args = {}

    model_number = 4
    gnn = EGATTransformerNetwork()

    print("Training initial GNN...")
    train(gnn, args, cmd_args, model_number)
    
    
# deepspeed --num_gpus=2 --master_port=25653 train_model.py --deepspeed --deepspeed_config configs/ds_config.json 