import torch
import numpy as np
from sklearn.metrics import auc
from utils.prompt import *

def select_rationales(attribution_scores, input_ids, attention_mask, percentage):
    """
    Select top percentage of tokens as rationales based on attribution scores for a batch.

    Args:
        attribution_scores List[List]: Attribution scores for each token. 
                                           Shape: [batch_size, seq_length]
        percentage (float): Percentage of tokens to select as rationales (between 0 and 1).

    Returns:
        rationale_mask (torch.Tensor): Boolean mask indicating selected rationales.
                                       Shape: [batch_size, seq_length]
    """

    batch_size, seq_length = input_ids.size()
    rationale_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(input_ids.device)
    if percentage == 0.0:
        return rationale_mask
    if percentage == 1.0:
        return torch.ones_like(input_ids, dtype=torch.bool).to(input_ids.device)        
    # compute the real length of each input
    real_length = torch.sum(attention_mask, dim=1)

    # minus 2 for [CLS] and [SEP] tokens
    real_length -= 2

    # compute the number of tokens to select for each example
    k = (real_length * percentage).clamp(min=1).long()  # Ensure at least one token is selected

    # For each example in the batch, select top-k tokens
    for i in range(batch_size):
        topk = k if isinstance(k, int) else k[i]
        #print(real_length[i], len(attribution_scores[i]))
        if real_length[i] == len(attribution_scores[i]):
            # select the top k tokens from the list of attribution scores
            # if the length of the attribution score is the same as the input length, that means the attribution score does not contain [CLS] and [SEP] tokens
            topk_indices = np.argsort(attribution_scores[i])[-topk:][::-1]
        elif real_length[i] == len(attribution_scores[i]) - 2:           
            topk_indices = np.argsort(attribution_scores[i][1:-1])[-topk:][::-1] # exclude [CLS] and [SEP] tokens in the attribution score
        else:
            raise ValueError("The length of the attribution score does not match the input length")
        # add 1 to skip [CLS] token
        topk_indices += 1
        rationale_mask[i, topk_indices.copy()] = True # make sure the selected token is not [CLS] or [SEP]

    return rationale_mask


def compute_comprehensiveness(model, input_ids, attention_mask, rationale_mask, predicted_ids, orig_probs, mask_token_id):
    """
    Compute the comprehensiveness score by masking out the rationales for a batch.

    Args:
        model: The BERT model.
        input_ids (torch.Tensor): Input token IDs. Shape: [batch_size, seq_length]
        attention_mask (torch.Tensor): Attention mask. Shape: [batch_size, seq_length]
        rationale_mask (torch.Tensor): Boolean mask for rationales. Shape: [batch_size, seq_length]

    Returns:
        comprehensiveness (torch.Tensor): Comprehensiveness scores for each example. Shape: [batch_size]
    """
    model.eval()
    with torch.no_grad():
        if orig_probs is None or predicted_ids is None:
            # Original prediction
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            orig_logits = outputs.logits
            orig_probs = torch.softmax(orig_logits, dim=-1)
            # gather the predicted class and the probabilities for these classes
            predicted_ids = torch.argmax(orig_probs, dim=1)
            orig_probs = orig_probs.gather(1, predicted_ids.unsqueeze(1)).squeeze(1)  # Shape: [batch_size]


        # Mask out rationales
        masked_input_ids = input_ids.clone()
        masked_input_ids[rationale_mask] = mask_token_id  # Mask token IDs
        masked_attention_mask = attention_mask.clone()
        # set the attention mask of the masked tokens to 0
        masked_attention_mask[rationale_mask] = 0

        outputs = model(input_ids=masked_input_ids, attention_mask=masked_attention_mask)
        masked_logits = outputs.logits
        masked_probs = torch.softmax(masked_logits, dim=-1)
        masked_probs = masked_probs.gather(1, predicted_ids.unsqueeze(1)).squeeze(1)  # Shape: [batch_size]

        comprehensiveness = orig_probs - masked_probs

    return comprehensiveness

