from gla_v2.model import GLA_Model_v2
from gla_v2.model_mot import GLA_Model_MoT
from gla_v2.data_preprocess import data_prepare_gla_v2
from utils import *
import RL4Grid
import time
from pypower.api import *


def eval_opf(response, target, env, i, max_float, min_float, float_disc):
    num_gen = env.env.ppc['num_gen']
    gen_bus_ids = env.env.ppc['gen_bus']
    response = response.detach().cpu().numpy()
    vec = (response.squeeze(-1)[:, gen_bus_ids] * (max_float - min_float + 1e-5) + min_float).squeeze(0)
    target_vec = (target * (max_float - min_float + 1e-5) + min_float).squeeze(0)
    obs = env.reset(start_sample_idx=i)
    action_high = obs.action_space['adjust_gen_p'].high
    action_low = obs.action_space['adjust_gen_p'].low
    if target_vec.shape[0] > num_gen:
        best_a = np.asarray(target_vec.squeeze()[gen_bus_ids]) - np.asarray(obs.gen_p)
    else:
        best_a = np.asarray(target_vec.squeeze()) - np.asarray(obs.gen_p)
    best_a = best_a.clip(action_low, action_high)
    for k in env.env.ppc['thermal_ids']:
        if obs.gen_p[k] - env.env.ppc['min_gen_p'][k] < 1:
            try:
                if best_a[k] < 0:
                    best_a[k] = 0
            except:
                import ipdb
                ipdb.set_trace()
    _, best_reward, target_done, info = env.step(best_a)
    if target_done:
        import ipdb
        ipdb.set_trace()
        print(f'total load={obs.load_p.sum():.3f}, total gen_p={(obs.gen_p+best_a).sum():.3f}')
    obs = env.reset(start_sample_idx=i)
    try:
        llm_a = np.asarray(vec[:num_gen]) - np.asarray(obs.gen_p)
    except:
        print('dimension is too small')
        return True, target_done, best_reward
    llm_a = llm_a.clip(action_low, action_high)
    for k in env.env.ppc['thermal_ids']:
        if obs.gen_p[k] - env.env.ppc['min_gen_p'][k] < 1:
            if llm_a[k] < 0:
                llm_a[k] = 0
    _, reward, llm_done, info = env.step(llm_a)
    if llm_done:
        reward = -10
    optimal_gap = (best_reward - reward) / 3
    print(f'llm_done={llm_done}, target_done={target_done}, opt_gap={optimal_gap:.4f}')
    return llm_done, target_done, optimal_gap

def eval_pmu_pred(response, target, max_float, min_float, bus_number):
    response = response.squeeze().reshape(bus_number, -1).detach().cpu().numpy()
    llm_pred = (response * (max_float - min_float + 1e-5) + min_float)
    target_pred = (target * (max_float - min_float + 1e-5) + min_float)

    target_pred = target_pred.reshape(bus_number, -1).detach().cpu().numpy()
    llm_pred = llm_pred.reshape(bus_number, -1)
    nrmse = compute_nrmse(target_pred, llm_pred, percentage=True)
    rsquare = compute_r2_score(target_pred, llm_pred)
    print(f'bus_num={bus_number}, nrmse={nrmse}, rsquare={rsquare}')
    return nrmse, rsquare

def eval_state_estimation(response, target, max_float, min_float, bus_num):
    response = response.reshape(bus_num, -1).detach().cpu().numpy()
    pred = (response.squeeze() * (max_float - min_float + 1e-5) + min_float)
    target = (target * (max_float - min_float + 1e-5) + min_float)
    nrmse = compute_nrmse(target, pred, percentage=True)
    rsquare = compute_r2_score(target, pred)
    print(f'bus_num={bus_num}, nrmse={nrmse}, rsquare={rsquare}')
    return nrmse, rsquare


def eval_LMP_prediction(response, target, max_float, min_float, bus_num):
    response = response.detach().cpu().numpy()
    pred = (response.squeeze(-1) * (max_float - min_float + 1e-5) + min_float).squeeze(0)
    target = (target * (max_float - min_float + 1e-5) + min_float).squeeze()
    target += np.random.randn(*target.shape) * 0.001    # avoid extreme values in nrmse and rsquares
    nrmse = compute_nrmse(target, pred, percentage=True)
    rsquare = compute_r2_score(target, pred)
    print(f'bus_num={bus_num}, nrmse={nrmse}, rsquare={rsquare}')
    return nrmse, rsquare


