import logging
import torch
import numpy as np
from torch.utils.data import DataLoader

from tqdm import tqdm
from collections import defaultdict
import random

from evals.metrics.utils import (
    aggregate_to_1D,
    evaluate_probability,
    eval_text_similarity,
    run_batchwise_evals,
    tokenwise_vocab_logprobs,
    eval_accuracy_api,
)
from evals.metrics.base import unlearning_metric

# Supress the info messages logged while calculating rouge using rouge_scorer
logging.getLogger("absl").setLevel(logging.WARNING)
logger = logging.getLogger("evaluator")


@unlearning_metric(name="probability")
def probability(model, **kwargs):
    """Compute the probabilities by data points and report aggregated average"""
    data = kwargs["data"]
    collator = kwargs["collators"]
    batch_size = kwargs["batch_size"]

    dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collator)

    fun_args = {}
    scores_by_index = run_batchwise_evals(
        model, dataloader, evaluate_probability, fun_args, "Calculating loss"
    )
    prob_values = np.array(
        [
            evals["prob"]
            for evals in scores_by_index.values()
            if evals["prob"] is not None
        ]
    )
    prob_values = aggregate_to_1D(prob_values)
    return {"agg_value": np.mean(prob_values), "value_by_index": scores_by_index}


@unlearning_metric(name="probability_w_options")
def probability_w_options(model, **kwargs):
    """Normalize probabilities of correct answers against false answers for
    open-ended datasets, returning the aggregated value and per-index probabilities."""
    correct_answer_results = kwargs["pre_compute"]["correct"]["value_by_index"]
    wrong_answer_results = kwargs["pre_compute"]["wrong"]["value_by_index"]

    correct_indices = list(correct_answer_results.keys())
    wrong_indices = list(wrong_answer_results.keys())
    assert correct_indices == wrong_indices

    # Filter out None values from both correct and wrong answers
    filtered_indices = [
        idx
        for idx in correct_indices
        if correct_answer_results[idx] is not None
        and wrong_answer_results[idx] is not None
    ]
    correct = np.array(
        [correct_answer_results[idx]["prob"] for idx in filtered_indices]
    )
    all_wrong = np.array(
        [wrong_answer_results[idx]["prob"] for idx in filtered_indices]
    )
    wrong = np.sum(all_wrong, axis=tuple(range(1, all_wrong.ndim)))
    probs = correct / (correct + wrong + 1e-10)

    value_by_index = dict(zip(correct_indices, [{"prob": val} for val in probs]))
    return {"agg_value": np.mean(probs), "value_by_index": value_by_index}


@unlearning_metric(name="rouge")
def rouge(model, **kwargs):
    """Calculate ROUGE metrics and return the aggregated value along with per-index scores."""
    tokenizer = kwargs["tokenizer"]
    data = kwargs["data"]
    collator = kwargs["collators"]
    batch_size = kwargs["batch_size"]
    generation_args = kwargs["generation_args"]
    dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collator)

    fun_args = {"tokenizer": tokenizer, "generation_args": generation_args}
    scores_by_index = run_batchwise_evals(
        model,
        dataloader,
        eval_text_similarity,
        fun_args,
        "Calculating text similarity",
    )
    rouge_values = np.array(
        [
            evals[kwargs["rouge_type"]]
            for evals in scores_by_index.values()
            if evals[kwargs["rouge_type"]] is not None
        ]
    )
    rouge_values = aggregate_to_1D(rouge_values)
    return {
        "agg_value": np.mean(rouge_values),
        "value_by_index": scores_by_index,
    }


