import argparse
import math

import torch
from datasets import Dataset, load_from_disk
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import apply_chat_template

from fair_gpt.utils import (
    LeftPaddingCompatibleDataCollatorForLM,
    per_example_loss,
    per_token_loss,
)

PROMPT = "Write a biography of a famous person in one paragraph."


@torch.no_grad()
def get_top_p_entropies(sorted_log_p, p):
    cumulative_probs = torch.logcumsumexp(sorted_log_p.float(), dim=-1)

    mask = (cumulative_probs <= math.log(p)).int()
    mask[..., 1:] |= mask[..., :-1]
    threshold_index = torch.argmax((1 - mask.int()), axis=-1)
    inf_mask = (1.0 - mask).float() * float("-Inf")
    inf_mask = torch.nan_to_num(inf_mask, nan=0.0, posinf=None, neginf=None)
    normalized_top_p = sorted_log_p + inf_mask

    normalized_top_p = torch.nn.functional.log_softmax(normalized_top_p, dim=-1)

    top_entropies = -torch.sum(normalized_top_p.exp() * normalized_top_p, dim=-1)
    uniform_entropies = torch.log(threshold_index.float() + 1)
    entropies_ratio = top_entropies / uniform_entropies
    entropies_ratio = torch.nan_to_num(
        entropies_ratio, nan=0.0, posinf=1.0, neginf=None
    )

    return threshold_index, top_entropies, entropies_ratio, cumulative_probs


def entropy_uniform_distribution(n_classes):
    entropies = torch.arange(1, n_classes + 1, dtype=torch.float32)
    return torch.log(entropies)


@torch.no_grad()
def compute_per_token_entropy(log_p):
    p = log_p.exp()
    entropy = -torch.einsum("...i,...i->...", p, log_p)
    return entropy


def compute_full_distribution_entropy(log_p, mask):
    per_token_entropy = compute_per_token_entropy(log_p)
    uniform_entropies = entropy_uniform_distribution(len(tokenizer)).to(log_p.device)
    masked_log_entropies = (
        torch.log(per_token_entropy) * mask
    )  # mask is 1 for valid tokens and 0 for padding tokens
    entropy_ratio = masked_log_entropies[:, :, None] - torch.log(
        uniform_entropies[None, 1:]
    )

    max_entropy_ratio = entropy_ratio.max(dim=-1)
    max_entropy_ratio_idx = max_entropy_ratio.indices
    max_entropy_ratio_value = max_entropy_ratio.values
    dict_metrics = {
        "distribution_entropies": per_token_entropy.to(torch.float32).cpu(),
        "distribution_max_log_ratio": max_entropy_ratio_value.to(torch.float32).cpu(),
        "distribution_max_log_ratio_idx": max_entropy_ratio_idx.cpu(),
    }

    return dict_metrics


def compute_top_p_entropies(sorted_log_p, p):
    top_p_idx, top_p_entropies, top_p_entropies_ratio, _ = get_top_p_entropies(
        sorted_log_p=sorted_log_p, p=p
    )
    dict_metrics = {
        f"top_{p}_entropies": top_p_entropies.to(torch.float32).cpu(),
        f"top_{p}_entropies_ratio": top_p_entropies_ratio.to(torch.float32).cpu(),
        f"top_{p}_idx": top_p_idx.cpu(),
    }
    return dict_metrics


