import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, random_split, Subset
import os, pickle
import time
import deepspeed
import yaml
import torch.nn.functional as F
import matplotlib.pyplot as plt

from models import EGATTransformerNetwork
from load_data import load_data
from loss_function import loss_function, evaluate_model
from get_data import GNNDataset, GNNDataset_ontime_loader

model_number = 2
epoch = 980
model = EGATTransformerNetwork()
deepspeed_config = {
    "train_batch_size": 8,
    "gradient_accumulation_steps": 4,
    "optimizer": {
        "type": "Adam",
        "params": {
            "lr": 0.001,
            "betas": [0.8, 0.999],
            "eps": 1e-08,
            "weight_decay": 3e-07
        }
    },
    "zero_optimization": {
        "stage": 0
    },
    "fp16": {
        "enabled": False
    },
    "deepspeed": {
        "disable_mpi": True
    }
}

config_yaml = '../configs/model.yaml'
with open(config_yaml, 'r') as f:
    args = yaml.safe_load(f)
    
model_engine,_,_,_ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config_params=deepspeed_config,
)

model_type = args['model']
load_path, _ = model_engine.load_checkpoint(f'./results/model/{model_type}/model{model_number}')
model_engine.eval()

def plot_2d(X, Y, stage, position, method):
    ax = fig.add_subplot(4, 4, position)  # change
    if position == 1:
        ax.set_title('(a)', loc='left', fontsize=30, pad=20)
    elif position == 9:
        ax.set_title('(b)', loc='left', fontsize=30, pad=20)
    else:
        ax.set_title(f'({method})', loc='left', fontsize=30, pad=5)
    cp = plt.tricontourf(X, Y, stage, 20, cmap='RdBu_r')
    cbar = plt.colorbar(cp)
    ax.set_xlim([np.min(X), np.max(X)])
    ax.set_ylim([np.min(Y), np.max(Y)])
    ax.set_xticks([(round(np.min(X), 1) + round(np.max(X), 1)) / 2, round(np.max(X), 1)])
    ax.set_yticks([round(np.min(Y), 1), (round(np.min(Y), 1) + round(np.max(Y), 1)) / 2, round(np.max(Y), 1)])
    ax.tick_params(axis='x', labelsize=25)
    ax.tick_params(axis='y', labelsize=25)

with open('../indices/indices_100_80_10_10.pkl', 'rb') as f:
    indices = pickle.load(f)
train_idx = indices['train_idx']
test_idx = indices['test_idx']

select_test_idx = np.random.choice(100, 3)  # [36, 3, 33]

fig = plt.figure(figsize=(30, 20))
for i in range(0, 3, 2):
    idx = select_test_idx[i]
    print(idx)
    simulation_num = idx
    data1 = load_data(simulation_num, 20, args)
    X = data1.points[:, 0]
    Y = data1.points[:, 1]
    initial = data1.x.numpy()[:, 0]
    plot_2d(X, Y, initial, 4 * (i) + 1, r'Initial state $E_z$')

    # -----------------------------------------
    data1.cuda()
    net_E1, _ = model_engine(data1)
    net_E1 = net_E1.cpu().detach().squeeze(1).numpy()
    plot_2d(X, Y, net_E1, 4 * (i + 1) + 2, f'GT-MSMW(n={20})')
    ground_truth1 = data1.y_E.cpu().detach().squeeze(1).numpy()
    plot_2d(X, Y, ground_truth1, 4 * (i) + 2, f'Ground truth(n={20})')

    # -----------------------------------------
    data2 = load_data(simulation_num, 50, args)
    data2.cuda()
    net_E2, _ = model_engine(data2)
    net_E2 = net_E2.cpu().detach().squeeze(1).numpy()
    plot_2d(X, Y, net_E2, 4 * (i + 1) + 3, f'GT-MSMW(n={50})')
    ground_truth2 = data2.y_E.cpu().detach().squeeze(1).numpy()
    plot_2d(X, Y, ground_truth2, 4 * (i) + 3, f'Ground truth(n={50})')

    # -----------------------------------------
    data3 = load_data(simulation_num, 90, args)
    data3.cuda()
    net_E3, _ = model_engine(data3)
    net_E3 = net_E3.cpu().detach().squeeze(1).numpy()
    plot_2d(X, Y, net_E3, 4 * (i + 1) + 4, f'GT-MSMW(n={90})')
    ground_truth3 = data3.y_E.cpu().detach().squeeze(1).numpy()
    plot_2d(X, Y, ground_truth3, 4 * (i) + 4, f'Ground truth(n={90})')

    AE_error = np.abs(net_E3 - ground_truth3)
    plot_2d(X, Y, AE_error, 4 * (i + 1) + 1, f'AE for n={90}')

plt.tight_layout()
plt.savefig("2D-R1F0-evolution")
plt.close(fig)

# deepspeed --num_gpus=1 --master_port=25653 plt_evolution.py