import os
import torch
import torch.amp as amp
import torch.nn as nn
import torch.nn.functional as F
from contextlib import nullcontext
from dataclasses import dataclass, asdict
from typing import Any, Callable, Literal
from torch import Tensor
from accelerate import PartialState
from accelerate.utils import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalLoopOutput
from trl import DPOTrainer
from trl.data_utils import maybe_extract_prompt
from trl.trainer.utils import pad

from args import NCAArguments
from utils import get_logger
from .base import BaseRouterTrainer


logger = get_logger(__name__)

@dataclass
class BertNCACollator(DataCollatorMixin):
    """Data collator used for NCA data. Inputs are dynamically padded to the maximum length of a batch
    if they are not all of the same length.

    Args:
        pad_token_id (`int`):
            Token ID to use for padding.
        return_tensors (`str`, *optional*, defaults to `"pt"`):
            Type of Tensor to return. Only `"pt"` is currently supported.
    """

    pad_token_id: int
    mask_token_id: int
    model_token_ids: list[int]
    mask_length: int
    packing: bool = False
    return_tensors: str = "pt"

    def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
        """Collate a batch of examples into a batch of tensors.

        Args:
            examples: List of examples.

        Returns:
            A dictionary of 5 items:
                prompt_input_ids: tensor of shape (batch_size, max_prompt_length)
                prompt_attention_mask: tensor of shape (batch_size, max_prompt_length)
                response_input_ids: list of n_responses tensors, each of shape (batch_size, max_response_length)
                response_attention_mask: list of n_responses tensors, each of shape (batch_size, max_response_length)
                rewards: tensor of shape (batch_size, n_responses)
        """
        n_candidates = len(examples[0]["response_input_ids"])  # Number of responses per prompt
        batch_size = len(examples)

        # Convert to tensor
        input_ids = [
            torch.tensor(examples[i]["prompt_input_ids"] + examples[i]["response_input_ids"][j])
            for i in range(batch_size)
            for j in range(n_candidates)
        ]
        attention_mask = [torch.ones_like(i) for i in input_ids]
        loss_mask = [
            torch.cat(
                (
                    torch.zeros(len(examples[i]["prompt_input_ids"])),
                    torch.ones(len(examples[i]["response_input_ids"][j])),
                )
            )
            for i in range(batch_size)
            for j in range(n_candidates)
        ]

        # Pad
        output = {}
        output["all_input_ids"] = pad(input_ids, padding_value=self.pad_token_id) # shape=(batch_size*n_responses, max_length)
        output["attention_mask"] = pad(attention_mask, padding_value=0) # shape=(batch_size*n_responses, max_length)
        output["loss_mask"] = pad(loss_mask, padding_value=0) # shape=(batch_size*n_responses, max_length)
        output["rewards"] = torch.tensor([example["scores"] for example in examples]) # shape=(batch_size, n_responses)
        if "ref_logps" in examples[0]:
            output["ref_logps"] = torch.tensor([example["ref_logps"] for example in examples]) # shape=(batch_size, n_responses)

        # Truncate
        if self.max_length is not None:
            output["all_input_ids"] = output["all_input_ids"][:, : self.max_length]
            output["attention_mask"] = output["attention_mask"][:, : self.max_length]
            output["loss_mask"] = output["loss_mask"][:, : self.max_length]

        return output


