# calculate_relative_error.py
# deepspeed
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, random_split, Subset
import os, pickle
import time
import deepspeed
import yaml
import torch.nn.functional as F

from models import EGATTransformerNetwork
from load_data import load_data
from loss_function import loss_function, evaluate_model
from get_data import GNNDataset, GNNDataset_ontime_loader

model_number = 2
epoch = 980
model = EGATTransformerNetwork()
deepspeed_config = {
    "train_batch_size": 8,
    "gradient_accumulation_steps": 4,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 0.001,
            "betas": [0.8, 0.999],
            "eps": 1e-08,
            "weight_decay": 3e-07
        }
    },
    "zero_optimization": {
        "stage": 0
    },
    "fp16": {
        "enabled": False
    },
    "deepspeed": {
        "disable_mpi": True
    }
}

config_yaml = '../configs/model.yaml'
with open(config_yaml, 'r') as f:
    args = yaml.safe_load(f)

model_engine,_,_,_ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config_params=deepspeed_config,
)

model_type = args['model']
load_path, _ = model_engine.load_checkpoint(f'./results/model/{model_type}/model{model_number}')
model_engine.eval()

def relative_error(dataset, idx, prefix):
    loss = 0.0
    total_time = 0
    with open(f"./results/relative_error/data_loss{model_number}.txt", "a") as f:
        for i in range(len(dataset)):
            j = idx[i]
            # print(j)
            start_time = time.time()
            data = dataset[i].cuda()
            net_E, net_H = model_engine(data)
            end_time = time.time()
            total_time += end_time - start_time
            loss_up = (torch.sum((net_E - data.y_E)**2) + torch.sum((net_H - data.y_H)**2)).cpu().detach().numpy()
            loss_down = (torch.sum(data.y_E**2) + torch.sum(data.y_H**2)).cpu().detach().numpy()
            loss += loss_up / loss_down
            loss_mse = loss_function(net_E, net_H, data.y_E, data.y_H)
            print(f"loss_mse: {loss_mse}")
            f.write(f"{prefix} Index: {j}\n")
            f.write(f"loss up: {loss_up}\n")
            f.write(f"loss down: {loss_down}\n")
            f.write(f"Relative Error: {(loss_up / loss_down)}\n")
            print(f"Relative Error: {(loss_up / loss_down)}")
            # print(f"loss_train: {F.mse_loss(E_net, data.y)}")
            f.write(f"net output: {net_E[:5, :]}\n")
            f.write(f"real output: {data.y_E[:5, :]}\n")
            f.write(f"error output: {net_E[:5, :] - data.y_E[:5, :]}\n")
    return loss / len(idx), total_time / len(idx)

def compare_relative_error(dataset, idx, prefix):
    with open(f"./results/relative_error/data_loss{model_number}.txt", "a") as f:
        for i in range(len(dataset)):
            j = idx[i]
            data = dataset[i].cuda()
            net_E, net_H = model_engine(data)
            end_time = time.time()
            total_time += end_time - start_time
            loss_up = (torch.sum((net_E - data.y_E)**2) + torch.sum((net_H - data.y_H)**2)).cpu().detach().numpy()
            loss_down = (torch.sum(data.y_E**2) + torch.sum(data.y_H**2)).cpu().detach().numpy()
            loss += loss_up / loss_down
            loss_mse = loss_function(net_E, net_H, data.y_E, data.y_H)
            print(f"loss_mse: {loss_mse}")
            print(f"Relative Error: {(loss_up / loss_down)}")
    return loss_short / index_short, loss_long / index_long, index_short, index_long

with open('./indices/indices_100_80_10_10.pkl', 'rb') as f:
    indices = pickle.load(f)
train_idx = indices['train_idx']
test_idx = indices['test_idx']
# np.random.seed(42)
# select_train_idx = np.random.choice(train_idx, 100)
# select_test_idx = np.random.choice(test_idx, 100)
select_train_idx = train_idx
select_test_idx = test_idx

dataset = GNNDataset_ontime_loader(args)
select_traindataset = Subset(dataset, select_train_idx)
select_testdataset = Subset(dataset, select_test_idx)
train_loss, train_time = relative_error(select_traindataset, select_train_idx, 'train')
test_loss, test_time = relative_error(select_testdataset, select_test_idx, 'test')

with open(f"./results/relative_error/data_loss{model_number}.txt", "a") as f:
    f.write(f"epoch: {epoch}\n")
    f.write(f"Train Final Relative Error: {train_loss}\n")
    f.write(f"Test Final Relative Error: {test_loss}\n")
    
# deepspeed --num_gpus=1 --master_port=25653 calculate_relative_error.py