@unlearning_metric(name="truth_ratio")
def truth_ratio(model, **kwargs):
    """Compute the truth ratio, aggregating false/true scores, and
    return the aggregated value."""

    # Forget data: It is better if false and true are equally likely,
    # i.e., tr=false/true is closest to 1.
    def closer_to_1_better(arr):
        return np.mean(np.minimum(arr, 1 / (arr + 1e-10)))

    # Non-forget data: It is better if tr=false/true is lower, i.e.,
    # 1-tr is higher.
    def true_better(arr):
        return np.mean(np.maximum(0, 1 - arr))

    if kwargs["aggregator"] == "closer_to_1_better":
        aggregator = closer_to_1_better
    elif kwargs["aggregator"] == "true_better":
        aggregator = true_better
    else:
        raise ValueError(f"Invalid truth ratio aggregator: {kwargs['aggregator']}")

    correct_answer_results = kwargs["pre_compute"]["correct"]["value_by_index"]
    wrong_answer_results = kwargs["pre_compute"]["wrong"]["value_by_index"]

    correct_indices = list(correct_answer_results.keys())
    wrong_indices = list(wrong_answer_results.keys())
    assert correct_indices == wrong_indices

    # Filter out None values from both correct and wrong answers
    filtered_indices = [
        idx
        for idx in correct_indices
        if correct_answer_results[idx] is not None
        and wrong_answer_results[idx] is not None
    ]
    correct_avg_losses = [
        correct_answer_results[idx]["avg_loss"] for idx in filtered_indices
    ]
    wrong_avg_losses = [
        wrong_answer_results[idx]["avg_loss"] for idx in filtered_indices
    ]

    correct_avg_losses = aggregate_to_1D(np.array(correct_avg_losses))
    wrong_avg_losses = aggregate_to_1D(np.array(wrong_avg_losses))

    correct_prob = np.exp(-correct_avg_losses)
    wrong_prob = np.exp(-wrong_avg_losses)

    truth_ratios = wrong_prob / (correct_prob + 1e-10)
    value_by_index = dict(
        zip(correct_indices, [{"score": val} for val in truth_ratios])
    )
    truth_ratio_stats = np.array([evals["score"] for evals in value_by_index.values()])
    forget_tr_avg = aggregator(truth_ratio_stats)
    return {"agg_value": forget_tr_avg, "value_by_index": value_by_index}


@unlearning_metric(name="exact_memorization")
def exact_memorization(model, **kwargs):
    data = kwargs["data"]
    collator = kwargs["collators"]
    batch_size = kwargs["batch_size"]
    dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collator)

    def _exact_memorization(model, batch):
        log_probs_batch, labels_batch = tokenwise_vocab_logprobs(
            model, batch, grad=False, return_labels=True
        )
        em_batch = []
        for log_probs, labels in zip(log_probs_batch, labels_batch):
            valid_len = len(labels)
            if valid_len == 0:
                # Rarely, tokenization can result in a mismatch with no valid target
                # tokens for loss computation (see preprocess_chat_instance() for
                # reference). Since this condition makes no sense in terms of
                # computing EM, we just choose to set EM=None
                logger.warning(
                    "EM score for an instance is marked None, due to "
                    "tokenization issues that resulted in no valid target tokens."
                )
                em_batch.append({"score": None})
            else:
                preds = torch.argmax(log_probs, dim=-1)
                em_score = (preds == labels).sum() / valid_len
                em_batch.append({"score": em_score.item()})
        return em_batch

    fun_args = {}
    scores_by_index = run_batchwise_evals(
        model, dataloader, _exact_memorization, fun_args, "Calculating EM"
    )
    em_values = np.array(
        [
            evals["score"]
            for evals in scores_by_index.values()
            if evals["score"] is not None
        ]
    )
    em_values = aggregate_to_1D(em_values)
    return {"agg_value": np.mean(em_values), "value_by_index": scores_by_index}


@unlearning_metric(name="extraction_strength")
def extraction_strength(model, **kwargs):
    data = kwargs["data"]
    collator = kwargs["collators"]
    batch_size = kwargs["batch_size"]
    dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collator)

    def _extraction_strength(model, batch):
        log_probs_batch, labels_batch = tokenwise_vocab_logprobs(
            model, batch, grad=False, return_labels=True
        )
        es_batch = []
        for log_probs, labels in zip(log_probs_batch, labels_batch):
            valid_len = len(labels)
            preds = torch.argmax(log_probs, dim=-1)
            for k in range(valid_len):
                suff_preds = preds[k:]
                suff_labels = labels[k:]
                if torch.equal(suff_preds, suff_labels):
                    break
            if valid_len == 0:
                # Rarely, tokenization can result in a mismatch with no valid target
                # tokens for loss computation (see preprocess_chat_instance() for
                # reference). Since this condition makes no sense in terms of
                # computing ES, we just choose to set ES=None
                logger.warning(
                    "ES score for an instance is marked None, due to "
                    "tokenization issues that resulted in no valid target tokens."
                )
                es_batch.append({"score": 0})
            else:
                es_score = 1 - (k / valid_len)
                es_batch.append({"score": es_score})
        return es_batch

    fun_args = {}
    scores_by_index = run_batchwise_evals(
        model, dataloader, _extraction_strength, fun_args, "Calculating ES"
    )
    es_values = np.array(
        [
            evals["score"]
            for evals in scores_by_index.values()
            if evals["score"] is not None
        ]
    )
    es_values = aggregate_to_1D(es_values)
    return {"agg_value": np.mean(es_values), "value_by_index": scores_by_index}