class BertNCARouterTrainer(BaseRouterTrainer):
    CollatorType = BertNCACollator

    def __init__(
        self,
        model: PreTrainedModel | str,
        processing_class: PreTrainedTokenizerBase,
        model_ids: list[str],
        args: NCAArguments,
        data_collator=None,
        train_dataset: Dataset | None = None,
        eval_dataset: Dataset | None = None,
        **kwargs,
    ):
        self.args = args
        self.nca_temperature_alpha = args.nca_temperature_alpha
        preprocess_args = {
            "fn_kwargs": {
                "tokenizer": processing_class,
                "max_prompt_length": args.max_length - args.mask_length,
                "max_completion_length": args.mask_length,
            },
            "batched": False,
            "num_proc": args.dataset_num_proc,
        }
        assert train_dataset is not None, "`train_dataset` is required for training."
        with PartialState().local_main_process_first():
            train_dataset = train_dataset.map(
                self._apply_chat_template,
                fn_kwargs={"tokenizer": processing_class},
                num_proc=args.dataset_num_proc,
                desc="Applying chat template to train dataset",
            )
            if eval_dataset is not None:
                eval_dataset = eval_dataset.map(
                    maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
                )
                eval_dataset = eval_dataset.map(
                    self._apply_chat_template,
                    fn_kwargs={"tokenizer": processing_class},
                    num_proc=args.dataset_num_proc,
                    desc="Applying chat template to eval dataset",
                )

        super().__init__(
            model=model,
            ref_model=ref_model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            processing_class=processing_class,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,  # type: ignore
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            peft_config=peft_config,
        )
        if data_collator is None:  # Replace default data collator with NCACollator
            self.data_collator = BertNCACollator(pad_token_id=self.padding_value)  # type: ignore

    def plot(self):
        """Plots the training state."""
        state = asdict(self.state)
        log_keys = list(state["log_history"][0].keys())
        plot_trainer_state(state, ["loss"], self.args.output_dir, "train_loss")
        plot_trainer_state(
            state, [k for k in log_keys if "logps/A" in k], self.args.output_dir, "train_logp"
        )
        plot_trainer_state(
            state,
            [k for k in log_keys if "rewards/A" in k] + ["rewards/margins"],
            self.args.output_dir,
            "train_reward",
        )
        plot_trainer_state(
            state, ["rewards/accuracies"], self.args.output_dir, "train_reward_accuracy"
        )

    def get_train_dataloader(self) -> DataLoader:
        """Subclass to precompute `ref_log_probs`."""

        if self.precompute_ref_log_probs and not "ref_logps" in self.train_dataset.column_names:
            ref_logps = self._precompute_ref_logps(self.train_dataset)
            self.train_dataset = self.train_dataset.add_column(name="ref_logps", column=ref_logps.tolist())
            self._precomputed_train_ref_log_probs = True
        return super().get_train_dataloader()

    def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
        """Subclass to precompute `ref_log_probs`."""

        eval_dataset = eval_dataset or self.eval_dataset
        if eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        if self.precompute_ref_log_probs and not "ref_logps" in eval_dataset.column_names:
            ref_logps = self._precompute_ref_logps(eval_dataset)
            eval_dataset = eval_dataset.add_column(name="ref_logps", column=ref_logps.tolist())
            self._precomputed_eval_ref_log_probs = True
        return super().get_eval_dataloader(eval_dataset=eval_dataset)

    def compute_ref_log_probs(self, batch: dict[str, Tensor]) -> Tensor:
        """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.

        Args:
            batch: The batch of inputs.

        Returns:
            The log probabilities of shape (batch_size, n_responses).
        """
        compte_ref_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
        with torch.no_grad(), compte_ref_context_manager:
            if self.ref_model is None:
                with self.null_ref_context():
                    logps = self.concatenated_forward(self.model, batch)
            else:
                logps = self.concatenated_forward(self.ref_model, batch)
        return logps

    def nca_loss(self, policy_logps: Tensor, reference_logps: Tensor, rewards: Tensor) -> tuple[Tensor, Tensor]:
        """Compute the NCA loss for a batch of policy and reference model log probabilities.

        Args:
            policy_logps: The log probabilities of the policy model for the batch, shape=(batch_size, n_responses).
            reference_logps: The log probabilities of the reference model for the batch.
            rewards: The rewards for the batch.

        Returns:
            A tuple of two tensors: the NCA loss and the model rewards (shape=(batch_size, n_responses)).
        """
        model_rewards = (policy_logps - reference_logps) * self.beta  # shape=(batch_size, n_responses)
        # ? The definition of temperature_alpha here is different from that in the paper (temperature_alpha = 1 / paper_alpha)
        rewards /= self.nca_temperature_alpha
        soft_label = rewards.softmax(dim=-1)  # shape=(batch_size, n_responses)

        if self.loss_type == "InfoNCA":
            ratio_logits_p = model_rewards.log_softmax(dim=-1)
            losses = -(soft_label * ratio_logits_p).sum(dim=-1)
        elif self.loss_type == "NCA":
            losses = (
                -F.logsigmoid(-model_rewards).mean() - (soft_label * F.logsigmoid(model_rewards)).sum(dim=-1).mean()
            )
        else:
            raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['InfoNCA', 'NCA']")

        return losses, model_rewards.detach()

    def get_batch_loss_metrics(
        self, model: PreTrainedModel, batch: dict[str, Tensor], train_eval: Literal["train", "eval"] = "train"
    ):
        """Compute the NCA loss and other metrics for the given batch of inputs for train or test."""
        policy_logps = self.concatenated_forward(model, batch)
        reference_logps = batch["ref_logps"] if "ref_logps" in batch else self.compute_ref_log_probs(batch)
        losses, model_rewards = self.nca_loss(policy_logps, reference_logps, batch["rewards"])
        reward_accuracies = (model_rewards[:, :-1] > model_rewards[:, 1:]).float().mean().cpu()
        mean_model_rewards = model_rewards.mean(dim=0).cpu()  # shape=(n_responses,)
        prefix = "eval_" if train_eval == "eval" else ""
        metrics = {}
        metrics[f"{prefix}rewards/accuracies"] = reward_accuracies
        metrics[f"{prefix}rewards/margins"] = (model_rewards[:, 0:1] - model_rewards[:, 1:]).mean().cpu()
        for i in range(policy_logps.shape[1]):
            metrics[f"{prefix}rewards/A{i}"] = mean_model_rewards[i]
            metrics[f"{prefix}logps/A{i}"] = policy_logps[:, i].mean().cpu()
        return losses.mean(), metrics

    def compute_loss(
        self,
        model: PreTrainedModel,
        inputs: dict[str, Any],
        return_outputs=False,
        num_items_in_batch=None,
    ) -> Tensor | tuple[Tensor, dict[str, Tensor]]:
        loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")

        # force log the metrics
        if self.accelerator.is_main_process:
            self.store_metrics(metrics, train_eval="train")

        if self.model_accepts_loss_kwargs is not None:
            loss /= self.args.gradient_accumulation_steps

        if return_outputs:
            return (loss, metrics)
        return loss

    def prediction_step(
        self,
        model: PreTrainedModel,
        inputs: dict[str, Any],
        prediction_loss_only: bool,
        ignore_keys: list[str] | None = None,
    ) -> tuple[Tensor, Tensor, Tensor] | tuple[Tensor, dict[str, Tensor], Tensor]:
        if ignore_keys is None:
            if hasattr(model, "config"):
                ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
            else:
                ignore_keys = []

        prediction_context_manager = amp.autocast("cuda") if self._peft_has_been_casted_to_bf16 else nullcontext()
        with torch.no_grad(), prediction_context_manager:
            loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")

        # force log the metrics
        self.store_metrics(metrics, train_eval="eval")

        if prediction_loss_only:
            return loss.detach(), None, None

        # logits for the chosen and rejected samples from model
        logits_dict = {k: v for k, v in metrics.items() if "logits" in k}
        logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
        logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
        labels = torch.zeros(logits.shape[0], device=self.accelerator.device)

        return (loss.detach(), logits, labels)

    def concatenated_forward(self, model: nn.Module, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
        """Run the given model on the given batch of inputs.

        Args:
            model: The model to run.
            batch: The batch of inputs.

        Returns:
            A tuple of two tensors: the log probabilities of shape (batch_size, n_responses).
        """

        batch_size, n_responses = batch["rewards"].shape
        input_ids = batch["all_input_ids"]
        loss_mask = batch["loss_mask"]
        outputs = model(input_ids=input_ids, attention_mask=batch["attention_mask"])
        logits = outputs.logits[:, :-1, :]
        labels = input_ids[:, 1:].clone()
        loss_mask = loss_mask[:, 1:].bool()
        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
        per_token_logps[~loss_mask] = 0
        all_logps = per_token_logps.sum(-1)
        return all_logps.reshape(batch_size, n_responses)

    def _precompute_ref_logps(self, dataset: Dataset) -> Tensor:
        cache_path = (
            os.path.join(self.args.ref_log_prob_path, dataset._fingerprint, "ref_logp.pt")
            if self.args.ref_log_prob_path
            else None
        )
        if cache_path and os.path.exists(cache_path):
            ref_logps = torch.load(cache_path)
        else:
            batch_size = self.args.precompute_ref_batch_size or self.args.per_device_train_batch_size
            dataloader_params = {
                "batch_size": batch_size,
                "collate_fn": self.data_collator,
                "num_workers": self.args.dataloader_num_workers,
                "pin_memory": self.args.dataloader_pin_memory,
                "shuffle": False,
            }
            data_loader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params))
            ref_logps = []
            for padded_batch in tqdm(iterable=data_loader, desc="Dataset reference log probs"):
                logps: Tensor = self.compute_ref_log_probs(padded_batch)
                logps = self.accelerator.gather_for_metrics(logps)  # type: ignore
                ref_logps.append(logps.cpu())
            ref_logps = torch.cat(ref_logps)
            if cache_path:
                os.makedirs(os.path.dirname(cache_path), exist_ok=True)
                torch.save(ref_logps, cache_path)
        return ref_logps

    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            self._signature_columns = ["prompt_input_ids", "response_input_ids", "ref_logps", "scores"]

    @staticmethod
    def tokenize_row(
        features: dict[str, str],
        processing_class: PreTrainedTokenizerBase,
        max_prompt_length: int | None,
        max_completion_length: int | None,
        add_special_tokens: bool,
    ):
        """
        Tokenize a row of the dataset.

        Args:
            features: Row of the dataset, should contain the keys `"prompt"` and `"responses"`.
            processing_class: Processing class used to process the data.
            max_prompt_length: Maximum length of the prompt sequence. If `None`, the prompt sequence is not truncated.
            max_completion_length: Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
            add_special_tokens: Whether to add special tokens to the sequences. Typically used for encoder-decoder models. If `True`, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.

        Returns:
            Tokenized sequences with the keys `"prompt_input_ids"` and `"response_input_ids".
        """
        tokenizer = processing_class  # the processing class is a tokenizer
        prompt_input_ids: list[int] = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
        response_input_ids: list[list[int]] = [
            tokenizer(response, add_special_tokens=False)["input_ids"] for response in features["responses"]
        ]

        # Add special tokens (typically for encoder-decoder models)
        if add_special_tokens:
            if tokenizer.bos_token_id is not None:
                prompt_input_ids = [tokenizer.bos_token_id] + prompt_input_ids
            if tokenizer.eos_token_id is not None:
                prompt_input_ids = prompt_input_ids + [tokenizer.eos_token_id]
        for input_ids in response_input_ids:
            input_ids.append(tokenizer.eos_token_id)

        # Truncate prompt and completion sequences
        if max_prompt_length is not None:
            prompt_input_ids = prompt_input_ids[-max_prompt_length:]
        if max_completion_length is not None:
            response_input_ids = [input_ids[:max_completion_length] for input_ids in response_input_ids]

        return {
            "prompt_input_ids": prompt_input_ids,
            "response_input_ids": response_input_ids,
        }
