from llm.model import LLM_Model
from llm.data_preprocess import *
from utils import *
import RL4Grid
import time


def eval_opf(response, target, env, i, max_float, float_disc):
    # import ipdb
    # ipdb.set_trace()
    num_gen = env.env.ppc['num_gen']
    answer = response.split('<end_of_text>')[0]
    vec = extract_integers(answer, max_float) if float_disc else extract_floats(answer)
    target = target.split("The best setpoint is ")[-1]
    target_vec = extract_integers(target, max_float) if float_disc else extract_floats(target)
    obs = env.reset(start_sample_idx=i)
    # threshold_ids = np.where(obs.gen_p == env.env.ppc['min_gen_p'])[0].tolist()
    # for k in threshold_ids:
    #     if k in env.env.ppc['thermal_ids']:
    #         if target_vec[k] >= env.env.ppc['min_gen_p'][k] / 2:
    #             target_vec[k] = env.env.ppc['min_gen_p'][k]
    action_high = obs.action_space['adjust_gen_p'].high
    action_low = obs.action_space['adjust_gen_p'].low
    best_a = np.asarray(target_vec) - np.asarray(obs.gen_p)
    best_a = best_a.clip(action_low, action_high)
    for k in env.env.ppc['thermal_ids']:
        if obs.gen_p[k] - env.env.ppc['min_gen_p'][k] < 1:
            if best_a[k] < 0:
                best_a[k] = 0
    adjust_gen_v = np.zeros(num_gen)
    _, best_reward, target_done, info = env.step({'adjust_gen_p': best_a, 'adjust_gen_v': adjust_gen_v})
    if target_done:
        import ipdb
        ipdb.set_trace()
    obs = env.reset(start_sample_idx=i)
    # threshold_ids = np.where(obs.gen_p == env.env.ppc['min_gen_p'])[0].tolist()
    # for k in threshold_ids:
    #     if k in env.env.ppc['thermal_ids']:
    #         if vec[k] >= env.env.ppc['min_gen_p'][k] / 2:
    #             vec[k] = env.env.ppc['min_gen_p'][k]
    try:
        llm_a = np.asarray(vec[:num_gen]) - np.asarray(obs.gen_p)
    except:
        # gen_p = np.asarray(obs.gen_p)[len(vec):num_gen]
        # llm_a = np.concatenate((vec, gen_p)) - np.asarray(obs.gen_p)
        # import ipdb
        # ipdb.set_trace()
        print('dimension is too small')
        return True, target_done, best_reward
    llm_a = llm_a.clip(action_low, action_high)
    for k in env.env.ppc['thermal_ids']:
        if obs.gen_p[k] - env.env.ppc['min_gen_p'][k] < 1:
            if llm_a[k] < 0:
                llm_a[k] = 0
    adjust_gen_v = np.zeros(num_gen)
    _, reward, llm_done, info = env.step({'adjust_gen_p': llm_a, 'adjust_gen_v': adjust_gen_v})
    if llm_done:
        import ipdb
        ipdb.set_trace()
    optimal_gap = (best_reward - reward) / 3
    print(f'llm_done={llm_done}, target_done={target_done}, opt_gap={optimal_gap:.4f}')
    return llm_done, target_done, optimal_gap

def eval_pmu_pred(response, target, max_float, bus_number):
    target_pred = extract_integers(target.split('predicted curves are ')[1], max_float) if float_disc else extract_floats(target.split('are: ')[1])
    target_pred = target_pred.reshape(bus_number, -1)
    pred_length = target_pred.shape[1]
    ori_pred = extract_integers(response, max_float) if float_disc else extract_floats(text)
    try:
        print(f'len_ori_pred={len(ori_pred)}')
        llm_pred = ori_pred[:bus_number * pred_length].reshape(bus_number, pred_length)
    except:
        print(f'dimension fault, llm_pred_shape={ori_pred.shape}, expected={bus_number * pred_length}')
        llm_pred = np.concatenate((ori_pred, np.zeros(bus_number * pred_length - len(ori_pred))), axis=0).reshape(bus_number, pred_length)

    nrmse = np.sqrt(((target_pred - llm_pred) ** 2).sum() / (bus_number * pred_length)) / (target_pred.mean() + 1e-3) * 100
    rsquare = 1 - ((target_pred - llm_pred) ** 2).sum() / (((target_pred - target_pred.mean()) ** 2).sum() + 1e-3)
    print(f'bus_num={bus_number}, nrmse={nrmse}, rsquare={rsquare}')
    return nrmse, rsquare

