import torch
import pickle
import pypower
from pypower.api import *
from pypower.idx_bus import *
from pypower.idx_gen import *
from pypower.idx_brch import *
from pypower.idx_cost import *
import numpy as np

def data_prepare_gla_v2(networks, ppcs, tasks, root_path="", is_test=False, float_disc=False, decimals=2, data_amount=None):
    samples, ppc_lsts, fault_curve_lsts, fault_info_lsts, state_est_lsts = [], [], [], [], []
    for network in networks:
        with open(f"{root_path}/ppc_lst_{network}.pkl", "rb") as file:
            ppc_lsts.append(pickle.load(file))
        if 'fault_detect' in tasks or 'transient_pred' in tasks:
            with open(f"{root_path}/{network if network != 'SG126' else 'IEEE118'}_Voldataset_2s120hz.pckl", "rb") as file:
                fault_curve_lsts.append(pickle.load(file))
            with open(f"{root_path}/{network if network != 'SG126' else 'IEEE118'}_faultinfo_2s120hz.pckl", "rb") as file:
                fault_info_lsts.append(pickle.load(file))
        if 'state_est' in tasks:
            with open(f"{root_path}/StateEstimation/{network if network != 'SG126' else 'IEEE126'}_State_estimation.pckl", "rb") as file:
                state_est_lsts.append(pickle.load(file))

    if data_amount is not None:
        num_scenario = [0, data_amount] if not is_test else [len(ppc_lsts[0]) - 3000, len(ppc_lsts[0])]
    else:
        num_scenario = [0, len(ppc_lsts[0]) - 3000] if not is_test else [len(ppc_lsts[0]) - 3000, len(ppc_lsts[0])]

    edge_index_dict = {}
    for network, ppc in ppcs.items():
        bus_num = ppc['bus'].shape[0]
        branch_num = ppc['branch'].shape[0]
        edge_index = []
        for j in range(branch_num):
            f_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, F_BUS])
            t_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, T_BUS])
            edge_index.append([f_bus, t_bus])
        edge_index = torch.tensor(edge_index)
        edge_index_dict[network] = edge_index


    for i in range(num_scenario[0], num_scenario[1]):
        if i % 100 == 0:
            print(f'{i}/{num_scenario[1]} samples loaded.')
        for n, network in enumerate(networks):
            ppc = ppc_lsts[n][i if i < len(ppc_lsts[0]) else len(ppc_lsts[0])-1]
            if 'fault_detect' in tasks or 'transient_pred' in tasks:
                curve = fault_curve_lsts[n][i]
                info = fault_info_lsts[n][i if i < len(ppc_lsts[0]) else len(ppc_lsts[0])-1]
            if 'state_est' in tasks:
                measure = np.concatenate([v.reshape(-1, 1) for k, v in state_est_lsts[n][i if i < len(ppc_lsts[0]) else len(ppc_lsts[0])-1]['measurement'].items()], axis=1)
                state = np.concatenate([v.reshape(-1, 1) for k, v in state_est_lsts[n][i if i < len(ppc_lsts[0]) else len(ppc_lsts[0])-1]['state'].items()], axis=1)

            edge_index = edge_index_dict[network]

            # opf
            bus_num = ppc['bus'].shape[0]
            pg_qg = np.zeros((bus_num, 2))
            for j, k in enumerate(ppc['gen'][:, GEN_BUS].astype(int).tolist()):
                idx = ppc['bus'][:, BUS_I].astype(int).tolist().index(k)
                pg_qg[idx] = ppc['gen'][j, [PG, QG]]
            x = np.concatenate((ppc['bus'][:, [PD, QD, VM, VA]], pg_qg), axis=1)
            language_input = f"This is a operation scenario in {network} bus system. What is the best active power setpoint of generators?"
            # response = ppc['target_gen_p']
            response = np.zeros(ppc['bus'].shape[0])
            for idx, gen_bus in enumerate(ppc['gen'][:, GEN_BUS].astype(int).tolist()):
                bus_idx = ppc['bus'][:, BUS_I].tolist().index(gen_bus)
                response[bus_idx] = ppc['gen'][idx, PG]
            x = (x - x.min(0, keepdims=True)) / (x.max(0, keepdims=True) - x.min(0, keepdims=True) + 1e-4)
            x = x * 2 - 1
            x = torch.tensor(x)
            float_max, float_min = response.max(), response.min()
            response = (response - float_min) / (float_max - float_min + 1e-4)
            response = torch.tensor(response).float().unsqueeze(-1)
            if 'opf' in tasks:
                try:
                    samples.append((x, edge_index, language_input, response, float_max, float_min, ppc['index']))
                except:
                    samples.append((x, edge_index, language_input, response, float_max, float_min, i))

            # LMP prediction
            if 'lmp_pred' in tasks:
                language_input = f'This is a operation scenario in {network} bus system. What is the locational marginal price?'
                response = ppc['bus'][:, LAM_P]
                float_max, float_min = response.max(), response.min()
                response = (response - float_min) / (float_max - float_min + 1e-4)
                response = torch.tensor(response).float().unsqueeze(-1)
                samples.append((x, edge_index, language_input, response, float_max, float_min, i))

            # state estimation
            if 'state_est' in tasks:
                measure = (measure - measure.min(0, keepdims=True)) / (measure.max(0, keepdims=True) - measure.min(0, keepdims=True) + 1e-4)
                measure = measure * 2 - 1
                measure = torch.tensor(measure)
                measure = torch.cat((measure, measure), dim=-1)
                language_input = f"This is measurements of voltage magnitudes, active power injection and reactive power injection in {network} bus system. What are the real states of voltage magnitude and phase angles?"
                float_max, float_min = state.max(), state.min()
                response = (state - float_min) / (float_max - float_min + 1e-4)
                response = torch.tensor(response).float()
                samples.append((measure, edge_index, language_input, response, float_max, float_min, i))

            if network == 'SG126':
                network = 'IEEE118'
                edge_index = edge_index_dict[network]

            # transient
            if 'fault_detect' in tasks or 'transient_pred' in tasks:
                threshold = 200
                if network not in ["IEEE118", "SG126", "IEEE300", "Texas2000"]:
                    pred_instruction = f'There is a {info["type"][0]} fault happend at bus {info["bus1"][0] if info["bus2"][0] == -1 else [info["bus1"][0], info["bus2"][0]]} in {network} bus system, what are the predictions of the following {curve.shape[1] - threshold} steps of voltage curves?'
                else:
                    pred_instruction = f'There is a {info["type"]} fault happend at bus {info["bus1"] if info["bus2"] == -1 else [info["bus1"], info["bus2"]]} in {network} bus system, what are the predictions of the following {curve.shape[1] - threshold} steps of voltage curves?'
                clf_instruction = f'These are fault nodal voltage curves in {network} bus system. What are the fault type and fault location?'
                target_pred = curve[:, threshold:]
                curve = (curve - curve.min()) / (curve.max() - curve.min() + 1e-4)
                curve = curve * 2 - 1
                if info["bus2"] == -1:
                    target_clf = f'The fault type is {info["type"]}, happended at bus {info["bus1"]}.'
                else:
                    target_clf = f'The fault type is {info["type"]}, happended at bus {[info["bus1"], info["bus2"]]}.'

                if 'transient_pred' in tasks:
                    float_max, float_min = target_pred.max(), target_pred.min()
                    pred = (target_pred - float_min) / (float_max - float_min + 1e-4)
                    pred = torch.tensor(pred).float()
                    samples.append((torch.tensor(curve[:, :threshold]), edge_index, pred_instruction, pred, float_max, float_min, i))

                if 'fault_detect' in tasks:
                    samples.append((torch.tensor(curve), edge_index, clf_instruction, target_clf, float_max, float_min, i))
    return samples