import torch
import pickle
import pypower
from pypower.api import *
from pypower.idx_bus import *
from pypower.idx_gen import *
from pypower.idx_brch import *
from pypower.idx_cost import *
from utils import GLADataset, float_discretization
import numpy as np


# def data_prepare_gla(networks, tasks, data_amount=None, root_path="", is_test=False, float_disc=False, decimals=2):
#     samples, ppc_lsts, fault_curve_lsts, fault_info_lsts, state_est_lsts = [], [], [], [], []
#
#     for network in networks:
#         with open(f"{root_path}/ppc_lst_{network}.pkl", "rb") as file:
#             ppc_lsts.append(pickle.load(file))
#
#     num_scenario = [0, len(ppc_lsts[0]) - 5000] if not is_test else [len(ppc_lsts[0]) - 1000, len(ppc_lsts[0])-900]
#     if data_amount is not None:
#         num_scenario = [0, data_amount] if not is_test else [len(ppc_lsts[0]) - 3000, len(ppc_lsts[0])-2900]
#
#     edge_indexes = []
#     gen_buses = []
#     for n, network in enumerate(networks):
#         ppc = ppc_lsts[n][0]
#         branch_num = ppc['branch'].shape[0]
#         edge_index = []
#         for j in range(branch_num):
#             f_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, F_BUS])
#             t_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, T_BUS])
#             # if ppc['branch'][j, BR_STATUS]:
#             edge_index.append([f_bus, t_bus])
#         edge_index = torch.tensor(edge_index)
#         edge_indexes.append(edge_index)
#
#         gen_bus = []
#         for j, k in enumerate(ppc['gen'][:, GEN_BUS].astype(int).tolist()):
#             idx = ppc['bus'][:, BUS_I].astype(int).tolist().index(k)
#             gen_bus.append(idx)
#         gen_buses.append(gen_bus)
#
#     for i in range(num_scenario[0], num_scenario[1]):
#         for n, network in enumerate(networks):
#             ppc = ppc_lsts[n][i]
#             # opf
#             bus_num = ppc['bus'].shape[0]
#             branch_num = ppc['branch'].shape[0]
#             edge_index = edge_indexes[n]
#             pg_qg = np.zeros((bus_num, 2))
#             for j, k in enumerate(ppc['gen'][:, GEN_BUS].astype(int).tolist()):
#                 # idx = ppc['bus'][:, BUS_I].astype(int).tolist().index(k)
#                 idx = gen_buses[n][j]
#                 pg_qg[idx] = ppc['gen'][j, [PG, QG]]
#             x = np.concatenate((ppc['bus'][:, [PD, QD, VM, VA]], pg_qg), axis=1)
#             language_input = f"This is a operation scenario of {network} bus system. What is the best active power setpoint of generators?"
#             response = ppc['target_gen_p']
#             float_max = max(abs(response))
#             x = (x - x.min(0, keepdims=True)) / (x.max(0, keepdims=True) - x.min(0, keepdims=True) + 1e-3)
#             x = x * 2 - 1
#             x = torch.tensor(x)
#             if 'opf' in tasks:
#                 if float_disc:
#                     discretized_response = f'The best setpoint is {float_discretization(response, float_max)}'
#                     samples.append((x, edge_index, language_input, discretized_response, float_max, i))
#                 else:
#                     response = f'The best setpoint is {np.around(response, decimals=decimals)}'
#                     samples.append((x, edge_index, language_input, response, float_max, i))
#     return samples