def eval_pmu_clf(response, target):
    try:
        fault_type_pred = response.split(',')[0].split('is ')[1].replace(' ', '')
        numbers = re.findall(r'\d+', response)
        fault_loc_pred = np.asarray(numbers).astype(int)
        fault_type_gt = target.split('is ')[1].split(',')[0]
        numbers = re.findall(r'\d+', target)
        fault_loc_gt = np.asarray(numbers).astype(int)
        print(f'type={fault_type_gt == fault_type_pred}, loc={(fault_loc_gt == fault_loc_pred).all()}')
        return fault_type_gt == fault_type_pred, (fault_loc_gt == fault_loc_pred).all()
    except:
        print('extraction failed')
        return False, False

if __name__ == '__main__':

    networks = [
        'IEEE14',
        'IEEE39',
        'IEEE57',
        'SG126',
        'IEEE300',
        # 'Texas2000'
    ]
    ppcs = {
        'IEEE14': case14(),
        'IEEE39': case39(),
        'IEEE57': case57(),
        'IEEE118': case118(),
        'SG126': case126(),
        'IEEE300': case300(),
        'Texas2000': case2000()
    }
    gcn_task = 'opf'
    tasks = [
        'opf',
        'fault_detect',
        'state_est',
        'lmp_pred',
    ]


    float_disc = True
    use_diffusion = False
    use_MoT = True
    use_lora = True
    node_parallel = False
    # llm_model = 'Qwen/Qwen3-1.7B'
    # llm_model = 'google/gemma-3-1b-pt'
    # llm_model = 'meta-llama/Llama-3.2-1b'  # meta-llama/Llama-3.2-1b or Llama-3.1-8B or google/gemma-3-1b-pt
    # llm_model = 'meta-llama/Llama-3.2-3b'
    llm_model = 'meta-llama/Llama-3.1-8b'
    graph_input_dim = 64
    graph_hidden_dim = 1024
    data_amount = None
    prefix = f'{llm_model.split("/")[1]}-{graph_input_dim}-{graph_hidden_dim}' \
             f'{"-diffusion" if use_diffusion else ""}' \
             f'{"-MoT" if use_MoT else ""}' \
             f'{"-"+str(data_amount) if data_amount is not None else ""}' \
             f'-RoPE6'

    # Evaluate GLA model
    Model = GLA_Model_MoT if use_MoT else GLA_Model_v2
    model = Model(graph_input_dim=graph_input_dim, graph_hidden_dim=graph_hidden_dim, language_model_name=f'{llm_model}', use_lora=use_lora,
                  float_disc=float_disc, use_diffusion=use_diffusion, node_parallel=False).cuda()
    model.load(model_path=f'./saved_models/openglav2_2_{prefix}.p',
               tokenizer_path=f'./saved_models/openglav2_tokenizer_2_{prefix}/')
    samples = data_prepare_gla_v2(networks, ppcs, tasks, float_disc=float_disc, is_test=True, data_amount=data_amount)

    type_acc, loc_acc, fault_loc_inf_t = {}, {}, {}
    llm_dones, target_dones, optimal_gaps, opf_inf_t = {}, {}, {}, {}
    transient_pred_nrmses, transient_pred_rsquares, pmu_pred_inf_t = {}, {}, {}
    state_est_nrmses, state_est_rsquares, state_est_inf_t = {}, {}, {}
    lmp_nrmses, lmp_rsquares, lmp_pred_inf_t = {}, {}, {}
    for network in networks:
        type_acc[network], loc_acc[network], fault_loc_inf_t[network], \
        llm_dones[network], target_dones[network], optimal_gaps[network], opf_inf_t[network], \
        transient_pred_nrmses[network], transient_pred_rsquares[network], pmu_pred_inf_t[network], \
        state_est_nrmses[network], state_est_rsquares[network], state_est_inf_t[network], \
        lmp_nrmses[network], lmp_rsquares[network], lmp_pred_inf_t[network] \
            = [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []
    for (graph_data, edge_index, instruction, target, float_max, float_min, idx) in samples:
        edge_index = edge_index.long().cuda().unsqueeze(0).permute(0, 2, 1).squeeze(0)
        graph_data = graph_data.unsqueeze(0).float().cuda()
        x = time.time()
        outputs = model(graph_data, edge_index, instruction)
        if 'fault type and fault location' in instruction:
            pred = model.decode(outputs)
            if pred == '<|end_of_text|>':
                import ipdb
                ipdb.set_trace()
            network = instruction.split('in ')[1].split(' bus')[0]
            if network == 'IEEE118':
                network = 'SG126'
            fault_loc_inf_t[network].append(time.time() - x)
            print(f'gla fault_detect {network} inference time={time.time()-x}')
            type_pred_res, loc_pred_res = eval_pmu_clf(pred, target)
            type_acc[network].append(type_pred_res)
            loc_acc[network].append(loc_pred_res)
        if 'best active power setpoint' in instruction:
            network = instruction.split('in ')[1].split(' bus')[0]
            opf_inf_t[network].append(time.time() - x)
            print(f'gla opf {network} inference time={time.time()-x}')
            env = RL4Grid.make_gridsim({network: ppcs[network]}, deterministic=True)
            llm_done, target_done, optimal_gap = eval_opf(outputs, target, env, idx, float_max, float_min, float_disc)
            llm_dones[network].append(llm_done)
            target_dones[network].append(target_done)
            optimal_gaps[network].append(optimal_gap)
        if 'what are the predictions of' in instruction:
            network = instruction.split('in ')[1].split(' bus')[0]
            if network == 'IEEE118':
                network = 'SG126'
            pmu_pred_inf_t[network].append(time.time() - x)
            print(f'gla transient_pred {network} inference time={time.time()-x}')
            # bus_num = int(instruction.split('IEEE')[1].split(' bus')[0])
            for key in ("IEEE", "SG", "Texas"):
                if key in instruction:
                    bus_num = int(instruction.split(key)[1].split(' bus', 1)[0])
            nrmse, rsquare = eval_pmu_pred(outputs, target, float_max, float_min, bus_num)
            transient_pred_nrmses[network].append(nrmse)
            transient_pred_rsquares[network].append(rsquare)
        if 'real states of voltage magnitude and phase angles' in instruction:
            network = instruction.split('in ')[1].split(' bus')[0]
            state_est_inf_t[network].append(time.time() - x)
            print(f'gla state_est {network} inference time={time.time()-x}')
            for key in ("IEEE", "SG", "Texas"):
                if key in instruction:
                    bus_num = int(instruction.split(key)[1].split(' bus', 1)[0])
            nrmse, rsquare = eval_state_estimation(outputs, target, float_max, float_min, bus_num)
            state_est_nrmses[network].append(nrmse)
            state_est_rsquares[network].append(rsquare)
        if 'locational marginal price' in instruction:
            network = instruction.split('in ')[1].split(' bus')[0]
            lmp_pred_inf_t[network].append(time.time() - x)
            print(f'gla lmp_pred {network} inference time={time.time()-x}')
            for key in ("IEEE", "SG", "Texas"):
                if key in instruction:
                    bus_num = int(instruction.split(key)[1].split(' bus', 1)[0])
            nrmse, rsquare = eval_LMP_prediction(outputs, target, float_max, float_min, bus_num)
            lmp_nrmses[network].append(nrmse)
            lmp_rsquares[network].append(rsquare)

    for network in networks:
        print('*********************************************')
        print(f'opengla_2_{prefix}')
        print(
            f'network={network}, opf_pass_rate={1 - np.asarray(llm_dones[network]).mean():.3f}, optimal_gap={np.asarray(optimal_gaps[network]).mean():.3f}, opf_inf_t={np.asarray(opf_inf_t[network]).mean():.3f}, '
            f'type_acc={np.asarray(type_acc[network]).mean():.3f}, loc_acc={np.asarray(loc_acc[network]).mean():.3f}, fault_loc_inf_t={np.asarray(fault_loc_inf_t[network]).mean():.3f}, '
            f'state_est_nrmse={np.asarray(state_est_nrmses[network]).mean():.3f}, state_est_rsquare={np.asarray(state_est_rsquares[network]).mean():.3f}, state_est_inf_t={np.asarray(state_est_inf_t[network]).mean():.3f}, '
            f'lmp_nrmse={np.asarray(lmp_nrmses[network]).mean():.3f}, lmp_rsquare={np.asarray(lmp_rsquares[network]).mean():.3f}, lmp_inf_t={np.asarray(lmp_pred_inf_t[network]).mean():.3f}, '
            f'transient_pred_nrmse={np.asarray(transient_pred_nrmses[network]).mean()}, transient_pred_rsquare={np.asarray(transient_pred_rsquares[network]).mean()}, pmu_pred_inf_t={np.asarray(pmu_pred_inf_t[network]).mean():.3f}'
            # f'non_empty_rate={non_empty / len(samples)}'
        )
