import argparse
import math
import os
from multiprocessing import set_start_method

import torch
from datasets import Dataset, load_from_disk
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM

os.environ["TOKENIZERS_PARALLELISM"] = "false"


@torch.no_grad()
def get_top_p_entropies(sorted_log_p, p):
    cumulative_probs = torch.logcumsumexp(sorted_log_p.float(), dim=-1)

    mask = (cumulative_probs <= math.log(p)).int()
    mask[..., 1:] |= mask[..., :-1]
    threshold_index = torch.argmax((1 - mask.int()), axis=-1)
    inf_mask = (1.0 - mask).float() * float("-Inf")
    inf_mask = torch.nan_to_num(inf_mask, nan=0.0, posinf=None, neginf=None)
    normalized_top_p = sorted_log_p + inf_mask

    normalized_top_p = torch.nn.functional.log_softmax(normalized_top_p, dim=-1)

    top_entropies = -torch.sum(normalized_top_p.exp() * normalized_top_p, dim=-1)
    uniform_entropies = torch.log(threshold_index.float() + 1)
    entropies_ratio = top_entropies / uniform_entropies
    entropies_ratio = torch.nan_to_num(
        entropies_ratio, nan=0.0, posinf=1.0, neginf=None
    )

    return threshold_index, top_entropies, entropies_ratio, cumulative_probs


def entropy_uniform_distribution(n_classes):
    entropies = torch.arange(1, n_classes + 1, dtype=torch.float32)
    return torch.log(entropies)


@torch.no_grad()
def compute_per_token_entropy(log_p):
    p = log_p.exp()
    entropy = -torch.einsum("...i,...i->...", p, log_p)
    return entropy


def compute_full_distribution_entropy(log_p, mask):
    per_token_entropy = compute_per_token_entropy(log_p)
    uniform_entropies = entropy_uniform_distribution(model.config.vocab_size).to(
        log_p.device
    )
    masked_log_entropies = (
        torch.log(per_token_entropy) * mask
    )  # mask is 1 for valid tokens and 0 for padding tokens
    entropy_ratio = masked_log_entropies[:, :, None] - torch.log(
        uniform_entropies[None, 1:]
    )

    max_entropy_ratio = entropy_ratio.max(dim=-1)
    max_entropy_ratio_idx = max_entropy_ratio.indices
    max_entropy_ratio_value = max_entropy_ratio.values
    dict_metrics = {
        "distribution_entropies": per_token_entropy.to(torch.float32).cpu(),
        "distribution_max_log_ratio": max_entropy_ratio_value.to(torch.float32).cpu(),
        "distribution_max_log_ratio_idx": max_entropy_ratio_idx.cpu(),
    }

    return dict_metrics


def compute_top_p_entropies(sorted_log_p, p):
    top_p_idx, top_p_entropies, top_p_entropies_ratio, _ = get_top_p_entropies(
        sorted_log_p=sorted_log_p, p=p
    )
    dict_metrics = {
        f"top_{p}_entropies": top_p_entropies.to(torch.float32).cpu(),
        f"top_{p}_entropies_ratio": top_p_entropies_ratio.to(torch.float32).cpu(),
        f"top_{p}_idx": top_p_idx.cpu(),
    }
    return dict_metrics


def gpu_computation(batch):
    chats = []

    for i in range(len(batch["description"])):
        user_message = {"role": "user", "content": batch["description"][i]}
        assistant_message = {"role": "assistant", "content": batch["solution"][i]}
        text = (
            tokenizer.apply_chat_template(
                [user_message, assistant_message],
                tokenize=False,
                continue_final_message=True,
            )
            + tokenizer.eos_token
        )
        chats.append(text)
    tokenized_texts = tokenizer(chats)["input_ids"]

    collator = DataCollatorForCompletionOnlyLM(
        response_template="<|start_header_id|>assistant<|end_header_id|>\n\n",
        tokenizer=tokenizer,
    )
    model_inputs = collator(tokenized_texts, return_tensors="pt")

    model_inputs = {k: v.to(model.device) for k, v in model_inputs.items()}
    dict_torch = {}

    with torch.no_grad():
        outputs = model(**model_inputs)
    log_p = outputs.logits[:, :-1].contiguous().to(torch.float32)
    del outputs
    labels = model_inputs["labels"][:, 1:].contiguous()
    label_mask = labels != -100
    max_z = log_p.max(axis=-1).values.cpu().to(torch.float32)
    min_z = log_p.min(axis=-1).values.cpu().to(torch.float32)
    dict_torch["max_z"] = max_z
    dict_torch["min_z"] = min_z

    log_p = torch.nn.functional.log_softmax(log_p, dim=-1)

    distribution_metrics = compute_full_distribution_entropy(log_p, label_mask)
    dict_torch.update(distribution_metrics)
    log_p = torch.sort(log_p, descending=True).values
    for p in TOP_P_LIST:
        top_p_metrics = compute_top_p_entropies(log_p, p)
        dict_torch.update(top_p_metrics)
    dict_torch["labels"] = labels

    dict_metrics = {k: [] for k in dict_torch.keys()}

    label_mask_np = label_mask.cpu().numpy()
    for i in range(len(batch["description"])):
        for k in dict_metrics.keys():
            metric = dict_torch[k][i].cpu().numpy()
            dict_metrics[k].append(metric[label_mask_np[i]].tolist())

    return dict_metrics


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_samples", type=int, default=1000)
    parser.add_argument(
        "--top_p_list", type=float, nargs="+", default=[0.9, 0.95, 0.99, 0.999]
    )
    parser.add_argument(
        "--output_dir", default="checkpoints/check-distribution-wp/llama3.1"
    )
    parser.add_argument("--batch_size", type=int, default=2)
    args = parser.parse_args()
    set_start_method("spawn")

    model = AutoModelForCausalLM.from_pretrained(
        "models_dir/llama3.1",
        device_map="auto",
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
    )
    print("Loaded model")

    tokenizer = AutoTokenizer.from_pretrained("models_dir/llama3.1")
    tokenizer.pad_token = tokenizer.eos_token

    dataset = load_from_disk("data_dir/writingPrompts-dist")
    print("Loaded dataset")
    dataset = dataset.select(range(args.n_samples))
    TOP_P_LIST = args.top_p_list
    print("Starting computation")
    list_dicts = []
    for i in tqdm(range(0, len(dataset), args.batch_size)):
        batch = dataset[i : i + args.batch_size]
        batch = gpu_computation(batch)
        list_dicts.append(batch)

    print("Finished computation")

    result_dict = {k: [] for k in list_dicts[0].keys()}
    for d in list_dicts:
        for k in d.keys():
            result_dict[k] += d[k]

    dataset = Dataset.from_dict(result_dict)
    dataset.save_to_disk(args.output_dir)
    # dataset = dataset.map(
    #     gpu_computation,
    #     batched=True,
    #     with_rank=True,
    #     num_proc=torch.cuda.device_count(),
    #     batch_size=64,
    #     keep_in_memory=True,
    #     load_from_cache_file=False,
    # )