def eval_renewable_pred(response, target, max_float):
    # ipdb.set_trace()
    future_hours = 6
    samples_per_hour = 60
    try:
        pred = response.split('assistant:')[1].split('<|im_end|>')[0]
        wind_text = pred.split('wind power is')[1].split('solar')[0]
        wind_pred = extract_integers(wind_text, max_float) if float_disc else extract_floats(wind_text)
        solar_text = pred.split('solar power is')[1]
        solar_pred = extract_integers(solar_text, max_float) if float_disc else extract_floats(solar_text)
    except:
        print('renewable split error')
        wind_pred = -1 * np.ones(future_hours * samples_per_hour)
        solar_pred = -1 * np.ones(future_hours * samples_per_hour)

    try:
        # assert wind_pred.shape[0] >= future_hours * samples_per_hour
        llm_wind_pred = wind_pred[:future_hours * samples_per_hour].reshape(future_hours, samples_per_hour)
    except:
        print(f'dimension fault, llm_wind_pred_shape={wind_pred.shape}, expected={future_hours * samples_per_hour}')
        llm_wind_pred = np.concatenate((wind_pred, -1 * np.ones(future_hours * samples_per_hour - len(wind_pred))), axis=0).reshape(future_hours, samples_per_hour)
    try:
        # assert solar_pred.shape[0] >= future_hours * samples_per_hour
        llm_solar_pred = solar_pred[:future_hours * samples_per_hour].reshape(future_hours, samples_per_hour)
    except:
        print(f'dimension fault, llm_solar_pred_shape={solar_pred.shape}, expected={future_hours * samples_per_hour}')
        llm_solar_pred = np.concatenate((solar_pred, -1 * np.ones(future_hours * samples_per_hour - len(solar_pred))), axis=0).reshape(future_hours, samples_per_hour)
    target_wind_text = target.split('wind power is')[1].split('solar')[0]
    target_wind_pred = extract_integers(target_wind_text, max_float) if float_disc else extract_floats(target_wind_text)
    target_wind_pred = target_wind_pred.reshape(future_hours, samples_per_hour)
    target_solar_text = target.split('solar power is')[1]
    target_solar_pred = extract_integers(target_solar_text, max_float) if float_disc else extract_floats(target_solar_text)
    target_solar_pred = target_solar_pred.reshape(future_hours, samples_per_hour)
    wind_mse = ((target_wind_pred - llm_wind_pred) ** 2).mean()
    solar_mse = ((target_solar_pred - llm_solar_pred) ** 2).mean()
    print(f'wind_mse={wind_mse:.4f}, solar_mse={solar_mse:.4f}')
    return wind_mse, solar_mse

def eval_state_estimation(response, target, max_float, bus_num):
    try:
        pred = extract_integers(response, max_float)[:2*bus_num].reshape(-1, bus_num)
    except:
        pred = extract_integers(response, max_float)
        print(f'dimension fault, pred_shape={len(pred)}, expected={2*bus_num}')
        pred = np.concatenate((pred, -1 * np.ones(2*bus_num - len(pred))),axis=0).reshape(-1, bus_num)
    target = extract_integers(target, max_float).reshape(-1, bus_num)
    # V_mag_mse = ((pred[0] - target[0]) ** 2).mean()
    # V_ang_mse = ((pred[1] - target[1]) ** 2).mean()
    # print(f'V_mag_mse={V_mag_mse:.4f}, V_ang_mse={V_ang_mse:.4f}')
    # return V_mag_mse, V_ang_mse
    nrmse = np.sqrt(((target - pred) ** 2).sum() / (bus_num * 2)) / (
                target.mean() + 1e-3) * 100
    rsquare = 1 - ((target - pred) ** 2).sum() / (((target - target.mean()) ** 2).sum() + 1e-3)
    print(f'bus_num={bus_num}, nrmse={nrmse}, rsquare={rsquare}')
    return nrmse, rsquare


def eval_LMP_prediction(response, target, max_float, bus_num):
    try:
        pred = extract_integers(response, max_float)[:2*bus_num].reshape(-1, bus_num)
    except:
        pred = extract_integers(response, max_float)
        print(f'dimension fault, pred_shape={len(pred)}, expected={2*bus_num}')
        pred = np.concatenate((pred, -1 * np.ones(2*bus_num - len(pred))),axis=0).reshape(-1, bus_num)
    target = extract_integers(target, max_float).reshape(-1, bus_num)
    # V_mag_mse = ((pred[0] - target[0]) ** 2).mean()
    # V_ang_mse = ((pred[1] - target[1]) ** 2).mean()
    # print(f'V_mag_mse={V_mag_mse:.4f}, V_ang_mse={V_ang_mse:.4f}')
    # return V_mag_mse, V_ang_mse
    nrmse = np.sqrt(((target - pred) ** 2).sum() / (bus_num * 2)) / (
                target.mean() + 1e-3) * 100
    rsquare = 1 - ((target - pred) ** 2).sum() / (((target - target.mean()) ** 2).sum() + 1e-3)
    print(f'bus_num={bus_num}, nrmse={nrmse}, rsquare={rsquare}')
    return nrmse, rsquare