TOFU_VALIDATION_SEED = 53
VALIDATION_SPLIT = 0.15

@unlearning_metric(name="accuracy")
def accuracy(model, **kwargs):
    tokenizer = kwargs["tokenizer"]
    data = kwargs["data"]
    collator = kwargs["collators"]
    batch_size = kwargs["batch_size"]
    generation_args = kwargs["generation_args"]
    retain_model_logs = kwargs["reference_logs"]["retain_model_logs"]["retain"]["value_by_index"]
    retain_model_gens = [v["generation"] for v in list(retain_model_logs.values())]
    is_validation = kwargs.get("is_validation", False)

    if is_validation:
        n_samples = max(20, int(VALIDATION_SPLIT * len(retain_model_gens)))
        rng = random.Random(TOFU_VALIDATION_SEED)
        selected_indices = rng.sample(range(len(retain_model_gens)), n_samples)
        validation_retain_model_gens = []
        for idx in selected_indices:
            validation_retain_model_gens.append(retain_model_gens[idx])
        retain_model_gens = validation_retain_model_gens
        print(f"Accuracy metric: Using {len(retain_model_gens)} samples for validation.")



    assert len(data) == 2, "Data must contain exactly two elements."
    keys = list(data.keys())
    assert any("pert" in key for key in keys), "One key must include 'pert'."

    data_pert = data[[key for key in keys if "pert" in key][0]]
    data = data[[key for key in keys if "pert" not in key][0]]

    print(f"Accuracy metric: Using {len(data)} original and {len(data_pert)} perturbed data points.")


    assert len(data) == len(data_pert), "Data and perturbed data must have the same length."
    assert len(data) == len(retain_model_gens), "Data and retain model gens must have the same length."

    dataloader = DataLoader(data, batch_size=batch_size, collate_fn=collator)
    dataloader_pert = DataLoader(data_pert, batch_size=batch_size, collate_fn=collator)

    fun_args = {"tokenizer": tokenizer, "generation_args": generation_args}

    st_idx = 0

    evals = defaultdict(dict)
    for batch, batch_pert in tqdm(zip(dataloader, dataloader_pert), total=len(dataloader)):
       
        assert "input_ids" in batch
        if "input_ids" in batch_pert:
            batch_pert = {"0": batch_pert}

        data_indices = (
            batch.pop("index").cpu().numpy().tolist()
        )  # data item indices
        for k in batch_pert.keys():
            batch_pert[k].pop("index")


        batch_evals = eval_accuracy_api(
            model=model, batch=batch, batch_pert=batch_pert, retain_model_gens=retain_model_gens[st_idx:st_idx+len(data_indices)], **fun_args
        )
        indexwise_batch_evals = dict(zip(data_indices, batch_evals))
        evals |= indexwise_batch_evals

        st_idx += len(data_indices)


    scores_by_index = evals
    
    
    is_not_truth_scores = np.array(
        [evals["is_not_truth"] for evals in scores_by_index.values()]
    )
    is_retain_scores = np.array(
        [evals["is_retain"] for evals in scores_by_index.values()]
    )
    is_retain_or_truth_scores = np.array(
        [evals["is_retain_or_truth"] for evals in scores_by_index.values()]
    )
    is_not_truth_score = aggregate_to_1D(is_not_truth_scores)
    is_retain_score = aggregate_to_1D(is_retain_scores)
    is_retain_or_truth_score = aggregate_to_1D(is_retain_or_truth_scores)
    return {
        "agg_is_not_truth": np.mean(is_not_truth_score),
        "agg_is_retain": np.mean(is_retain_score),
        "agg_is_retain_or_truth": np.mean(is_retain_or_truth_score),
        "value_by_index": scores_by_index,
    }
