from baselines.gnn import GNN_Model
from baselines.pfnet import MPN_simplenet
from baselines.canos import Canos_Model
from baselines.data_preprocess import *
from utils import *
import RL4Grid
import time
import torch
from pypower.api import *

if __name__ == '__main__':
    use_lora = True
    float_disc = True
    networks = [
        # 'IEEE14',
        # 'IEEE39',
        # 'IEEE57',
        # 'SG126',
        # 'IEEE300',  # not available if using 8B model
        'Texas2000'
    ]
    ppcs = {
        'IEEE14': case14(),
        'IEEE39': case39(),
        'IEEE57': case57(),
        'IEEE118': case118(),
        'SG126': case126(),
        'IEEE300': case300(),
        'Texas2000': case2000()
    }
    tasks = [
        # 'opf',
        'state_est',
        'lmp_pred',
        # 'transient_pred'
    ]
    model_type = 'baselines'

    llm_dones, target_dones, optimal_gaps, opf_inf_t = {}, {}, {}, {}
    state_est_nrmses, state_est_rsquares, state_est_inf_t = {}, {}, {}
    transient_pred_nrmses, transient_pred_rsquares, pmu_pred_inf_t = {}, {}, {}
    lmp_nrmses, lmp_rsquares, lmp_pred_inf_t = {}, {}, {}
    for network in networks:
        llm_dones[network], target_dones[network], optimal_gaps[network], opf_inf_t[network], \
            state_est_nrmses[network], state_est_rsquares[network], state_est_inf_t[network], \
            transient_pred_nrmses[network], transient_pred_rsquares[network], pmu_pred_inf_t[network], \
            lmp_nrmses[network], lmp_rsquares[network], lmp_pred_inf_t[network] \
            = [], [], [], [], [], [], [], [], [], [], [], [], []
        for task in tasks:
            # Evaluation
            samples = data_prepare_gcn(network, ppcs, task, is_test=True)
            nfeature_dim = samples[0][0].shape[1]
            efeature_dim = samples[0][2].shape[1]
            output_dim = samples[0][3].shape[1]
            hidden_dim = 128
            if model_type in ['baselines', 'gat', 'deepopf', 'gin']:
                model = GNN_Model(nfeature_dim, hidden_dim, output_dim, type=model_type).cuda()
            # elif model_type == 'pfnet':
            #     model = MPN_simplenet(nfeature_dim, efeature_dim, output_dim, hidden_dim, n_gnn_layers=3, K=3, dropout_rate=0.0)
            elif model_type == 'canos':
                model = Canos_Model(nfeature_dim, hidden_dim, efeature_dim, output_dim)
            elif model_type == 'itransformer':
                node_num = samples[0][0].shape[0]
                model = ITransformer_Model(
                    seq_len=nfeature_dim, pred_len=samples[0][3].shape[1], use_norm=False
                )
            else:
                raise NotImplementedError
            model.load(
                model_path=f'./saved_models/{model_type}_{network}_{task}.p'
            )
            model.eval()
            model.cuda()

            for (x, edge_index, edge_attr, response, i, gen_bus, float_max, float_min, ori_response, network) in samples:
                gen_bus_lst = gen_bus.astype(int).tolist()
                edge_index = edge_index.long().cuda().permute(1, 0)
                x = torch.tensor(x).float().cuda()
                edge_attr = edge_attr.float().cuda()
                if model_type in ['baselines', 'gat', 'deepopf', 'gin']:
                    outputs = model(x.squeeze(0), edge_index).squeeze(-1)
                # elif model_type == 'pfnet':
                #     outputs = model(x.squeeze(0), edge_index, edge_attr.squeeze(0)).squeeze(-1)
                elif model_type == 'canos':
                    outputs = model(x.squeeze(0), edge_index, edge_attr.squeeze(0)).squeeze(-1)
                elif model_type == 'itransformer':
                    outputs = model(x.unsqueeze(0).permute(0, 2, 1), x_mark_enc=None).permute(0, 2, 1).squeeze(0)
                else:
                    raise NotImplementedError

                outputs = outputs.detach().cpu().numpy()
                # outputs = (outputs + 1) / 2 * (float_max - float_min) + float_min
                outputs = outputs * (float_max - float_min + 1e-5) + float_min
                # response = (np.asarray(response) + 1) / 2 * (float_max - float_min) + float_min
                response = np.asarray(response) * (float_max - float_min + 1e-5) + float_min
                if task == 'lmp_pred':
                    nrmse = compute_nrmse(response, outputs, percentage=True)
                    rsquare = compute_r2_score(response, outputs)
                    lmp_nrmses[network].append(nrmse)
                    lmp_rsquares[network].append(rsquare)

                if task == 'state_est':
                    nrmse = compute_nrmse(response, outputs, percentage=True)
                    rsquare = compute_r2_score(response, outputs)
                    state_est_nrmses[network].append(nrmse)
                    state_est_rsquares[network].append(rsquare)


                if task == 'transient_pred':
                    nrmse = compute_nrmse(response, outputs, percentage=True)
                    rsquare = compute_r2_score(response, outputs)
                    transient_pred_nrmses[network].append(nrmse)
                    transient_pred_rsquares[network].append(rsquare)

                if task == 'opf':
                    gen_pred = outputs[gen_bus_lst]
                    env = RL4Grid.make_gridsim({network: ppcs[network]}, deterministic=True)
                    num_gen = env.env.ppc['num_gen']
                    obs = env.reset(start_sample_idx=i)
                    action_high = obs.action_space['adjust_gen_p'].high
                    action_low = obs.action_space['adjust_gen_p'].low
                    best_a = ori_response - np.asarray(obs.gen_p)
                    best_a = best_a.clip(action_low, action_high)
                    _, best_reward, target_done, info = env.step(best_a)
                    if target_done:
                        import ipdb
                        ipdb.set_trace()
                    obs = env.reset(start_sample_idx=i)
                    llm_a = gen_pred - np.asarray(obs.gen_p)
                    llm_a = llm_a.clip(action_low, action_high)
                    _, reward, llm_done, info = env.step(llm_a)
                    if llm_done:
                        reward = -10
                    optimal_gap = (best_reward - reward) / 3
                    llm_dones[network].append(llm_done)
                    target_dones[network].append(target_done)
                    optimal_gaps[network].append(optimal_gap)

    for network in networks:
        print('*********************************************')
        print(f'{model_type}')
        print(
            f'network={network}, opf_pass_rate={1 - np.asarray(llm_dones[network]).mean():.3f}, optimal_gap={np.asarray(optimal_gaps[network]).mean():.3f}, '
            f'state_est_nrmse={np.asarray(state_est_nrmses[network]).mean():.3f}, state_est_rsquare={np.asarray(state_est_rsquares[network]).mean():.3f}, '
            f'lmp_nrmse={np.asarray(lmp_nrmses[network]).mean():.3f}, lmp_rsquare={np.asarray(lmp_rsquares[network]).mean():.3f}, '
            f'transient_pred_nrmse={np.asarray(transient_pred_nrmses[network]).mean()}, transient_pred_rsquare={np.asarray(transient_pred_rsquares[network]).mean()}'
        )
    import ipdb
    ipdb.set_trace()