def eval_pmu_clf(response, target):
    try:
        fault_type_pred = response.split(',')[0].split('is ')[1].replace(' ', '')
        numbers = re.findall(r'\d+', response)
        fault_loc_pred = np.asarray(numbers).astype(int)
        fault_type_gt = target.split('is ')[1].split(',')[0]
        numbers = re.findall(r'\d+', target)
        fault_loc_gt = np.asarray(numbers).astype(int)
        print(f'type={fault_type_gt == fault_type_pred}, loc={(fault_loc_gt == fault_loc_pred).all()}')
        return fault_type_gt == fault_type_pred, (fault_loc_gt == fault_loc_pred).all()
    except:
        print('extraction failed')
        return False, False

if __name__ == '__main__':
    use_lora = True
    float_disc = True
    networks = [
        'IEEE14',
        'IEEE39',
        'IEEE57',
        'SG126',
        'IEEE300'  # not available if using 8B model
    ]
    gcn_task = 'opf'
    tasks = [
        'opf',
        'fault_detect',
        'state_est',
        'lmp_pred',
        # 'transient_pred',
    ]


    # Evaluate pure LLM
    model = LLM_Model(language_model_name='meta-llama/Llama-3.2-1B', use_lora=True).cuda()
    model.load(model_path='./saved_models/llm.p',
               tokenizer_path='./saved_models/llm_tokenizer/')
    samples = data_prepare_llm(networks, tasks, is_test=True)

    # import ipdb
    # ipdb.set_trace()
    type_acc, loc_acc, pred_nrmses, pred_rsquares = [], [], [], []
    llm_dones, target_dones, optimal_gaps = [], [], []
    non_empty = 0
    for (instruction, target, float_max, idx) in samples:
        x = time.time()
        outputs = model(instruction)
        pred = model.decode(outputs)
        print(f'llm inference time={time.time()-x}')
        if pred != '<|end_of_text|>':
            non_empty += 1
        if pred == '<|end_of_text|>':
            import ipdb
            ipdb.set_trace()
        if 'fault type and fault location' in instruction:
            type_pred_res, loc_pred_res = eval_pmu_clf(pred.split('?')[1], target)
            type_acc.append(type_pred_res)
            loc_acc.append(loc_pred_res)
        if 'best active power setpoint' in instruction:
            network = instruction.split('of ')[1].split(' bus')[0]
            env = RL4Grid.make_gridsim(network)
            llm_done, target_done, optimal_gap = eval_opf(pred.split('?')[1], target, env, idx, float_max, float_disc)
            llm_dones.append(llm_done)
            target_dones.append(target_done)
            optimal_gaps.append(optimal_gap)
        if 'following voltage curves' in instruction:
            bus_num = int(instruction.split('IEEE')[1].split(' system')[0])
            nrmse, rsquare = eval_pmu_pred(pred.split('?')[1], target, float_max, bus_num)
            pred_nrmses.append(nrmse)
            pred_rsquares.append(rsquare)
        if 'real states of voltage magnitude and phase angles' in instruction:
            # import ipdb
            # ipdb.set_trace()
            try:
                bus_num = int(instruction.split('IEEE')[1].split(' bus')[0])
            except:
                bus_num = int(instruction.split('SG')[1].split(' bus')[0])
            nrmse, rsquare = eval_state_estimation(pred.split('?')[1], target, float_max, bus_num)
            pred_nrmses.append(nrmse)
            pred_rsquares.append(rsquare)
        if 'locational marginal price' in instruction:
            # import ipdb
            # ipdb.set_trace()
            try:
                bus_num = int(instruction.split('IEEE')[1].split(' bus')[0])
            except:
                bus_num = int(instruction.split('SG')[1].split(' bus')[0])
            nrmse, rsquare = eval_LMP_prediction(pred.split('?')[1], target, float_max, bus_num)
            pred_nrmses.append(nrmse)
            pred_rsquares.append(rsquare)

    print(f'opf_pass_rate={1 - np.asarray(llm_dones).mean()}, optimal_gap={np.asarray(optimal_gaps).mean()}, '
          f'type_acc={np.asarray(type_acc).mean()}, '
          f'loc_acc={np.asarray(loc_acc).mean()}, pred_nrmse={np.asarray(pred_nrmses).mean()},'
          f'pred_rsquare={np.asarray(pred_rsquares).mean()}, '
          f'non_empty_rate={non_empty / len(samples)}')

    import ipdb
    ipdb.set_trace()