import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, random_split, Subset
import os, pickle
import time
import deepspeed
import yaml
import torch.nn.functional as F
import matplotlib.patches as Patch
import seaborn as sns
import scipy.stats as stats
import matplotlib.pyplot as plt

from models import EGATTransformerNetwork
from load_data import load_data
from loss_function import loss_function, evaluate_model
from get_data import GNNDataset, GNNDataset_ontime_loader

model_number = 2
epoch = 980
model = EGATTransformerNetwork()
deepspeed_config = {
    "train_batch_size": 8,
    "gradient_accumulation_steps": 4,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 0.001,
            "betas": [0.8, 0.999],
            "eps": 1e-08,
            "weight_decay": 3e-07
        }
    },
    "zero_optimization": {
        "stage": 0
    },
    "fp16": {
        "enabled": False
    },
    "deepspeed": {
        "disable_mpi": True
    }
}

config_yaml = '../configs/model.yaml'
with open(config_yaml, 'r') as f:
    args = yaml.safe_load(f)
    
model_engine,_,_,_ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config_params=deepspeed_config,
)

model_type = args['model']
load_path, _ = model_engine.load_checkpoint(f'./results/model/{model_type}/model{model_number}')
model_engine.eval()

relative_error = []

def J_loss(dataset, idx, prefix):
    loss1 = 0
    for i in range(len(dataset)):
        j = idx[i]
        print(j)
        data = dataset[i].cuda()
        net_E, net_H = model_engine(data)
        loss_up = (torch.sum((net_E - data.y_E) ** 2) + torch.sum((net_H - data.y_H) ** 2)).cpu().detach().numpy()
        loss_down = (torch.sum(data.y_E ** 2) + torch.sum(data.y_H ** 2)).cpu().detach().numpy()
        loss = loss_up / loss_down * 100
        loss1 += loss
        relative_error.append(loss)
        print(f"Relative Error: {(loss)}")
    print(f"Average Relative Error: {(loss1 / len(idx))}")

with open('../indices/indices_100_80_10_10.pkl', 'rb') as f:
    indices = pickle.load(f)
train_idx = indices['train_idx']
test_idx = indices['test_idx']

select_test_idx = np.random.choice(train_idx, 900)  # test_idx

dataset = GNNDataset_online_loader(args)
select_testdataset = Subset(dataset, select_test_idx)
J_loss(select_testdataset, select_test_idx, 'test')

relative_error = [x for x in relative_error if 0 < x < 2]
mean_error = np.mean(relative_error)
std_error = np.std(relative_error)

xmin, xmax = min(relative_error), max(relative_error)

plt.figure(figsize=(8, 6))
sns.histplot(relative_error, bins=30, color='sandybrown', edgecolor='black', kde=False, stat='density')

sns.kdeplot(relative_error, color='orangered', linewidth=2, label=f'KDE Fit: Mean = {mean_error:.2f}, Std = {std_error:.2f}')

plt.tick_params(axis='both', which='major', width=2, length=6) 
plt.tick_params(axis='both', which='minor', width=1.5, length=4)

plt.title('2D-R1F0', fontsize=14) 
plt.xlabel('Relavate error of field values (%)', fontsize=14)  
plt.ylabel('Density', fontsize=14)

hist_patch = Patch(color='sandybrown', label="Test set distribution")

kde_patch = Patch(color='orangered', label=f'Mean = {mean_error:.2f}, Std = {std_error:.2f}')

plt.legend(handles=[hist_patch, kde_patch], loc="upper right", fontsize=12, frameon=True)

plt.grid(True, linestyle='--', alpha=0.6)

plt.savefig('2D-R1F0-histogram')  

# deepspeed --num_gpus=1 --master_port=25653 plt_histogram.py