import torch
import pickle
import pypower
from pypower.api import *
from pypower.idx_bus import *
from pypower.idx_gen import *
from pypower.idx_brch import *
from pypower.idx_cost import *
from utils import GLADataset, float_discretization
import numpy as np

# Prepare dataset for GCN, task could be selected in ['opf', 'state_est', 'lmp_pred']
def data_prepare_gcn(
        network,
        ppcs,
        task='opf',
        root_path="",
        is_test=False,
        normalize=True
):
    samples, ppc_lsts, state_est_lsts = [], [], []
    fault_curve_lsts = []

    with open(f"{root_path}/ppc_lst_{network}.pkl", "rb") as file:
        ppc_lsts.append(pickle.load(file))
    if task == 'state_est':
        with open(f"{root_path}/StateEstimation/{network if network != 'SG126' else 'IEEE126'}_State_estimation.pckl", "rb") as file:
            state_est_lsts.append(pickle.load(file))
    if task == 'transient_pred':
        if network == 'Texas2000':
            fault_curve_lsts.append(np.load(f'{root_path}/Texas2000_Voldataset_2s120hz_ballanced.npy'))
        else:
            with open(f"{root_path}/{network if network != 'SG126' else 'IEEE118'}_Voldataset_2s120hz{'_35k' if network=='IEEE300' else ''}.pckl", "rb") as file:
                fault_curve_lsts.append(pickle.load(file))

    num_scenario = [0, len(ppc_lsts[0]) - 5000] if not is_test else [len(ppc_lsts[0]) - 5000, len(ppc_lsts[0])-4900]

    edge_index_dict = {}
    for network, ppc in ppcs.items():
        bus_num = ppc['bus'].shape[0]
        branch_num = ppc['branch'].shape[0]
        edge_index = []
        for j in range(branch_num):
            f_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, F_BUS])
            t_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, T_BUS])
            # if ppc['branch'][j, BR_STATUS]:
            edge_index.append([f_bus, t_bus])
        edge_index = torch.tensor(edge_index)
        edge_index_dict[network] = edge_index

    for i in range(num_scenario[0], num_scenario[1]):
        if i % 100 == 0:
            print(f'{i}/{num_scenario[1]} samples loaded.')
        n = 0
        ppc = ppc_lsts[n][i]
        edge_index = edge_index_dict[network]
        p_or = ppc['branch'][:, PF]
        q_or = ppc['branch'][:, QF]
        p_ex = ppc['branch'][:, PT]
        q_ex = ppc['branch'][:, QT]
        line_status = ppc['branch'][:, BR_STATUS]
        edge_attr = torch.from_numpy(np.asarray([p_or, q_or, p_ex, q_ex, line_status])).T
        assert edge_attr.shape[0] == edge_index.shape[0]

        pg_qg = np.zeros((bus_num, 2))
        gen_bus = []
        for j, k in enumerate(ppc['gen'][:, GEN_BUS].astype(int).tolist()):
            idx = ppc['bus'][:, BUS_I].astype(int).tolist().index(k)
            gen_bus.append(idx)
            pg_qg[idx] = ppc['gen'][j, [PG, QG]]

        if task in ['opf', 'lmp_pred']:
            x = torch.tensor(np.concatenate((ppc['bus'][:, [PD, QD, VM, VA]], pg_qg), axis=1))
            if task == 'opf':
                response = ppc['target_gen_p']
            elif task == 'lmp_pred':
                response = ppc['bus'][:, LAM_P]
            response = response.reshape(-1, 1)
            float_max = x.max().item()
            float_min = x.min().item()
            if normalize: # min-max normalization
                x = (x - float_min) / (float_max - float_min + 1e-5)
                response = (response - float_min) / (float_max - float_min + 1e-5)
            try:
                samples.append((x, edge_index, edge_attr, response, ppc['index'], np.asarray(gen_bus), float_max,
                                float_min, ppc['target_gen_p'], network))
            except:
                samples.append((x, edge_index, edge_attr, response, i, np.asarray(gen_bus), float_max, float_min,
                                ppc['target_gen_p'], network))
        elif task == 'state_est':
            measure = np.concatenate([v.reshape(-1, 1) for k, v in state_est_lsts[n][i]['measurement'].items()], axis=1)
            state = np.concatenate([v.reshape(-1, 1) for k, v in state_est_lsts[n][i]['state'].items()], axis=1)
            response = state
            float_max = measure.max().item()
            float_min = measure.min().item()
            if normalize:
                measure = (measure - float_min) / (float_max - float_min + 1e-5)
                response = (response - float_min) / (float_max - float_min + 1e-5)
            samples.append((measure, edge_index, edge_attr, response, i, np.asarray(gen_bus), float_max, float_min, state, network))
        elif task == 'transient_pred':
            curve = fault_curve_lsts[n][i]
            if network == 'SG126':
                edge_index = edge_index_dict['IEEE118']
            threshold = 200
            target_pred = curve[:, threshold:]
            curve = (curve - curve.min()) / (curve.max() - curve.min() + 1e-4)
            float_max, float_min = target_pred.max(), target_pred.min()
            if normalize:
                pred = (target_pred - float_min) / (float_max - float_min + 1e-5)
                pred = torch.tensor(pred).float()
            samples.append((curve[:, :threshold], edge_index, edge_attr, pred, i, np.asarray(gen_bus), float_max, float_min, target_pred, network))

    return samples