from gla.model import GLA_Model
from gla.data_preprocess import data_prepare_gla
from utils import *
import RL4Grid
import time


def eval_opf(response, target, env, i, max_float, float_disc):
    num_gen = env.env.ppc['num_gen']
    answer = response.split('<end_of_text>')[0]
    vec = extract_integers(answer, max_float) if float_disc else extract_floats(answer)
    target = target.split("The best setpoint is ")[-1]
    target_vec = extract_integers(target, max_float) if float_disc else extract_floats(target)
    obs = env.reset(start_sample_idx=i)
    action_high = obs.action_space['adjust_gen_p'].high
    action_low = obs.action_space['adjust_gen_p'].low
    best_a = np.asarray(target_vec) - np.asarray(obs.gen_p)
    best_a = best_a.clip(action_low, action_high)
    for k in env.env.ppc['thermal_ids']:
        if obs.gen_p[k] - env.env.ppc['min_gen_p'][k] < 1:
            if best_a[k] < 0:
                best_a[k] = 0
    _, best_reward, target_done, info = env.step(best_a)
    if target_done:
        import ipdb
        ipdb.set_trace()
    obs = env.reset(start_sample_idx=i)
    try:
        llm_a = np.asarray(vec[:num_gen]) - np.asarray(obs.gen_p)
    except:
        print('dimension is too small')
        return True, target_done, best_reward
    llm_a = llm_a.clip(action_low, action_high)
    for k in env.env.ppc['thermal_ids']:
        if obs.gen_p[k] - env.env.ppc['min_gen_p'][k] < 1:
            if llm_a[k] < 0:
                llm_a[k] = 0
    _, reward, llm_done, info = env.step(llm_a)
    optimal_gap = (best_reward - reward) / 3
    print(f'llm_done={llm_done}, target_done={target_done}, opt_gap={optimal_gap:.4f}')
    return llm_done, target_done, optimal_gap

def eval_pmu_pred(response, target, max_float, bus_number):
    target_pred = extract_integers(target.split('predicted curves are ')[1], max_float) if float_disc else extract_floats(target.split('are: ')[1])
    target_pred = target_pred.reshape(bus_number, -1)
    pred_length = target_pred.shape[1]
    ori_pred = extract_integers(response, max_float) if float_disc else extract_floats(text)
    try:
        print(f'len_ori_pred={len(ori_pred)}')
        llm_pred = ori_pred[:bus_number * pred_length].reshape(bus_number, pred_length)
    except:
        print(f'dimension fault, llm_pred_shape={ori_pred.shape}, expected={bus_number * pred_length}')
        llm_pred = np.concatenate((ori_pred, np.zeros(bus_number * pred_length - len(ori_pred))), axis=0).reshape(bus_number, pred_length)

    nrmse = np.sqrt(((target_pred - llm_pred) ** 2).sum() / (bus_number * pred_length)) / (target_pred.mean() + 1e-3) * 100
    rsquare = 1 - ((target_pred - llm_pred) ** 2).sum() / (((target_pred - target_pred.mean()) ** 2).sum() + 1e-3)
    print(f'bus_num={bus_number}, nrmse={nrmse}, rsquare={rsquare}')
    return nrmse, rsquare

