import torch
import pickle
import pypower
from pypower.api import *
from pypower.idx_bus import *
from pypower.idx_gen import *
from pypower.idx_brch import *
from pypower.idx_cost import *
from utils import GLADataset, float_discretization
import numpy as np


def data_prepare_llm(networks, tasks, root_path="", is_test=False):
    samples, ppc_lsts, fault_curve_lsts, fault_info_lsts, state_est_lsts = [], [], [], [], []


    for network in networks:
        with open(f"{root_path}/ppc_lst_{network}.pkl", "rb") as file:
            ppc_lsts.append(pickle.load(file))
        with open(f"{root_path}/{network if network != 'SG126' else 'IEEE118'}_Voldataset_2s120hz.pckl", "rb") as file:
            fault_curve_lsts.append(pickle.load(file))
        with open(f"{root_path}/{network if network != 'SG126' else 'IEEE118'}_faultinfo_2s120hz.pckl", "rb") as file:
            fault_info_lsts.append(pickle.load(file))
        with open(f"{root_path}/StateEstimation/{network if network != 'SG126' else 'IEEE126'}_State_estimation.pckl", "rb") as file:
            state_est_lsts.append(pickle.load(file))

    num_scenario = [0, len(ppc_lsts[0]) - 5000] if not is_test else [len(ppc_lsts[0]) - 5000, len(ppc_lsts[0])-4900]

    for i in range(num_scenario[0], num_scenario[1]):
        for n, network in enumerate(networks):
            ppc = ppc_lsts[n][i]
            curve = fault_curve_lsts[n][i]
            info = fault_info_lsts[n][i]
            measure = np.concatenate([v.reshape(-1, 1) for k, v in state_est_lsts[n][i]['measurement'].items()], axis=1)
            state = np.concatenate([v.reshape(-1, 1) for k, v in state_est_lsts[n][i]['state'].items()], axis=1)

            bus_num = ppc['bus'].shape[0]
            branch_num = ppc['branch'].shape[0]
            edge_index = []
            for j in range(branch_num):
                f_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, F_BUS])
                t_bus = ppc['bus'][:, BUS_I].astype(int).tolist().index(ppc['branch'][j, T_BUS])
                # if ppc['branch'][j, BR_STATUS]:
                edge_index.append([f_bus, t_bus])
            edge_index = torch.tensor(edge_index)
            pg_qg = np.zeros((bus_num, 2))
            for j, k in enumerate(ppc['gen'][:, GEN_BUS].astype(int).tolist()):
                idx = ppc['bus'][:, BUS_I].astype(int).tolist().index(k)
                pg_qg[idx] = ppc['gen'][j, [PG, QG]]
            x = np.concatenate((ppc['bus'][:, [PD, QD, VM, VA]], pg_qg), axis=1)
            response = ppc['target_gen_p']
            float_max = max(abs(response))
            x_text = f'{float_discretization(x, abs(x).max())}'.replace(',', '')
            language_input = f"This is a operation scenario of {network} bus system. " \
                             f"The system input states are {x_text}. " \
                             f"What is the best active power setpoint of generators?"
            discretized_response = f'The best setpoint is {float_discretization(response, float_max)}'
            if 'opf' in tasks:
                samples.append((language_input, discretized_response, float_max, i))

            # LMP prediction
            language_input = f'This is a operation scenario of {network} bus system. ' \
                             f'The system input states are {x_text}' \
                             f'What is the locational marginal price?'
            response = ppc['bus'][:, LAM_P]
            float_max = max(abs(response))
            discretized_response = f'The LMP prediction is {float_discretization(response, float_max)}'
            if 'lmp_pred' in tasks:
                samples.append((language_input, discretized_response, float_max, i))

            # state estimation
            x_text = f'{float_discretization(measure, abs(measure).max())}'.replace(',', '')
            float_max = abs(state).max()
            state = state.T
            language_input = f"This is measurements of voltage magnitudes, active power injection and reactive power injection in {network} bus system. " \
                             f"The measurements are {x_text}" \
                             f"What are the real states of voltage magnitude and phase angles?"
            discretized_response = f'The best estimation is {float_discretization(state, float_max)}'
            if 'state_est' in tasks:
                samples.append((language_input, discretized_response, float_max, i))


            threshold = 200
            target_pred = curve[:, threshold:]
            float_max = abs(curve).max()
            if network not in ["IEEE118", "SG126", "IEEE300"]:
                if info["bus2"][0] == -1:
                    target_clf = f'The fault is {info["type"][0]}, happended at bus {info["bus1"][0]}.'
                else:
                    target_clf = f'The fault is {info["type"][0]}, happended at bus {[info["bus1"][0], info["bus2"][0]]}.'
            else:
                if info["bus2"] == -1:
                    target_clf = f'The fault type is {info["type"]}, happended at bus {info["bus1"]}.'
                else:
                    target_clf = f'The fault type is {info["type"]}, happended at bus {[info["bus1"], info["bus2"]]}.'

            curve_text = f'{float_discretization(curve[:, threshold], float_max)}'.replace(',', '')
            pred_values = f'{float_discretization(target_pred, float_max)}'.replace(',', '')
            discretized_pred = 'The predicted curves are ' + pred_values
            pred_instruction = f'The previous {threshold} steps of nodal voltages are {curve_text}. ' \
                               f'There is a {info["type"][0]} fault happend at bus {info["bus1"][0] if info["bus2"][0] == -1 else [info["bus1"][0], info["bus2"][0]]} in {network} system, ' if network not in ['IEEE118', 'SG126', 'IEEE300'] else f'There is a {info["type"]} fault happend at bus {info["bus1"] if info["bus2"] == -1 else [info["bus1"], info["bus2"]]} in {network} system, '\
                               f'What are predictions of the following voltage curves given a range of voltages?'
            if 'transient_pred' in tasks:
                samples.append((pred_instruction, discretized_pred, float_max, i))

            curve_text = f'{float_discretization(curve, float_max)}'.replace(',', '')
            clf_instruction = f'The nodal volage curves are {curve_text}. ' \
                              f'What are the fault type and location given the nodal voltage curves in {network} system?'
            if 'fault_detect' in tasks:
                samples.append((clf_instruction, target_clf, float_max, i))
    return samples