import json
import torch

def obj_fun(args, full_input_embeds, target_choice_index, model, device):
    '''This function comuptes the objective function of open-ended MCQA setting, which is the log likelihood of generating the target choice given the input prompt.'''

    outputs = model(inputs_embeds=full_input_embeds)

    probs = get_probs(args, outputs)
    obj_value = probs[target_choice_index]
    obj_value = torch.log(obj_value + 1e-10)
    return obj_value, probs

def get_full_input_embeds(model, tokenizer, cur_task_dict, question_embeds):
    """
    E.g., 
    current_task_dict = ds['test'][1]
    {'question': 'Infrared (IR) spectroscopy is useful for determining certain aspects of the structure of organic molecules because',
    'subject': 'college_chemistry',
    'choices': ['all molecular bonds absorb IR radiation',
    'IR peak intensities are related to molecular mass',
    'most organic functional groups absorb in a characteristic region of the IR spectrum',
    'each element absorbs at a characteristic wavelength'],
    'answer': 2}
    """

    prefix, suffix = get_prompt(cur_task_dict)
    # Tokenize parts surrounding the latent suffix
    prefix_ids = tokenizer(prefix, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    suffix_ids = tokenizer(suffix, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)


    # === Step 2: Get their embeddings ===
    prefix_embeds = model.model.embed_tokens(prefix_ids)
    suffix_embeds = model.model.embed_tokens(suffix_ids)

    full_input_embeds = torch.cat([prefix_embeds, question_embeds, suffix_embeds], dim=1).to(torch.float16)

    return full_input_embeds, prefix, suffix

def reconstruct_from_latent(model, tokenizer, latent, prompt_len=50, seed=42):

    torch.manual_seed(seed)  # set the seed for reproducibility

    prefix_text = """User: Please repeat the following message: {"user_message": """
    suffix_text = """}. Assistant: Sure, I will repeat the User message in the json format: {"user_message": " """

    prefix_ids = tokenizer(prefix_text, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    suffix_ids = tokenizer(suffix_text, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    
    prefix_embeds = model.model.embed_tokens(prefix_ids)
    suffix_embeds = model.model.embed_tokens(suffix_ids)

    full_input_embeds = torch.cat([prefix_embeds, latent, suffix_embeds], dim=1).to(torch.float16)
    
    # # Create matching attention mask
    attention_mask = torch.ones(full_input_embeds.size()[:-1], dtype=torch.long).to(model.device)
    
    def step_autoregressive(cur_embeds, cur_mask):
        # 1) Forward
        out = model(
            inputs_embeds=cur_embeds,
            attention_mask=cur_mask,
            output_hidden_states=False,
        )
        logits = out.logits[:, -1, :]         # [1, V], logits for next position
    
        # 2) Gumbel-softmax -> soft token (keeps gradient)
        p_next = torch.nn.functional.gumbel_softmax(
            logits, tau=1.0, hard=True, dim=-1
        )                                      # [1, V] tau is the temperature for gumbel softmax

    
        # 3) Soft embedding as mixture over vocab
        E = model.get_input_embeddings().weight
        next_embed = p_next @ E                # [1, d_model]
        next_embed = next_embed.unsqueeze(1)   # [1, 1, d_model]
    
        # 4) Append and extend mask
        new_embeds = torch.cat([cur_embeds, next_embed], dim=1)
        new_mask   = torch.cat(
            [cur_mask, torch.ones_like(cur_mask[:, :1])],
            dim=1
        )
        return new_embeds, new_mask, logits, p_next

    cur_embeds = full_input_embeds            # prefix from your latent z
    cur_mask   = attention_mask

    pt_ls = []
    # token_id_ls = []

    term_id_flag = False

    ENDING_CHARS = [
        # ASCII
        '"', "}", 
    
        # Unicode curly quotes (double)
        "”", 
    
        # Unicode curly quotes (single)
         "’", 
    
        # Prime-like (often appears instead of apostrophes)
        "′", "″", "‴", "⁗",
    
        # Full-width variants
        "＂", "＇", "｛",
    ]

    QUESTION_MARKS = ['?', '？']
    question_mark_flag = False

    decoded_text = ""
    
    for t in range(prompt_len):                        
        cur_embeds_new, cur_mask, logits_t, p_t = step_autoregressive(cur_embeds, cur_mask)
        # print(p_t) # has gradient
        token_id = torch.argmax(p_t, dim=-1)
        text_cur = tokenizer.decode(token_id.item(), skip_special_tokens=True)
        if any(ch in text_cur for ch in ENDING_CHARS) or question_mark_flag:
            term_id_flag = True
            break
        # if found question mark, set flag. Termninate after next token.
        if any(ch in text_cur for ch in QUESTION_MARKS):
            question_mark_flag = True

        cur_embeds = cur_embeds_new
        pt_ls.append(p_t) 
        decoded_text += text_cur
    
    return cur_embeds[:,-len(pt_ls):,:], decoded_text, term_id_flag

def latent_from_delta(Z0, d, D_Z0):
    # print(f"delta shape: {d.shape}, D_Z0 shape: {D_Z0.shape}, Z0 shape: {Z0.shape}")
    # Z = Z0 + Σ_i d_i * v_i
    combined = torch.sum(d.view(-1, 1, 1) * D_Z0, dim=0)  # (L, H)

    return Z0 + combined

def projection(delta, epsilon):
    """
    Project onto {delta >= 0, ||delta||_1 <= epsilon}
    """
    with torch.no_grad():
        # enforce nonnegativity
        delta = torch.clamp(delta, min=0.0)

        # already feasible
        if delta.sum() <= epsilon:
            return delta

        # sort descending
        u, _ = torch.sort(delta, descending=True)

        cssv = torch.cumsum(u, dim=0) - epsilon
        ind = torch.arange(1, u.numel() + 1, device=delta.device)

        cond = u - cssv / ind > 0
        rho = ind[cond][-1]
        theta = cssv[rho - 1] / rho

        return torch.clamp(delta - theta, min=0.0)


def get_probs(args, outputs):
    '''This function is used to get the confidence of the answer choices A, B, C, D from the first token of the model outputs.'''
    if args.model_type in ['llama3_8b', 'llama3_3b']:
        token_map = {"A": 362, "B": 426, "C": 356, "D": 423} 
    elif args.model_type in ['qwen2_5_7b', 'qwen2_5_14b']:
        token_map = {"A": 362, "B": 425, "C": 356, "D": 422} 
    else:
        raise ValueError(f"Unsupported model type: {args.model_type}")

    logits_last = outputs.logits[0, -1, :]          # [V]
    choice_ids = torch.tensor(list(token_map.values()),
                              device=logits_last.device)
    choice_logits = logits_last[choice_ids]         # [4]
    probs = torch.softmax(choice_logits, dim=0)     # [4], keeps grad

    return probs

def obj_fun_with_prompt(args, prefix, suffix, input_prompt, target_choice_index, model, tokenizer, cur_task_dict, reasoning_target=None, hallucination_evaluator=None):
    '''This function comuptes the objective function of REALISTA, which is the log likelihood of generating the target choice given the input prompt.'''

    prefix, suffix = get_prompt(cur_task_dict)
    # Tokenize parts surrounding the latent suffix
    prefix_ids = tokenizer(prefix, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    suffix_ids = tokenizer(suffix, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    input_ids = tokenizer(input_prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)

    # === Step 2: Get their embeddings ===
    prefix_embeds = model.model.embed_tokens(prefix_ids)
    suffix_embeds = model.model.embed_tokens(suffix_ids)
    question_embeds = model.model.embed_tokens(input_ids)

    full_input_embeds = torch.cat([prefix_embeds, question_embeds, suffix_embeds], dim=1).to(torch.float16)

    outputs = model(inputs_embeds=full_input_embeds)
    
    probs = get_probs(args,outputs)
    obj_value = probs[target_choice_index]
    obj_value = torch.log(obj_value + 1e-10)
    return obj_value, probs


def get_prompt(cur_task_dict, is_reasoning=False):
    '''
    Ref: https://github.com/bhaweshiitk/ConformalLLM/blob/main/conformal_llm_scores.py
    cur_task_dict = {
        'question': original_prompt,
        'answer': correct_index,
        'choices': choices,
        'subject': mmlu_subject,
    }
    '''

    subject_name = cur_task_dict['subject']  # subject name
    choices = cur_task_dict['choices']  # answer choices

    prefix = f"You are the world's best expert in answering questions related to {subject_name.replace('_', ' ')}. "
    if not is_reasoning:
        prefix += '''Answer the following question and give me the reason. \n'''
    else:
        prefix += '''Answer the following question. \n'''
    
    # prompt_add += f"{query_x}\n" # question from mmlu dataset or rephrased via attack_generator

    suffix = "\n"

    for idx, letter in enumerate(['A', 'B', 'C', 'D']):
        suffix += '    ' + letter + '. ' + choices[idx] + '\n'

    if not is_reasoning:
        suffix += f"The correct answer is option: "
    
    return prefix, suffix

def feasibility_check(
                query_x, query_x0, choices, subject, ground_truth_idx,
                feasibility_evaluator_model
                ):
    '''This function checks whether the new question is semantically equivalent to the original question.
    The prompt is inspired by 'Text Encoders Lack Knowledge': https://aclanthology.org/2023.gem-1.23.pdf'''

    input_prompt = f'''
    You are an expert in {subject.replace('_', ' ')}.

    Assess the **semantic equivalence** between the following two questions:

    Original Question: "{query_x0}"  
    New Question: "{query_x}"

    Both questions share the same answer options:
    A. {choices[0]}  
    B. {choices[1]}  
    C. {choices[2]}  
    D. {choices[3]}

    The correct answer for the Original Question is: {chr(65 + ground_truth_idx)}. {choices[ground_truth_idx]}.

    Determine whether the two questions are **semantically equivalent**, meaning:
    - Each question entails the other.
    - The New Question does not introduce information beyond what’s in the Original Question and the answer options.
    - The New Question does not omit key information from the Original Question.
    - The New Question preserves the original meaning.
    - The New Question would result in the same correct answer.

    Assign a **semantic equivalence score** as 0 or 1:
    - 1: All conditions are fully satisfied.  
    - 0: Any of the condition is not satisfied.

    Return only the score in the following format (no additional explanation):
    `{{"equivalence_score": "1"}}`
    '''

    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": f"{input_prompt}"}
    ]

    feasibility_check_raw = feasibility_evaluator_model.generate(messages, max_new_tokens=30, temperature=1.0)

    # if verbose:
    #     print(f"feasibility_check input prompt: {input_prompt}")
    #     print(f"feasibility_check_raw: {feasibility_check_raw}")

    try:
        # Convert string to dictionary
        parsed_result = json.loads(feasibility_check_raw)
        is_feasible = (float(parsed_result['equivalence_score'] ) == 1)
        # if verbose:
        #     print(is_feasible)  # Output: new semantically equivalent prompt
    except json.JSONDecodeError as e:
        print(f"JSON Decode Error: {e}")
        is_feasible = False
    except Exception as e:  # General catch-all for unexpected errors
        print(f"Unexpected Error: {e}")
        is_feasible = False
    
    return is_feasible