import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, random_split, Subset
import os, pickle
import time
import deepspeed
import yaml
import torch.nn.functional as F
import matplotlib.pyplot as plt

from models import EGATTransformerNetwork
from load_data import load_data
from loss_function import loss_function, evaluate_model
from get_data import GNNDataset, GNNDataset_ontime_loader

model_number = 2
epoch = 980
model = EGATTransformerNetwork()
deepspeed_config = {
    "train_batch_size": 8,
    "gradient_accumulation_steps": 4,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 0.001,
            "betas": [0.8, 0.999],
            "eps": 1e-08,
            "weight_decay": 3e-07
        }
    },
    "zero_optimization": {
        "stage": 0
    },
    "fp16": {
        "enabled": False
    },
    "deepspeed": {
        "disable_mpi": True
    }
}

config_yaml = '../configs/model.yaml'
with open(config_yaml, 'r') as f:
    args = yaml.safe_load(f)
    
model_engine,_,_,_ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config_params=deepspeed_config,
)

model_type = args['model']
load_path, _ = model_engine.load_checkpoint(f'./results/model/{model_type}/model{model_number}')
model_engine.eval()

def plot_2d(X, Y, stage, position, method):
    ax = fig.add_subplot(3, 4, position)  
    if position == 1:
        ax.set_title('(a)', loc='left', fontsize=30, pad=20)
    elif position == 5: 
        ax.set_title('(b)', loc='left', fontsize=30, pad=20)
    elif position == 9:  
        ax.set_title(f'({method})', loc='left', fontsize=30, pad=5)
    cp = plt.tricontourf(X, Y, stage, 20, cmap='RdBu_r')
    cbar = plt.colorbar(cp)
    ax.set_xlim([np.min(X), np.max(X)])
    ax.set_ylim([np.min(Y), np.max(Y)])
    ax.set_xticks([(round(np.min(X), 1) + round(np.max(X), 1)) / 2, round(np.max(X), 1)])
    ax.set_yticks([round(np.min(Y), 1), (round(np.min(Y), 1) + round(np.max(Y), 1)) / 2, round(np.max(Y), 1)])
    ax.tick_params(axis='x', labelsize=25)
    ax.tick_params(axis='y', labelsize=25)

with open('../indices/indices_100_80_10_10.pkl', 'rb') as f:
    indices = pickle.load(f)
train_idx = indices['train_idx']
test_idx = indices['test_idx']

select_test_idx = np.random.choice(test_idx, 3)

fig = plt.figure(figsize=(30, 20))
for i in range(3):
    idx = select_test_idx[i]
    print(idx)
    simulation_num = idx // 99
    step_num = idx % 99
    data = load_data(simulation_num, step_num, args)
    X = data.points[:, 0]
    Y = data.points[:, 1]
    initial = data.x.numpy()[:, 0]
    plot_2d(X, Y, initial, 4 * i + 1, r'Initial state $E_z$')

    # -----------------------------------------
    ground_truth = data.y_E.squeeze(1).numpy()
    plot_2d(X, Y, ground_truth, 4 * i + 2, f'Ground truth(n={step_num})')

    # -----------------------------------------
    data.cuda()
    net_E, _ = model_engine(data)
    print(net_E[:5, :])
    net_E = net_E.cpu().detach().squeeze(1).numpy()
    plot_2d(X, Y, net_E, 4 * i + 3, f'GT-MSMW(n={step_num})')

    # -----------------------------------------
    AE_error = np.abs(net_E - ground_truth)
    plot_2d(X, Y, AE_error, 4 * i + 4, 'AE')

plt.tight_layout()
plt.savefig("2D-R0F0")
plt.close(fig)

# deepspeed --num_gpus=1 --master_port=25653 plt_comparison.py