def compute_sufficiency(model, input_ids, attention_mask, rationale_mask, predicted_ids, orig_probs, mask_token_id):
    """
    Compute the sufficiency score by keeping only the rationales for a batch.

    Args:
        model: The BERT model.
        input_ids (torch.Tensor): Input token IDs. Shape: [batch_size, seq_length]
        attention_mask (torch.Tensor): Attention mask. Shape: [batch_size, seq_length]
        rationale_mask (torch.Tensor): Boolean mask for rationales. Shape: [batch_size, seq_length]
        label_ids (torch.Tensor): The target label indices. Shape: [batch_size]

    Returns:
        sufficiency (torch.Tensor): Sufficiency scores for each example. Shape: [batch_size]
    """
    model.eval()
    with torch.no_grad():
        if orig_probs is None or predicted_ids is None:
            # Original prediction
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            orig_logits = outputs.logits
            orig_probs = torch.softmax(orig_logits, dim=-1)
            # gather the predicted class and the probabilities for these classes
            predicted_ids = torch.argmax(orig_probs, dim=1)
            orig_probs = orig_probs.gather(1, predicted_ids.unsqueeze(1)).squeeze(1)  # Shape: [batch_size]

        # Keep only rationales
        suff_input_ids = input_ids.clone()
        suff_input_ids[~rationale_mask] = mask_token_id  # Mask non-rationales
        suff_attention_mask = attention_mask.clone()
        # set the attention mask of the non-rationale tokens to 0
        suff_attention_mask[~rationale_mask] = 0

        outputs = model(input_ids=suff_input_ids, attention_mask=suff_attention_mask)
        suff_logits = outputs.logits
        suff_probs = torch.softmax(suff_logits, dim=-1)
        suff_probs = suff_probs.gather(1, predicted_ids.unsqueeze(1)).squeeze(1)  # Shape: [batch_size]

        sufficiency = orig_probs - suff_probs

    return sufficiency

def compute_perturbation_auc(percentages, scores):
    """
    Compute the AUC for the perturbation scores at different percentages.

    Args:
        percentages (List[float]): List of percentages.
        scores (List[float]): List of scores corresponding to the percentages.

    Returns:
        auc_score (float): The AUC score.
    """
    auc_score = auc(percentages, scores)
    return auc_score

def select_rationales_decoder(attribution_scores, input_ids, attention_mask, texts, tokenizer, template, prompt, percentage):
    batch_size, seq_length = input_ids.size()
    assert batch_size == 1, "Only support batch size of 1 for decoder models"
    rationale_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(input_ids.device)
    if percentage == 0.0:
        return rationale_mask
    if percentage == 1.0:
        return torch.ones_like(input_ids, dtype=torch.bool).to(input_ids.device)        
    # compute the real length of each input
    real_length = torch.sum(attention_mask, dim=1)

    # minus 2 for [CLS] and [SEP] tokens
    # real_length -= 2

    # compute the number of tokens to select for each example
    k = (real_length * percentage).clamp(min=1).long()  # Ensure at least one token is selected

    # For each example in the batch, select top-k tokens
    for i in range(batch_size):
        topk = k if isinstance(k, int) else k[i]
        #print(real_length[i], len(attribution_scores[i]))
        if real_length[i] == len(attribution_scores[i]):
            # select the top k tokens from the list of attribution scores
            # if the length of the attribution score is the same as the input length, that means the attribution score does not contain [CLS] and [SEP] tokens
            topk_indices = np.argsort(attribution_scores[i])[-topk:][::-1]
        else:
            raise ValueError("The length of the attribution score does not match the input length")
        # add 1 to skip [CLS] token
        rationale_mask[i, topk_indices.copy()] = True 
    
    full_text = tokenizer.apply_chat_template(fill_in_template(template, prompt.replace("[TEST EXAMPLE]", texts[0])),tokenize=False,add_generation_prompt=True, enable_thinking=False, date_string="2025-07-01")
    full_inputs = tokenizer(full_text, return_tensors="pt").to(input_ids.device)
    # find the position of the original input in the full inputs
    def find_subsequence(full_input_ids: torch.Tensor, input_ids: torch.Tensor):

        full_ids = full_input_ids[0]   # shape [L]
        ids = input_ids[0]             # shape [M]

        n, m = full_ids.size(0), ids.size(0)
        matches = []

        for i in range(n - m + 1):
            if torch.equal(full_ids[i:i + m], ids):
                matches = list(range(i, i + m))
                break

        return matches
    matches = find_subsequence(full_inputs.input_ids, input_ids)
    if len(matches) == 0:
        raise ValueError("The input_ids cannot be found in the full input_ids")
    # make sure the input is a contiguous span
    assert len(matches) == input_ids.size(1), "The input is not a contiguous span in the full input"
    rationale_mask_full = torch.zeros_like(full_inputs.input_ids, dtype=torch.bool).to(input_ids.device)
    rationale_mask_full[0, matches[0]:matches[0]+input_ids.size(1)] = rationale_mask[0]
    return rationale_mask_full

