from utils import GPT, load_latest_realista_result_pkl, set_seed, SAVE_RESULT_DICT
from config import SAVE_PATH, RED_BACKGROUND, GREEN_BACKGROUND, YELLOW_BACKGROUND, RESET
from tqdm import tqdm
from my_utils import feasibility_check, get_prompt, obj_fun_with_prompt, projection, latent_from_delta, reconstruct_from_latent, get_full_input_embeds, obj_fun
import torch
import torch.nn as nn

def realista(args, model, tokenizer, Z0, D_Z0, concepts, cur_task_dict, target_choice_index):
    '''REALISTA for finding editing strength vector delta'''
    device = model.device
    # Hyperparameters for Langevin Dynamics
    epsilon = args.epsilon # 0 ≤ δ_i; ‖δ‖_1≤ ε
    eta = args.eta
    max_iter = args.max_iter
    T = args.T0 # temperature in langevin dynamics   
    prompt_len = args.prompt_len  # max length of reconstructed prompt
    annealing_rate = args.annealing_rate
    trial_num = args.trial_num
    
    feasibility_checker_LLM = GPT("gpt-4.1-mini-2025-04-14")

    print("Langevin dynamics hyperparameters:")
    print(args)

    # ----- Load Stage 1 Results: Single Concept Initilization  -----
    stage1_data, _ = load_latest_realista_result_pkl(SAVE_PATH, args.model_type, args.mmlu_subject, 1)
    stage1_data_cur = stage1_data[args.mmlu_question_idx]

    stage1_obj_best = -float('inf')
    stage1_concept_key_best = None
    stage1_concept_idx_best = -1
    stage1_prompt_best = ""

    print("Original question and choices:")
    print(cur_task_dict['question'])

    top_idx_ls = []

    # Collect all candidates with their obj values
    all_candidates = []
    for concept_idx, concept_key in enumerate(stage1_data_cur.keys()):
        obj_value = stage1_data_cur[concept_key]['obj_value']  
        attack_prompt = stage1_data_cur[concept_key]['attack_prompt']
        all_candidates.append((obj_value, concept_idx, concept_key, attack_prompt))
    
    # Sort by obj_value descending
    all_candidates.sort(reverse=True, key=lambda x: x[0])
    
    # Check feasibility only until we have trial_num feasible prompts
    feasible_count = 0
    for obj_value, concept_idx, concept_key, attack_prompt in all_candidates:
        if feasible_count >= trial_num:
            break
            
        print(f"Checking feasibility for concept {concept_key} (obj: {obj_value:.4f})...")
        is_feasible = feasibility_check(
            attack_prompt, cur_task_dict['question'], cur_task_dict['choices'], 
            args.mmlu_subject, cur_task_dict['answer'], feasibility_checker_LLM)
        
        if is_feasible:
            print(f"{GREEN_BACKGROUND}Feasibility check passed for concept {concept_key}.{RESET}")
            top_idx_ls.append(concept_idx)
            feasible_count += 1
            
            if obj_value > stage1_obj_best:
                stage1_obj_best = obj_value
                stage1_concept_key_best = concept_key
                stage1_concept_idx_best = concept_idx
                stage1_prompt_best = attack_prompt
                print(f"New best obj value from stage 1: {obj_value:.4f}")
                print(f"Best attack prompt: {attack_prompt}")
        else:
            print(f"{RED_BACKGROUND}Feasibility check failed for concept {concept_key}.{RESET}")
    
    print(f"\nFound {feasible_count} feasible prompts out of {len(all_candidates)} candidates.")
    stage1_result_dict = {
        'best_concept_key': stage1_concept_key_best,
        'best_concept_idx': stage1_concept_idx_best,
        'best_obj_value': stage1_obj_best,
        'best_attack_prompt': stage1_prompt_best
    }
    
    print("----- Starting Stage 2: Projected Langevin Dynamics -----")

    result_dict_ls = []
    for trial_idx in tqdm(range(trial_num), desc="Trials"):
        print(f"=== Trial {trial_idx+1}/{trial_num} ===")
        n = D_Z0.shape[0]  # number of concepts
        save_result_dict = SAVE_RESULT_DICT(args, cur_task_dict, target_choice_index, epsilon, stage1_result_dict)
        
        prefix, suffix = get_prompt(cur_task_dict)
        obj_value, probs = obj_fun_with_prompt(args, prefix, suffix, cur_task_dict['question'], target_choice_index, model, tokenizer, cur_task_dict)
        obj_best = obj_value
        iter_best = 0

        save_result_dict.update(
                args, 0, obj_value, cur_task_dict['question'], nn.Parameter(torch.zeros(n, device=device, dtype=model.dtype)), probs, T, concepts, cur_task_dict, target_choice_index, prefix, suffix, model, tokenizer, obj_best, iter_best
            ) # record for step 0

        set_seed(args.seed + trial_idx)  # different seed per trial
        
        print(f"Initialize delta with concept from stage 1: {top_idx_ls[trial_idx]}")
        delta = nn.Parameter(torch.zeros(n, device=device, dtype=model.dtype).scatter_(0, torch.tensor([top_idx_ls[trial_idx]], device=device), epsilon))

        with torch.no_grad():
            delta[:] = projection(delta, epsilon)

        for step in tqdm(range(1, max_iter+1), desc="Steps"):
            if delta.grad is not None:
                delta.grad = None 

            # forward: delta -> Z_cur -> latent -> objective
            Z_cur = latent_from_delta(Z0, delta, D_Z0)
            
            cur_embeds, decoded_text, term_id_flag = reconstruct_from_latent(
                model, tokenizer, Z_cur, prompt_len
            )

            if term_id_flag == False:
                # not able to reconstruct the prompt
                print(f"Step failed. Add noise to delta and continue.")
                print(f"{YELLOW_BACKGROUND}Reconstructed prompt did not terminate properly.{RESET}")

                # Record results
                save_result_dict.update(
                    args, step, obj_value, decoded_text, delta, probs, T, concepts, cur_task_dict, target_choice_index, prefix, suffix, model, tokenizer, obj_best, iter_best
                )

                with torch.no_grad():
                    noise = (2 * eta * T) ** 0.5 * torch.randn_like(delta)
                    delta += noise
                    delta[:] = projection(delta, epsilon)
                continue
            
            full_input_embeds, _, _ = get_full_input_embeds(
                model, tokenizer, cur_task_dict, question_embeds=cur_embeds
            )
            
            obj_value, _ = obj_fun(args, full_input_embeds, target_choice_index, model, device)

            if torch.isnan(obj_value):
                print(f"{RED_BACKGROUND}Step failed with obj_value: {obj_value}{RESET}")
                print(decoded_text)
                print()
                break

            if obj_value > obj_best:
                print(f"Pass adversarial test. Check feasibility...")
                is_feasible = feasibility_check(
                    decoded_text, cur_task_dict['question'], cur_task_dict['choices'],
                    args.mmlu_subject, cur_task_dict['answer'],
                    feasibility_checker_LLM
                    )
                if is_feasible:
                    print(f"{GREEN_BACKGROUND}Feasibility check passed.{RESET}")
                    print(f"New best obj value: {obj_value.item():.4f}")
                    obj_best = obj_value
                    # iter_best = step
                else:
                    print(f"{RED_BACKGROUND}Feasibility check failed.{RESET}")
            
            # Record results
            save_result_dict.update(
                args, step, obj_value, decoded_text, delta, probs, T, concepts, cur_task_dict, target_choice_index, prefix, suffix, model, tokenizer, obj_best, iter_best
            )

            loss = -obj_value
            loss.backward()          # computes delta.grad

            with torch.no_grad():
                n = delta.numel()
                noise = (2 * eta * T) ** 0.5 * torch.randn_like(delta)
                gradient_update = - eta * delta.grad
                delta += gradient_update + noise
                delta[:] = projection(delta, epsilon)
                T = T * annealing_rate
        
        result_dict = save_result_dict.get_result_dict()

        print("Ground truth index:", result_dict['ground_truth_idx'])
        print("Target choice index:", result_dict['target_choice_index'])
        print(f"Best obj value: {result_dict['obj_best']:.4f} at iteration {result_dict['iter_best']}")
        print(f"A: {result_dict['probs'][result_dict['iter_best']][0]*100:.2f}%, B: {result_dict['probs'][result_dict['iter_best']][1]*100:.2f}%, C: {result_dict['probs'][result_dict['iter_best']][2]*100:.2f}%, D: {result_dict['probs'][result_dict['iter_best']][3]*100:.2f}%")

        result_dict_ls.append(result_dict)

    return  result_dict_ls