def eval_renewable_pred(response, target, max_float):
    # ipdb.set_trace()
    future_hours = 6
    samples_per_hour = 60
    try:
        pred = response.split('assistant:')[1].split('<|im_end|>')[0]
        wind_text = pred.split('wind power is')[1].split('solar')[0]
        wind_pred = extract_integers(wind_text, max_float) if float_disc else extract_floats(wind_text)
        solar_text = pred.split('solar power is')[1]
        solar_pred = extract_integers(solar_text, max_float) if float_disc else extract_floats(solar_text)
    except:
        print('renewable split error')
        wind_pred = -1 * np.ones(future_hours * samples_per_hour)
        solar_pred = -1 * np.ones(future_hours * samples_per_hour)

    try:
        # assert wind_pred.shape[0] >= future_hours * samples_per_hour
        llm_wind_pred = wind_pred[:future_hours * samples_per_hour].reshape(future_hours, samples_per_hour)
    except:
        print(f'dimension fault, llm_wind_pred_shape={wind_pred.shape}, expected={future_hours * samples_per_hour}')
        llm_wind_pred = np.concatenate((wind_pred, -1 * np.ones(future_hours * samples_per_hour - len(wind_pred))), axis=0).reshape(future_hours, samples_per_hour)
    try:
        # assert solar_pred.shape[0] >= future_hours * samples_per_hour
        llm_solar_pred = solar_pred[:future_hours * samples_per_hour].reshape(future_hours, samples_per_hour)
    except:
        print(f'dimension fault, llm_solar_pred_shape={solar_pred.shape}, expected={future_hours * samples_per_hour}')
        llm_solar_pred = np.concatenate((solar_pred, -1 * np.ones(future_hours * samples_per_hour - len(solar_pred))), axis=0).reshape(future_hours, samples_per_hour)
    target_wind_text = target.split('wind power is')[1].split('solar')[0]
    target_wind_pred = extract_integers(target_wind_text, max_float) if float_disc else extract_floats(target_wind_text)
    target_wind_pred = target_wind_pred.reshape(future_hours, samples_per_hour)
    target_solar_text = target.split('solar power is')[1]
    target_solar_pred = extract_integers(target_solar_text, max_float) if float_disc else extract_floats(target_solar_text)
    target_solar_pred = target_solar_pred.reshape(future_hours, samples_per_hour)
    wind_mse = ((target_wind_pred - llm_wind_pred) ** 2).mean()
    solar_mse = ((target_solar_pred - llm_solar_pred) ** 2).mean()
    print(f'wind_mse={wind_mse:.4f}, solar_mse={solar_mse:.4f}')
    return wind_mse, solar_mse

def eval_state_estimation(response, target, max_float, bus_num):
    try:
        pred = extract_integers(response, max_float)[:2*bus_num].reshape(-1, bus_num)
    except:
        pred = extract_integers(response, max_float)
        print(f'dimension fault, pred_shape={len(pred)}, expected={2*bus_num}')
        pred = np.concatenate((pred, -1 * np.ones(2*bus_num - len(pred))),axis=0).reshape(-1, bus_num)
    target = extract_integers(target, max_float).reshape(-1, bus_num)
    # V_mag_mse = ((pred[0] - target[0]) ** 2).mean()
    # V_ang_mse = ((pred[1] - target[1]) ** 2).mean()
    # print(f'V_mag_mse={V_mag_mse:.4f}, V_ang_mse={V_ang_mse:.4f}')
    # return V_mag_mse, V_ang_mse
    nrmse = np.sqrt(((target - pred) ** 2).sum() / (bus_num * 2)) / (
                target.mean() + 1e-3) * 100
    rsquare = 1 - ((target - pred) ** 2).sum() / (((target - target.mean()) ** 2).sum() + 1e-3)
    print(f'bus_num={bus_num}, nrmse={nrmse}, rsquare={rsquare}')
    return nrmse, rsquare


def eval_LMP_prediction(response, target, max_float, bus_num):
    try:
        pred = extract_integers(response, max_float)[:2*bus_num].reshape(-1, bus_num)
    except:
        pred = extract_integers(response, max_float)
        print(f'dimension fault, pred_shape={len(pred)}, expected={2*bus_num}')
        pred = np.concatenate((pred, -1 * np.ones(2*bus_num - len(pred))),axis=0).reshape(-1, bus_num)
    target = extract_integers(target, max_float).reshape(-1, bus_num)
    # V_mag_mse = ((pred[0] - target[0]) ** 2).mean()
    # V_ang_mse = ((pred[1] - target[1]) ** 2).mean()
    # print(f'V_mag_mse={V_mag_mse:.4f}, V_ang_mse={V_ang_mse:.4f}')
    # return V_mag_mse, V_ang_mse
    nrmse = np.sqrt(((target - pred) ** 2).sum() / (bus_num * 2)) / (
                target.mean() + 1e-3) * 100
    rsquare = 1 - ((target - pred) ** 2).sum() / (((target - target.mean()) ** 2).sum() + 1e-3)
    print(f'bus_num={bus_num}, nrmse={nrmse}, rsquare={rsquare}')
    return nrmse, rsquare