def compute_comprehensiveness_decoder(model, tokenizer, input_ids, attention_mask, rationale_mask, input_text_mask, predicted_ids, orig_probs, mask_token_id):
    model.eval()
    positive_token = "Yes"
    negative_token = "No"
    positive_token_id = tokenizer(positive_token, add_special_tokens=False)["input_ids"][0]
    negative_token_id = tokenizer(negative_token, add_special_tokens=False)["input_ids"][0]
    with torch.no_grad():
        if orig_probs is None or predicted_ids is None:
            # Original prediction
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            predictions = 1 if probs[0, positive_token_id] > probs[0, negative_token_id] else 0
            confidences_predicted_class = probs[0, positive_token_id].cpu().numpy() if predictions == 1 else probs[0, negative_token_id].cpu().numpy()
            orig_probs = torch.tensor([confidences_predicted_class]).to(input_ids.device)
            predicted_ids = torch.tensor([predictions]).to(input_ids.device)
            predicted_token_ids = positive_token_id if predicted_ids.item() == 1 else negative_token_id
            predicted_token_ids = torch.tensor([predicted_token_ids]).to(input_ids.device)
        else:
            predicted_token_ids = positive_token_id if predicted_ids.item() == 1 else negative_token_id
            predicted_token_ids = torch.tensor([predicted_token_ids]).to(input_ids.device)

        # Mask out rationales
        masked_input_ids = input_ids.clone()
        masked_attention_mask = attention_mask.clone()

        # remove the tokens in the rationale mask
        if mask_token_id is None:
            masked_input_ids = masked_input_ids[:, ~rationale_mask[0]]
            # set the attention mask of the masked tokens to 0
            masked_attention_mask = masked_attention_mask[:, ~rationale_mask[0]]
        else:
            masked_input_ids[rationale_mask] = mask_token_id  # Mask token IDs
            # set the attention mask of the masked tokens to 0
            masked_attention_mask[rationale_mask] = 0

        outputs = model(input_ids=masked_input_ids, attention_mask=masked_attention_mask)
        masked_logits = outputs.logits[:, -1, :]
        masked_probs = torch.softmax(masked_logits, dim=-1)
        masked_probs = masked_probs[0, predicted_token_ids[0]].cpu().numpy()  # Shape: [batch_size]
        
        masked_probs = torch.tensor(masked_probs).unsqueeze(0).to(input_ids.device)
        comprehensiveness = orig_probs - masked_probs

    return comprehensiveness

def compute_sufficiency_decoder(model, tokenizer, input_ids, attention_mask, rationale_mask, input_mask_task, predicted_ids, orig_probs, mask_token_id):
    model.eval()
    positive_token = "Yes"
    negative_token = "No"
    positive_token_id = tokenizer(positive_token, add_special_tokens=False)["input_ids"][0]
    negative_token_id = tokenizer(negative_token, add_special_tokens=False)["input_ids"][0]
    with torch.no_grad():
        if orig_probs is None or predicted_ids is None:
            # Original prediction
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            predictions = 1 if probs[0, positive_token_id] > probs[0, negative_token_id] else 0
            confidences_predicted_class = probs[0, positive_token_id].cpu().numpy() if predictions == 1 else probs[0, negative_token_id].cpu().numpy()
            orig_probs = torch.tensor([confidences_predicted_class]).to(input_ids.device)
            predicted_ids = torch.tensor([predictions]).to(input_ids.device)
            predicted_token_ids = positive_token_id if predicted_ids.item() == 1 else negative_token_id
            predicted_token_ids = torch.tensor([predicted_token_ids]).to(input_ids.device)
        else:
            predicted_token_ids = positive_token_id if predicted_ids.item() == 1 else negative_token_id
            predicted_token_ids = torch.tensor([predicted_token_ids]).to(input_ids.device)

        # Keep only rationales
        suff_input_ids = input_ids.clone()
        suff_attention_mask = attention_mask.clone()
        remove_mask = ~rationale_mask & input_mask_task.bool().to(input_ids.device)
        if mask_token_id is None:
            suff_input_ids = suff_input_ids[:, ~remove_mask[0]]
            # set the attention mask of the non-rationale tokens to 0
            suff_attention_mask = suff_attention_mask[:, ~remove_mask[0]]
        else:
            suff_input_ids[remove_mask] = mask_token_id  # Mask non-rationales
            # set the attention mask of the non-rationale tokens to 0
            suff_attention_mask[remove_mask] = 0
        
        #print("Sufficiency input tokens:", tokenizer.convert_ids_to_tokens(suff_input_ids[0].cpu().numpy().tolist()))
        outputs = model(input_ids=suff_input_ids, attention_mask=suff_attention_mask)
        suff_logits = outputs.logits[:, -1, :]
        suff_probs = torch.softmax(suff_logits, dim=-1)
        suff_probs = suff_probs[0, predicted_token_ids[0]].cpu().numpy()  # Shape: [batch_size]
        suff_probs = torch.tensor(suff_probs).unsqueeze(0).to(input_ids.device)

        sufficiency = orig_probs - suff_probs

    return sufficiency