import concurrent.futures
from typing import List, Optional

import click
import datasets as dt
import numpy as np
import pandas as pd
import torch
from evaluate import logging
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


def compute_perplexity(
    predictions,
    model_id,
    batch_size: int = 16,
    add_start_token: bool = True,
    device=None,
    max_length=None,
):

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    model = AutoModelForCausalLM.from_pretrained(model_id)
    model = model.to(device)

    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # if batch_size > 1 (which generally leads to padding being required), and
    # if there is not an already assigned pad_token, assign an existing
    # special token to also be the padding token
    if tokenizer.pad_token is None and batch_size > 1:
        existing_special_tokens = list(
            tokenizer.special_tokens_map_extended.values()
        )
        # check that the model already has at least one special token defined
        assert (
            len(existing_special_tokens) > 0
        ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
        # assign one of the special tokens to also be the pad token
        tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

    if add_start_token and max_length:
        # leave room for <BOS> token to be added:
        assert (
            tokenizer.bos_token is not None
        ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
        max_tokenized_len = max_length - 1
    else:
        max_tokenized_len = max_length

    encodings = tokenizer(
        predictions,
        add_special_tokens=False,
        padding=True,
        truncation=True if max_tokenized_len else False,
        max_length=max_tokenized_len,
        return_tensors="pt",
        return_attention_mask=True,
    ).to(device)

    encoded_texts = encodings["input_ids"]
    attn_masks = encodings["attention_mask"]

    # check that each input is long enough:
    if add_start_token:
        assert torch.all(
            torch.ge(attn_masks.sum(1), 1)
        ), "Each input text must be at least one token long."
    else:
        assert torch.all(
            torch.ge(attn_masks.sum(1), 2)
        ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."

    ppls = []
    loss_fct = CrossEntropyLoss(reduction="none")

    for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
        end_index = min(start_index + batch_size, len(encoded_texts))
        encoded_batch = encoded_texts[start_index:end_index]
        attn_mask = attn_masks[start_index:end_index]

        if add_start_token:
            bos_tokens_tensor = torch.tensor(
                [[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)
            ).to(device)
            encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
            attn_mask = torch.cat(
                [
                    torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(
                        device
                    ),
                    attn_mask,
                ],
                dim=1,
            )

        labels = encoded_batch

        with torch.no_grad():
            out_logits = model(encoded_batch, attention_mask=attn_mask).logits

        shift_logits = out_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_attention_mask_batch = attn_mask[..., 1:].contiguous()

        perplexity_batch = torch.exp(
            (
                loss_fct(shift_logits.transpose(1, 2), shift_labels)
                * shift_attention_mask_batch
            ).sum(1)
            / shift_attention_mask_batch.sum(1)
        )

        ppls += perplexity_batch.tolist()

    return {"perplexities": ppls, "mean_perplexity": np.mean(ppls)}


@click.command()
@click.option("--model", required=True, type=str, help="Model name or path")
@click.option(
    "--preference-ds",
    required=True,
    type=str,
    help="Dataset name or path with preferences",
)
@click.option(
    "--on", required=True, type=str, help="Column name of the preference"
)
@click.option(
    "--batch-size", required=True, type=int, help="Batch size for evaluation"
)
@click.option("--n-gpus", default=None, type=int, help="Number of GPUs to use")
@click.option(
    "--n-samples", default=None, type=int, help="Number of samples to evaluate"
)
def main(
    model: str,
    preference_ds: str,
    on: str,
    batch_size: int,
    n_gpus: Optional[int] = None,
    n_samples: Optional[int] = None,
):
    try:
        ds = dt.load_from_disk(preference_ds)
    except:
        ds = dt.load_dataset(preference_ds, split="train")

    if isinstance(ds, dt.DatasetDict):
        ds = dt.concatenate_datasets([ds[split] for split in ds.keys()])

    ds = ds.sort("prompt")
    if n_samples is not None:
        ds = ds.select(range(n_samples))

    tokenizer = AutoTokenizer.from_pretrained(model)

    if (
        "chosen" in ds.column_names
        and "rejected" in ds.column_names
        and "prompt" in ds.column_names
    ):
        all_completions = []
        for _ in tqdm(ds, desc="Generating completions"):
            all_completions.append(
                tokenizer.apply_chat_template(
                    [
                        {"role": "user", "content": _["prompt"]},
                        {"role": "assistant", "content": _[on]},
                    ],
                    tokenize=False,
                ).replace(tokenizer.bos_token, "")
            )

    if n_gpus is None:
        results = compute_perplexity(
            predictions=all_completions, model_id=model, batch_size=batch_size
        )
        print(pd.DataFrame(results["perplexities"]).describe())
    else:
        all_perplexities = []
        with concurrent.futures.ProcessPoolExecutor(
            max_workers=n_gpus
        ) as executor:
            futures = []
            chunk_size = len(ds) // n_gpus
            completion_chunks = [
                all_completions[i : i + chunk_size]
                for i in range(0, len(all_completions), chunk_size)
            ]
            for i, chunk in enumerate(completion_chunks):
                device = f"cuda:{i % n_gpus}"
                futures.append(
                    executor.submit(
                        compute_perplexity,
                        predictions=chunk,
                        model_id=model,
                        batch_size=batch_size,
                        device=device,
                    )
                )
            for f in tqdm(
                concurrent.futures.as_completed(futures), total=len(futures)
            ):
                results = f.result()
                all_perplexities += results["perplexities"]

        print(pd.DataFrame(all_perplexities).describe())


if __name__ == "__main__":
    main()