def gpu_computation(batch):
    chats = []

    for i in range(len(batch["content"])):
        user_message = {"role": "user", "content": PROMPT}
        title = batch["title"][i]
        article_content = batch["content"][i]
        biography = f"{title}\n\n{article_content}"
        assistant_message = {"role": "assistant", "content": biography}
        messages = {"prompt": [user_message], "completion": [assistant_message]}
        if hasattr(tokenizer, "chat_template") and tokenizer.chat_template is not None:
            formatted_text = apply_chat_template(messages, tokenizer=tokenizer)

        else:
            # Fallback to simple concatenation for models without chat template
            formatted_text = {
                "prompt": f"{tokenizer.bos_token}{user_message['content']}\n\n",
                "completion": f"{assistant_message['content']}{tokenizer.eos_token}",
            }
        prompt_tok = tokenizer(formatted_text["prompt"], add_special_tokens=False)[
            "input_ids"
        ]
        completion_tok = tokenizer(
            formatted_text["completion"], add_special_tokens=False
        )["input_ids"]
        labels = [-100] * len(prompt_tok) + completion_tok
        input_batch = {}
        input_batch["labels"] = labels
        input_batch["input_ids"] = prompt_tok + completion_tok

        chats.append(input_batch)
    model_inputs = collator(chats)

    model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
    dict_torch = {}

    with torch.no_grad():
        outputs = model(**model_inputs)
    log_p = outputs.logits[:, :-1].contiguous().to(torch.float32)
    loss = per_example_loss(outputs, model_inputs)[0]
    tokens_loss = per_token_loss(outputs, model_inputs)[0]
    dict_torch["tokens_loss"] = tokens_loss.to(torch.float32).cpu()

    del outputs
    labels = model_inputs["labels"][:, 1:].contiguous()
    label_mask = labels != -100
    max_z = log_p.max(axis=-1).values.cpu().to(torch.float32)
    min_z = log_p.min(axis=-1).values.cpu().to(torch.float32)
    dict_torch["max_z"] = max_z
    dict_torch["min_z"] = min_z

    log_p = torch.nn.functional.log_softmax(log_p, dim=-1)

    distribution_metrics = compute_full_distribution_entropy(log_p, label_mask)
    dict_torch.update(distribution_metrics)
    log_p = torch.sort(log_p, descending=True).values
    for p in TOP_P_LIST:
        top_p_metrics = compute_top_p_entropies(log_p, p)
        dict_torch.update(top_p_metrics)
    dict_torch["labels"] = labels

    dict_metrics = {k: [] for k in dict_torch.keys()}

    label_mask_np = label_mask.cpu().numpy()
    for i in range(len(batch["content"])):
        for k in dict_metrics.keys():
            metric = dict_torch[k][i].cpu().numpy()
            dict_metrics[k].append(metric[label_mask_np[i]].tolist())
    dict_metrics["loss"] = loss.to(torch.float32).cpu().numpy().tolist()

    return dict_metrics


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_name",
        type=str,
        required=True,
        help="Name of the model under models_dir",
    )
    parser.add_argument("--n_samples", type=int, default=None)
    parser.add_argument(
        "--top_p_list", type=float, nargs="+", default=[0.9, 0.95, 0.99, 0.999]
    )
    parser.add_argument("--batch_size", type=int, default=2)
    args = parser.parse_args()

    # Construct model path and output directory
    model_path = f"models_dir/{args.model_name}"
    output_dir = f"checkpoints/check-distribution-wikibio/{args.model_name}"

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
    )
    print(f"Loaded model: {args.model_name}")

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    tokenizer.pad_token = tokenizer.eos_token

    dataset = load_from_disk("data_dir/wikipedia_bio")
    collator = LeftPaddingCompatibleDataCollatorForLM(tokenizer=tokenizer)

    print("Loaded dataset")
    if args.n_samples is not None:
        dataset = dataset.select(range(args.n_samples))
    TOP_P_LIST = args.top_p_list
    print("Starting computation")
    list_dicts = []
    for i in tqdm(range(0, len(dataset), args.batch_size)):
        batch = dataset[i : i + args.batch_size]
        batch = gpu_computation(batch)
        list_dicts.append(batch)

    print("Finished computation")

    result_dict = {k: [] for k in list_dicts[0].keys()}
    for d in list_dicts:
        for k in d.keys():
            result_dict[k] += d[k]

    dataset = Dataset.from_dict(result_dict)
    dataset.save_to_disk(output_dir)
    print(f"Results saved to {output_dir}")