def data_prepare_gla(networks, tasks, data_amount=None, root_path="", is_test=False, float_disc=False, decimals=2):
    samples, ppc_lsts, fault_curve_lsts, fault_info_lsts, state_est_lsts = [], [], [], [], []

    for network in networks:
        with open(f"{root_path}/ppc_lst_{network}.pkl", "rb") as file:
            ppc_lsts.append(pickle.load(file))
        with open(f"{root_path}/{network if network != 'SG126' else 'IEEE118'}_Voldataset_2s120hz.pckl", "rb") as file:
            fault_curve_lsts.append(pickle.load(file))
        with open(f"{root_path}/{network if network != 'SG126' else 'IEEE118'}_faultinfo_2s120hz.pckl", "rb") as file:
            fault_info_lsts.append(pickle.load(file))
        with open(f"{root_path}/StateEstimation/{network if network != 'SG126' else 'IEEE126'}_State_estimation.pckl", "rb") as file:
            state_est_lsts.append(pickle.load(file))

    num_scenario = [0, len(ppc_lsts[0]) - 5000] if not is_test else [len(ppc_lsts[0]) - 1000, len(ppc_lsts[0])-900]
    if data_amount is not None:
        num_scenario = [0, data_amount] if not is_test else [len(ppc_lsts[0])-3000, len(ppc_lsts[0])-2900]

    for i in range(num_scenario[0], num_scenario[1]):
        for n, network in enumerate(networks):
            try:
                ppc = ppc_lsts[n][i]
            except:
                import ipdb
                ipdb.set_trace()
            curve = fault_curve_lsts[n][i]
            info = fault_info_lsts[n][i]
            measure = np.concatenate([v.reshape(-1, 1) for k, v in state_est_lsts[n][i]['measurement'].items()], axis=1)
            state = np.concatenate([v.reshape(-1, 1) for k, v in state_est_lsts[n][i]['state'].items()], axis=1)

            # opf
            bus_num = ppc['bus'].shape[0]
            branch_num = ppc['branch'].shape[0]
            edge_index = []
            for j in range(branch_num):
                f_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, F_BUS])
                t_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, T_BUS])
                # if ppc['branch'][j, BR_STATUS]:
                edge_index.append([f_bus, t_bus])
            edge_index = torch.tensor(edge_index)
            pg_qg = np.zeros((bus_num, 2))
            for j, k in enumerate(ppc['gen'][:, GEN_BUS].astype(int).tolist()):
                idx = ppc['bus'][:, BUS_I].astype(int).tolist().index(k)
                pg_qg[idx] = ppc['gen'][j, [PG, QG]]
            x = np.concatenate((ppc['bus'][:, [PD, QD, VM, VA]], pg_qg), axis=1)
            language_input = f"This is a operation scenario of {network} bus system. What is the best active power setpoint of generators?"
            response = ppc['target_gen_p']
            float_max = max(abs(response))
            x = (x - x.min(0, keepdims=True)) / (x.max(0, keepdims=True) - x.min(0, keepdims=True) + 1e-3)
            x = x * 2 - 1
            x = torch.tensor(x)
            if 'opf' in tasks:
                if float_disc:
                    discretized_response = f'The best setpoint is {float_discretization(response, float_max)}'
                    samples.append((x, edge_index, language_input, discretized_response, float_max, i))
                else:
                    response = f'The best setpoint is {np.around(response, decimals=decimals)}'
                    samples.append((x, edge_index, language_input, response, float_max, i))

            # LMP prediction
            if 'lmp_pred' in tasks:
                language_input = f'This is a operation scenario of {network} bus system. What is the locational marginal price?'
                response = ppc['bus'][:, LAM_P]
                float_max = max(abs(response))
                if float_disc:
                    discretized_response = f'The LMP prediction is {float_discretization(response, float_max)}'
                    samples.append((x, edge_index, language_input, discretized_response, float_max, i))
                else:
                    response = f'The LMP prediction is {np.around(response, decimals=decimals)}'
                    samples.append((x, edge_index, language_input, response, float_max, i))

            # state estimation
            if 'state_est' in tasks:
                measure = (measure - measure.min(0, keepdims=True)) / (measure.max(0, keepdims=True) - measure.min(0, keepdims=True) + 1e-3)
                measure = measure * 2 - 1
                measure = torch.tensor(measure)
                measure = torch.cat((measure, measure), dim=-1)
                float_max = abs(state).max()
                state = state.T
                language_input = f"This is measurements of voltage magnitudes, active power injection and reactive power injection in {network} bus system. What are the real states of voltage magnitude and phase angles?"
                if float_disc:
                    discretized_response = f'The best estimation is {float_discretization(state, float_max)}'
                    samples.append((measure, edge_index, language_input, discretized_response, float_max, i))
                else:
                    response = f'The best estimation is {np.around(state, decimals=decimals)}'
                    samples.append((measure, edge_index, language_input, response, float_max, i))

            if network == 'SG126':
                network = 'IEEE118'
                ppc = case118()
                bus_num = ppc['bus'].shape[0]
                branch_num = ppc['branch'].shape[0]
                edge_index = []
                for j in range(branch_num):
                    f_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, F_BUS])
                    t_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, T_BUS])
                    # if ppc['branch'][j, BR_STATUS]:
                    edge_index.append([f_bus, t_bus])
                edge_index = torch.tensor(edge_index)

            # transient
            if network not in ["IEEE118", "SG126", "IEEE300"]:
                pred_instruction = f'There is a {info["type"][0]} fault happend at bus {info["bus1"][0] if info["bus2"][0] == -1 else [info["bus1"][0], info["bus2"][0]]} in {network} system, what are the predictions of the following voltage curves?'
            else:
                pred_instruction = f'There is a {info["type"]} fault happend at bus {info["bus1"] if info["bus2"] == -1 else [info["bus1"], info["bus2"]]} in {network} system, what are the predictions of the following voltage curves?'
            clf_instruction = f'These are fault nodal voltage curves in {network} systems. What are the fault type and fault location?'
            threshold = 200
            target_pred = curve[:, threshold:]
            float_max = abs(target_pred).max()
            # curve = (curve - curve.min(1, keepdims=True)) / (curve.max(1, keepdims=True) - curve.min(1, keepdims=True) + 1e-6)
            curve = (curve - curve.min()) / (curve.max() - curve.min() + 1e-3)
            curve = curve * 2 - 1
            if network not in ["IEEE118", "SG126", "IEEE300"]:
                if info["bus2"][0] == -1:
                    target_clf = f'The fault is {info["type"][0]}, happended at bus {info["bus1"][0]}.'
                else:
                    target_clf = f'The fault is {info["type"][0]}, happended at bus {[info["bus1"][0], info["bus2"][0]]}.'
            else:
                if info["bus2"] == -1:
                    target_clf = f'The fault type is {info["type"]}, happended at bus {info["bus1"]}.'
                else:
                    target_clf = f'The fault type is {info["type"]}, happended at bus {[info["bus1"], info["bus2"]]}.'

            if 'transient_pred' in tasks:
                if float_disc:
                    pred_values = f'{float_discretization(target_pred, float_max)}'.replace(',', '')
                    discretized_pred = 'The predicted curves are ' + pred_values
                    samples.append((torch.tensor(curve[:, :threshold]), edge_index, pred_instruction, discretized_pred, float_max, i))
                else:
                    pred_values = f'{np.around(target_pred, decimals=decimals)}'.replace(',', '')
                    pred = 'The predicted curves are ' + pred_values
                    samples.append((torch.tensor(curve[:, :threshold]), edge_index, pred_instruction, pred, float_max, i))

            if 'fault_detect' in tasks:
                samples.append((torch.tensor(curve), edge_index, clf_instruction, target_clf, float_max, i))
    return samples