def eval_pmu_clf(response, target):
    try:
        fault_type_pred = response.split(',')[0].split('is ')[1].replace(' ', '')
        numbers = re.findall(r'\d+', response)
        fault_loc_pred = np.asarray(numbers).astype(int)
        fault_type_gt = target.split('is ')[1].split(',')[0]
        numbers = re.findall(r'\d+', target)
        fault_loc_gt = np.asarray(numbers).astype(int)
        print(f'type={fault_type_gt == fault_type_pred}, loc={(fault_loc_gt == fault_loc_pred).all()}')
        return fault_type_gt == fault_type_pred, (fault_loc_gt == fault_loc_pred).all()
    except:
        print('extraction failed')
        return False, False

if __name__ == '__main__':
    use_lora = True
    float_disc = True
    networks = [
        # 'IEEE14',
        # 'IEEE39',
        # 'IEEE57',
        # 'SG126',
        # 'IEEE300',  # not available if using 8B model including transient pred
        'Texas2000'
    ]
    tasks = [
        'opf',
        # 'fault_detect',
        # 'state_est',
        # 'lmp_pred',
        # 'transient_pred',
    ]

    # llm_model = 'google/gemma-3-1b-pt'  # meta-llama/Llama-3.2-1b or Llama-3.1-8B or google/gemma-3-1b-pt
    llm_model = 'meta-llama/Llama-3.2-1b'  # meta-llama/Llama-3.2-1b or Llama-3.1-8B or google/gemma-3-1b-pt
    graph_input_dim = 64
    graph_hidden_dim = 1024
    data_amount = 2048
    prefix = f'{llm_model.split("/")[1]}-{graph_input_dim}-{graph_hidden_dim}-D{data_amount}'

    model = GLA_Model(graph_input_dim=64, graph_hidden_dim=1024, language_model_name=f'{llm_model}', use_lora=use_lora, float_disc=float_disc).cuda()
    model.load(model_path=f'./saved_models/opengla_2_{prefix}.p',
               tokenizer_path=f'./saved_models/opengla_tokenizer_2_{prefix}/')
    samples = data_prepare_gla(networks, tasks, data_amount=data_amount, float_disc=float_disc, is_test=True,
                               # root_path='/workspace/RL4Grid/'
                               )

    # import ipdb
    # ipdb.set_trace()
    type_acc, loc_acc, fault_loc_inf_t = {}, {}, {}
    llm_dones, target_dones, optimal_gaps, opf_inf_t = {}, {}, {}, {}
    transient_pred_nrmses, transient_pred_rsquares, pmu_pred_inf_t = {}, {}, {}
    state_est_nrmses, state_est_rsquares, state_est_inf_t = {}, {}, {}
    lmp_nrmses, lmp_rsquares, lmp_pred_inf_t = {}, {}, {}
    for network in networks:
        type_acc[network], loc_acc[network], fault_loc_inf_t[network], \
        llm_dones[network], target_dones[network], optimal_gaps[network], opf_inf_t[network], \
        transient_pred_nrmses[network], transient_pred_rsquares[network], pmu_pred_inf_t[network], \
        state_est_nrmses[network], state_est_rsquares[network], state_est_inf_t[network], \
        lmp_nrmses[network], lmp_rsquares[network], lmp_pred_inf_t[network] \
            = [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []
    non_empty = 0
    for (graph_data, edge_index, instruction, target, float_max, idx) in samples:
        edge_index = edge_index.long().cuda().unsqueeze(0).permute(0, 2, 1).squeeze(0)
        graph_data = graph_data.unsqueeze(0).float().cuda()
        x = time.time()
        outputs = model(graph_data, edge_index, instruction)
        pred = model.decode(outputs)
        if pred != '<|end_of_text|>':
            non_empty += 1
        if pred == '<|end_of_text|>':
            import ipdb
            ipdb.set_trace()
        if 'fault type and fault location' in instruction:
            network = instruction.split('in ')[1].split(' systems')[0]
            if network == 'IEEE118':
                network = 'SG126'
            fault_loc_inf_t[network].append(time.time() - x)
            print(f'gla fault_detect {network} inference time={time.time()-x:.3f}')
            type_pred_res, loc_pred_res = eval_pmu_clf(pred, target)
            type_acc[network].append(type_pred_res)
            loc_acc[network].append(loc_pred_res)
        if 'best active power setpoint' in instruction:
            network = instruction.split('of ')[1].split(' bus')[0]
            opf_inf_t[network].append(time.time() - x)
            print(f'gla opf {network} inference time={time.time()-x}')
            env = RL4Grid.make_gridsim(network)
            llm_done, target_done, optimal_gap = eval_opf(pred, target, env, idx, float_max, float_disc)
            llm_dones[network].append(llm_done)
            target_dones[network].append(target_done)
            optimal_gaps[network].append(optimal_gap)
        if 'following voltage curves' in instruction:
            network = instruction.split('in ')[1].split(' system')[0]
            if network == 'IEEE118':
                network = 'SG126'
            pmu_pred_inf_t[network].append(time.time() - x)
            print(f'gla transient_pred {network} inference time={time.time()-x}')
            bus_num = int(instruction.split('IEEE')[1].split(' system')[0])
            nrmse, rsquare = eval_pmu_pred(pred, target, float_max, bus_num)
            transient_pred_nrmses[network].append(nrmse)
            transient_pred_rsquares[network].append(rsquare)
        if 'real states of voltage magnitude and phase angles' in instruction:
            network = instruction.split('in ')[1].split(' bus system')[0]
            state_est_inf_t[network].append(time.time() - x)
            print(f'gla state_est {network} inference time={time.time()-x}')
            try:
                bus_num = int(instruction.split('IEEE')[1].split(' bus')[0])
            except:
                bus_num = int(instruction.split('SG')[1].split(' bus')[0])
            nrmse, rsquare = eval_state_estimation(pred, target, float_max, bus_num)
            state_est_nrmses[network].append(nrmse)
            state_est_rsquares[network].append(rsquare)
        if 'locational marginal price' in instruction:
            network = instruction.split('of ')[1].split(' bus')[0]
            lmp_pred_inf_t[network].append(time.time() - x)
            print(f'gla lmp_pred {network} inference time={time.time()-x}')
            try:
                bus_num = int(instruction.split('IEEE')[1].split(' bus')[0])
            except:
                bus_num = int(instruction.split('SG')[1].split(' bus')[0])
            nrmse, rsquare = eval_LMP_prediction(pred, target, float_max, bus_num)
            lmp_nrmses[network].append(nrmse)
            lmp_rsquares[network].append(rsquare)

    for network in networks:
        print('*********************************************')
        print(f'opengla_1_{prefix}')
        print(f'network={network}, opf_pass_rate={1-np.asarray(llm_dones[network]).mean():.3f}, optimal_gap={np.asarray(optimal_gaps[network]).mean():.3f}, opf_inf_t={np.asarray(opf_inf_t[network]).mean():.3f}, '
              f'type_acc={np.asarray(type_acc[network]).mean():.3f}, loc_acc={np.asarray(loc_acc[network]).mean():.3f}, fault_loc_inf_t={np.asarray(fault_loc_inf_t[network]).mean():.3f}, '
              f'state_est_nrmse={np.asarray(state_est_nrmses[network]).mean():.3f}, state_est_rsquare={np.asarray(state_est_rsquares[network]).mean():.3f}, state_est_inf_t={np.asarray(state_est_inf_t[network]).mean():.3f}, '
              f'lmp_nrmse={np.asarray(lmp_nrmses[network]).mean():.3f}, lmp_rsquare={np.asarray(lmp_rsquares[network]).mean():.3f}, lmp_inf_t={np.asarray(lmp_pred_inf_t[network]).mean():.3f}, '
              # f'transient_pred_nrmse={np.asarray(pred_nrmses).mean()}, transient_pred_rsquare={np.asarray(pred_rsquares).mean()}, pmu_pred_inf_t={np.asarray(pmu_pred_inf_t[network]).mean():.3f}'
              f'non_empty_rate={non_empty / len(samples)}')
