import os
import time
from dataclasses import dataclass, field
from typing import Any, cast

import pandas as pd
import torch
from accelerate import Accelerator
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset
from tqdm import tqdm
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
)

from datasets import Dataset as HfDataset

GOLD_MODEL = "Ray2333/reward-model-Mistral-7B-instruct-Unified-Feedback"
max_length = 1024
dataset_num_proc = 30


device = "cuda" if torch.cuda.is_available() else None


@dataclass
class EvalCompletionRewardConfig:
    dataset: str = field(metadata={"help": "Path to the directory containing the completions across training"})
    reward_model: str = field(default=GOLD_MODEL, metadata={"help": "Name or path of the reward model to evaluate the completions with"})
    batch_size: int = field(default=8, metadata={"help": "Batch size"})
    interval: int = field(default=25, metadata={"help": "How often the completions were logged from the eval dataset"})
    num_steps: int = field(default=313, metadata={"help": "Number of steps completed during training"})


def eval_dataset_with_rm(config: EvalCompletionRewardConfig):
    """Evaluate generated responses with a reward model"""

    # reward_model_name_or_path: str, datasets_path: str, batch_size: int, interval: int, num_steps: int

    # accelerator = Accelerator()
    # print(f"Accelerator ({accelerator.process_index}/{accelerator.num_processes}")

    steps = [0] + [i for i in range(config.interval - 1, config.num_steps, config.interval)]

    ################
    # Model & tokenizer
    ################

    tokenizer = AutoTokenizer.from_pretrained(config.reward_model, use_fast=False)
    tokenizer.model_max_length = max_length
    tokenizer.pad_token = tokenizer.eos_token
    model = AutoModelForSequenceClassification.from_pretrained(config.reward_model, num_labels=1, torch_dtype=torch.bfloat16)

    ################
    # Dataset
    ################

    datasets: list[TorchDataset] = []
    proxy_log = {}

    for step in steps:
        epoch = 1 if step == (config.num_steps - 1) else 0  # constant for now
        dataset_path = os.path.join(config.dataset, f"epoch{int(epoch)}step{step:05d}_completions.csv")
        print(f"Loading dataset '{dataset_path}'...")

        raw_dataset: HfDataset = HfDataset.from_csv(dataset_path)  # type: ignore
        assert "prompts" in raw_dataset.column_names and "responses" in raw_dataset.column_names and "scores" in raw_dataset.column_names
        proxy_log[step] = raw_dataset["scores"]

        tokenizer_kwargs = {"padding": "max_length", "truncation": True, "max_length": max_length, "return_tensors": "pt"}

        def formatting_func(example: dict[str, Any]):
            prompt = example.get("prompts")
            response = example.get("responses")
            messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
            prompt_plus_response = tokenizer.apply_chat_template(messages, tokenize=False)
            tokens = tokenizer(prompt_plus_response, **tokenizer_kwargs)  # type: ignore

            return {"input_ids": tokens["input_ids"][0], "attention_mask": tokens["attention_mask"][0]}  # type: ignore

        raw_dataset = raw_dataset.map(formatting_func, remove_columns=raw_dataset.column_names, num_proc=dataset_num_proc)
        # raw_dataset = raw_dataset.filter(lambda x: len(x["input_ids"]) <= max_length, num_proc=dataset_num_proc)  # truncate to max_length is set, shoudn't be necessary
        raw_dataset.set_format(type="torch")

        datasets.append(raw_dataset)  # type: ignore

    ################
    # Eval
    ################

    model.to(device)  # model = accelerator.prepare(model)

    scores_log = pd.DataFrame()

    for step, dataset in zip(steps, datasets):
        data_loader = DataLoader(dataset, batch_size=config.batch_size, pin_memory=True)  # , collate_fn=DataCollatorWithPadding(tokenizer=tokenizer))

        all_rewards: list[Tensor] = []
        # all_texts: list[Tensor] = []

        # data_loader = accelerator.prepare(data_loader)

        with torch.no_grad():
            for batch in tqdm(data_loader, desc=f"Eval Gold Scores"):
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                reward: Tensor = model(input_ids, attention_mask=attention_mask).logits.reshape(-1)
                all_rewards.extend(reward)
                # all_texts.extend(batch["input_ids"])

        # all_texts = [x.rstrip(tokenizer.pad_token) for x in tokenizer.batch_decode(all_texts)]  # type: ignore
        all_rewards = [x.item() for x in all_rewards]  # type: ignore

        # print(f"Process {accelerator.local_process_index} processed {len(all_texts)} prompts")
        # accelerator.wait_for_everyone()

        # all_texts_gathered = accelerator.gather_for_metrics(all_texts)
        # all_rewards_gathered = accelerator.gather_for_metrics(all_rewards)

        ################
        # Results
        ################

        scores_log[f"gold_{step}"] = all_rewards
        scores_log[f"proxy_{step}"] = proxy_log[step]

    save_path = os.path.join(config.dataset, "allscores.csv")
    scores_log.to_csv(save_path)
    print(f"Saved to '{save_path}'.")


if __name__ == "__main__":
    parser = HfArgumentParser(EvalCompletionRewardConfig)  # type: ignore
    config = parser.parse_args_into_dataclasses()[0]

    start_time = time.time()
    print("Starting...")
    eval_dataset_with_rm(config)
    print("Finished.")
    print(f"Took {time.time() - start_time:.2f} s")
