import os
import time
from dataclasses import dataclass, field
from typing import Literal

import pandas as pd
import torch
from accelerate import Accelerator
from peft import PeftModel  # type: ignore
from torch import Tensor
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorWithPadding,
    HfArgumentParser,
)
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

from datasets import Dataset
from yang.load_eval_datasets import load_eval_dataset


def calculate_sequence_logprobs(logits: Tensor, input_ids: Tensor, attention_mask: Tensor):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = input_ids[..., 1:].contiguous()
    shift_mask = attention_mask[..., 1:].contiguous()

    log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
    token_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze(-1)

    token_log_probs = token_log_probs * shift_mask
    sequence_log_probs = token_log_probs.sum(dim=-1) / shift_mask.sum(dim=-1)
    return sequence_log_probs


@dataclass
class ScriptArgs:
    model: str = field(metadata={"help": "Path or name of the model to evaluate"})
    per_device_eval_batch_size: int = field(default=8)
    max_length: int = field(default=1024)
    log_dir: str = field(default="./eval_ppo")
    attn_implementation: str = field(default="flash_attention_2")


def main(config: ScriptArgs):
    accelerator = Accelerator()
    device = Accelerator().local_process_index

    # log_folder = config.model.removeprefix("models/").replace("/", "_")

    # log_path = os.path.join(config.log_dir, log_folder)
    # if accelerator.is_main_process:
    #     os.makedirs(log_path, exist_ok=True)

    # save_csv = False

    ################
    # Model & Tokenizer
    ################

    tokenizer = AutoTokenizer.from_pretrained(config.model, padding_side="left", use_fast=True, clean_up_tokenization_spaces=True, trust_remote_code=False)
    # use_fast=False? (Yang et al. 2024)
    tokenizer.model_max_length = config.max_length
    if tokenizer.pad_token is None:
        print("pad_token is None, replacing pad token with eos token")
        tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    if tokenizer.chat_template is None:
        print("chat_template is None, replacing chat template with SIMPLE_CHAT_TEMPLATE")
        tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

    model_kwargs = dict(
        device_map=device,
        use_cache=False,
        attn_implementation=config.attn_implementation,
        torch_dtype=torch.float16,
    )
    model = AutoModelForCausalLM.from_pretrained(config.model, **model_kwargs)
    model.resize_token_embeddings(len(tokenizer))
    model.config.pad_token_id = tokenizer.pad_token_id

    if False:
        model = PeftModel.from_pretrained(model, config.model)
    if hasattr(model, "merge_and_unload"):
        model = model.merge_and_unload()

    model.eval()

    tasks = ["unified", "hhh", "mtbench"]
    accuracies = {}

    for task in tasks:
        ################
        # Dataset
        ################

        eval_dataset: Dataset = load_eval_dataset(task, tokenizer)  # type: ignore
        print("Size of test dataset:", len(eval_dataset))

        def custom_collate_fn(batch: list[dict[str, Tensor]]):
            # Separate the different sequences
            source_ids = [item["source_id"] for item in batch]
            chosen_inputs = {"input_ids": [item["input_ids"] for item in batch], "attention_mask": [item["attention_mask_chosen"] for item in batch]}
            rejected_inputs = {"input_ids": [item["input_ids_rejected"] for item in batch], "attention_mask": [item["attention_mask_rejected"] for item in batch]}

            # Pad each set of sequences separately
            data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="pt")
            chosen_batch = data_collator(chosen_inputs)  # type: ignore
            rejected_batch = data_collator(rejected_inputs)  # type: ignore

            # Combine everything back
            return {
                "source_id": torch.stack(source_ids),
                "input_ids": chosen_batch["input_ids"],
                "attention_mask_chosen": chosen_batch["attention_mask"],
                "input_ids_rejected": rejected_batch["input_ids"],
                "attention_mask_rejected": rejected_batch["attention_mask"],
            }

        # data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
        eval_data_loader = DataLoader(eval_dataset, batch_size=config.per_device_eval_batch_size, drop_last=True, collate_fn=custom_collate_fn)  # type: ignore
        eval_data_loader = accelerator.prepare(eval_data_loader)

        ################
        # Inference
        ################

        total_correct = 0
        total_samples = 0
        all_results = []

        progress_bar = tqdm(eval_data_loader, desc="Eval")

        with torch.no_grad():
            for batch in progress_bar:
                chosen, chosen_mask, rejected, rejected_mask = (
                    batch["input_ids"].to(model.device),
                    batch["attention_mask_chosen"].to(model.device),
                    batch["input_ids_rejected"].to(model.device),
                    batch["attention_mask_rejected"].to(model.device),
                )
                chosen_logits = model(input_ids=chosen, attention_mask=chosen_mask).logits
                rejected_logits = model(input_ids=rejected, attention_mask=rejected_mask).logits
                chosen_log_probs = calculate_sequence_logprobs(chosen_logits, chosen, chosen_mask)
                rejected_log_probs = calculate_sequence_logprobs(rejected_logits, rejected, rejected_mask)

                predictions = (chosen_log_probs > rejected_log_probs).float()
                predictions: Tensor = accelerator.gather_for_metrics(predictions)  # type: ignore

                total_correct += predictions.sum().item()
                total_samples += len(predictions)

                batch_results = {
                    "source_ids": batch["source_id"].cpu().numpy(),
                    "chosen_log_probs": chosen_log_probs.cpu().numpy(),
                    "rejected_log_probs": rejected_log_probs.cpu().numpy(),
                    "correct": predictions.cpu().numpy(),
                }
                all_results.append(batch_results)

                accuracy = total_correct / total_samples
                progress_bar.set_postfix({"accuracy": f"{accuracy:.3f}"})
        final_accuracy = total_correct / total_samples
        accuracies[task] = final_accuracy

    if accelerator.is_main_process:
        for task in tasks:
            print("{" + f"{accuracies[task] * 100:.1f}" + "}", end="")
        print()

        # if save_csv:
        #     dataframe.to_csv(f"{log_path}/eval_data.csv")
        #     print(f"Saved to '{log_path}/eval_data.csv'.")


if __name__ == "__main__":
    parser = HfArgumentParser(ScriptArgs)  # type: ignore
    config = parser.parse_args_into_dataclasses()[0]

    start_time = time.time()
    print("Starting...")
    main(config)
    print("Finished.")
    print(f"Took {time.time() - start_time:.2f} s")
