import os
import torch
import torch.nn.functional as F
from dataclasses import dataclass, asdict
from typing import Any, Literal, override
from torch import Tensor
from accelerate.utils import tqdm
from datasets import Dataset
from torch.utils.data import DataLoader
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
from transformers.integrations import deepspeed_init
from transformers.trainer_pt_utils import nested_detach
from trl.trainer.utils import pad

from args import NCAArguments
from utils import plot_trainer_state, get_logger
from model.causal import load_model_with_reward_head
from .utils import compute_routing_loss
from .base import BaseRouterTrainer

logger = get_logger(__name__)


@dataclass
class CausalNCACollator(DataCollatorMixin):
    pad_token_id: int
    max_length: int | None = None
    use_logits_to_keep: bool = True
    return_tensors: str = "pt"

    def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
        """Collate a batch of examples into a batch of tensors.

        Args:
            examples: List of examples.

        Returns:
            A dictionary of 5 items:
                prompt_input_ids: tensor of shape (batch_size, max_prompt_length)
                prompt_attention_mask: tensor of shape (batch_size, max_prompt_length)
                response_input_ids: list of n_responses tensors, each of shape (batch_size, max_response_length)
                response_attention_mask: list of n_responses tensors, each of shape (batch_size, max_response_length)
                rewards: tensor of shape (batch_size, n_responses)
        """
        n_responses = len(examples[0]["response_input_ids"])  # Number of responses per prompt
        batch_size = len(examples)

        # Convert to tensor
        input_ids = [
            torch.tensor(example["prompt_input_ids"] + example["response_input_ids"][j])
            for example in examples
            for j in range(n_responses)
        ]
        attention_mask = [torch.ones_like(i) for i in input_ids]
        loss_mask = [
            torch.cat(
                (
                    torch.zeros(len(examples[i]["prompt_input_ids"]), dtype=torch.bool),
                    torch.ones(len(examples[i]["response_input_ids"][j]), dtype=torch.bool),
                )
            )
            for i in range(batch_size)
            for j in range(n_responses)
        ]

        # Pad
        output = {
            "input_ids": pad(input_ids, padding_value=self.pad_token_id),  # shape=(batch_size*n_responses, max_length)
            "attention_mask": pad(attention_mask, padding_value=0),  # shape=(batch_size*n_responses, max_length)
            "scores": torch.tensor([example["scores"] for example in examples]),  # shape=(batch_size, n_responses)
            "model_token_indices": torch.tensor(
                [len(example["prompt_input_ids"]) - 1 for example in examples for _ in example["response_input_ids"]]
            ),  # shape=(batch_size*n_candidates,) it's actually the last token of the prompt
            "loss_mask": pad(loss_mask, padding_value=0),  # shape=(batch_size*n_responses, max_length)
        }
        if "ref_logps" in examples[0]:
            output["ref_logps"] = torch.tensor(
                [example["ref_logps"] for example in examples]
            )  # shape=(batch_size, n_responses)

        # Truncate
        if self.max_length is not None:
            output["input_ids"] = output["input_ids"][:, : self.max_length]
            output["attention_mask"] = output["attention_mask"][:, : self.max_length]
            output["loss_mask"] = output["loss_mask"][:, : self.max_length]
        if self.use_logits_to_keep:
            min_prompt_length = min(len(example["prompt_input_ids"]) for example in examples) + 1  # +1 for model ID
            logits_to_keep = max(output["loss_mask"].shape[1] - min_prompt_length, 1)  # Keep at least 1 token
            output["loss_mask"] = output["loss_mask"][:, -logits_to_keep:]
            output["logits_to_keep"] = logits_to_keep
        return output


class CausalNCATrainer(BaseRouterTrainer):
    CollatorType = CausalNCACollator

    def __init__(
        self,
        model: PreTrainedModel | str,
        processing_class: PreTrainedTokenizerBase,
        model_ids: list[str],
        args: NCAArguments,
        data_collator=None,
        train_dataset: Dataset | None = None,
        eval_dataset: Dataset | None = None,
        **kwargs,
    ):
        self.args = args
        preprocess_args = {
            "fn_kwargs": {
                "tokenizer": processing_class,
                "max_prompt_length": args.max_prompt_length,
                "max_completion_length": args.max_completion_length,
            },
            "batched": False,
            "num_proc": args.dataset_num_proc,
        }
        if data_collator is None:
            data_collator = CausalNCACollator(processing_class.pad_token_id, args.max_length, args.use_logits_to_keep)
        super().__init__(
            model,
            processing_class,
            model_ids,
            args,
            data_collator,
            train_dataset,
            eval_dataset,
            preprocess_args,
            **kwargs,
        )

        # Precompute ref logps
        if args.precompute_ref_log_probs:
            if self.is_deepspeed_enabled and self.deepspeed is None:
                _, _ = deepspeed_init(self, num_training_steps=0, inference=True)
            model = self._wrap_model(self.model, training=False)
            if self.train_dataset is not None:
                ref_logps = self._precompute_ref_logps(self.train_dataset)
                self.train_dataset = self.train_dataset.add_column(name="ref_logps", column=ref_logps.tolist())
            if self.eval_dataset is not None:
                ref_logps = self._precompute_ref_logps(self.eval_dataset)
                self.eval_dataset = self.eval_dataset.add_column(name="ref_logps", column=ref_logps.tolist())

    def plot(self):
        """Plots the training state."""
        assert self.args.output_dir is not None, "output_dir must be set to plot the training state."
        state = asdict(self.state)
        log_keys = list(state["log_history"][0].keys())
        plot_trainer_state(state, ["loss"], self.args.output_dir, "train_loss")
        plot_trainer_state(state, [k for k in log_keys if "logps/A" in k], self.args.output_dir, "train_logp")
        plot_trainer_state(
            state,
            [k for k in log_keys if "rewards/A" in k] + ["rewards/margins"],
            self.args.output_dir,
            "train_reward",
        )
        plot_trainer_state(state, ["rewards/accuracies"], self.args.output_dir, "train_reward_accuracy")

    # @override
    # def get_train_dataloader(self) -> DataLoader:
    #     """Subclass to precompute `ref_log_probs`."""

    #     if self.args.precompute_ref_log_probs and not "ref_logps" in self.train_dataset.column_names:
    #         ref_logps = self._precompute_ref_logps(self.train_dataset)
    #         self.train_dataset = self.train_dataset.add_column(name="ref_logps", column=ref_logps.tolist())
    #         self._precomputed_train_ref_log_probs = True
    #     return super().get_train_dataloader()

    # @override
    # def get_eval_dataloader(self, eval_dataset: Dataset | None = None) -> DataLoader:
    #     """Subclass to precompute `ref_log_probs`."""

    #     eval_dataset = eval_dataset or self.eval_dataset
    #     if eval_dataset is None:
    #         raise ValueError("Trainer: evaluation requires an eval_dataset.")
    #     if self.args.precompute_ref_log_probs and not "ref_logps" in eval_dataset.column_names:
    #         ref_logps = self._precompute_ref_logps(eval_dataset)
    #         eval_dataset = eval_dataset.add_column(name="ref_logps", column=ref_logps.tolist())
    #         self._precomputed_eval_ref_log_probs = True
    #     return super().get_eval_dataloader(eval_dataset=eval_dataset)

    def compute_ref_log_probs(self, batch: dict[str, Tensor]) -> Tensor:
        """Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.

        Args:
            batch: The batch of inputs.

        Returns:
            The log probabilities of shape (batch_size, n_responses).
        """
        with torch.no_grad():
            logps, _ = self.concatenated_forward(batch)
        return logps

    def nca_loss(self, policy_logps: Tensor, reference_logps: Tensor, rewards: Tensor) -> tuple[Tensor, Tensor]:
        """Compute the NCA loss for a batch of policy and reference model log probabilities.

        Args:
            policy_logps: The log probabilities of the policy model for the batch, shape=(batch_size, n_responses).
            reference_logps: The log probabilities of the reference model for the batch.
            rewards: The rewards for the batch.

        Returns:
            A tuple of two tensors: the NCA loss and the model rewards (shape=(batch_size, n_responses)).
        """
        model_rewards = (policy_logps - reference_logps) * self.args.nca_beta  # shape=(batch_size, n_responses)
        # ? The definition of temperature_alpha here is different from that in the paper (temperature_alpha = 1 / paper_alpha)
        rewards /= self.args.nca_temperature_alpha
        soft_label = rewards.softmax(dim=-1)  # shape=(batch_size, n_responses)

        if self.args.loss_type == "InfoNCA":
            ratio_logits_p = model_rewards.log_softmax(dim=-1)
            losses = -(soft_label * ratio_logits_p).sum(dim=-1)
        elif self.args.loss_type == "NCA":
            losses = (
                -F.logsigmoid(-model_rewards).mean() - (soft_label * F.logsigmoid(model_rewards)).sum(dim=-1).mean()
            )
        else:
            raise ValueError(f"Unknown loss type: {self.args.loss_type}. Should be one of ['InfoNCA', 'NCA']")

        return losses, model_rewards.detach()

    @override
    def compute_loss(self, inputs: dict[str, Any], return_outputs=False):
        scores = inputs.pop("scores")  # shape=(batch_size, n_candidates)
        policy_logps, routing_logits = self.concatenated_forward(inputs)
        routing_loss = compute_routing_loss(routing_logits.view(-1, self.n_candidates), scores, self.args)
        is_training = routing_loss.requires_grad
        prefix = "" if is_training else "eval_"
        metrics = {f"{prefix}routing_loss": routing_loss.detach().cpu()}
        if is_training and self.model_accepts_loss_kwargs:
            routing_loss /= self.args.gradient_accumulation_steps
        if self.args.nca_loss_weight != 0:
            reference_logps = inputs.pop("ref_logps", self.compute_ref_log_probs(inputs))
            losses, model_rewards = self.nca_loss(policy_logps, reference_logps, scores)
            nca_loss = losses.mean()
            loss = routing_loss + self.args.nca_loss_weight * nca_loss
            reward_accuracies = (model_rewards[:, :-1] > model_rewards[:, 1:]).float().mean().cpu()
            mean_model_rewards = model_rewards.mean(dim=0).cpu()  # shape=(n_responses,)
            metrics[f"{prefix}nca_loss"] = nca_loss.detach().cpu()
            metrics[f"{prefix}loss"] = loss.detach().cpu()
            metrics[f"{prefix}rewards/accuracies"] = reward_accuracies
            metrics[f"{prefix}rewards/margins"] = (model_rewards[:, 0:1] - model_rewards[:, 1:]).mean().cpu()
            for i in range(policy_logps.shape[1]):
                metrics[f"{prefix}rewards/A{i}"] = mean_model_rewards[i]
                metrics[f"{prefix}logps/A{i}"] = policy_logps[:, i].mean().cpu()
        else:
            loss = routing_loss

        # force log the metrics
        if self.accelerator.is_main_process:
            self.store_metrics(metrics, train_eval="train")

        if return_outputs:
            return (loss, routing_logits)
        return loss

    @override
    def prediction_step(
        self,
        model: PreTrainedModel,
        inputs: dict[str, Any],
        prediction_loss_only: bool,
        ignore_keys: list[str] | None = None,
    ):
        inputs = self._prepare_inputs(inputs)
        labels = nested_detach(inputs["scores"])
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", ["past_key_values"])
            else:
                ignore_keys = []

        with torch.no_grad():
            with self.compute_loss_context_manager():
                loss, routing_logits = self.compute_loss(model, inputs, return_outputs=True)
            loss = loss.mean().detach()

        if prediction_loss_only:
            return (loss, None, None)

        routing_logits = nested_detach(routing_logits.view(-1, self.n_candidates))
        return (loss, routing_logits, labels)  # type: ignore

    def concatenated_forward(self, inputs: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
        input_ids = inputs["input_ids"]
        loss_mask = inputs.pop("loss_mask")
        lm_logits, _, routing_logits = self.model(input_ids[:, :-1], inputs["attention_mask"])
        labels = input_ids[:, 1:]
        loss_mask = loss_mask[:, 1:].bool()
        per_token_logps = torch.gather(lm_logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
        per_token_logps[~loss_mask] = 0
        all_logps = per_token_logps.sum(-1)
        batch_size = inputs["input_ids"].shape[0] // self.n_candidates
        return all_logps.reshape(batch_size, self.n_candidates), routing_logits

    def _precompute_ref_logps(self, dataset: Dataset) -> Tensor:
        self.model.eval()
        cache_path = (
            os.path.join(self.args.ref_log_prob_path, dataset._fingerprint, "ref_logp.pt")
            if self.args.ref_log_prob_path
            else None
        )
        if cache_path and os.path.exists(cache_path):
            ref_logps = torch.load(cache_path)
        else:
            batch_size = self.args.per_device_eval_batch_size
            dataloader_params = {
                "batch_size": batch_size,
                "collate_fn": self.data_collator,
                "num_workers": self.args.dataloader_num_workers,
                "pin_memory": self.args.dataloader_pin_memory,
                "shuffle": False,
            }
            data_loader = self.accelerator.prepare(DataLoader(dataset, **dataloader_params))
            ref_logps = []
            for padded_batch in tqdm(iterable=data_loader, desc="Dataset reference log probs"):
                logps: Tensor = self.compute_ref_log_probs(padded_batch)
                logps = self.accelerator.gather_for_metrics(logps)  # type: ignore
                ref_logps.append(logps.cpu())
            ref_logps = torch.cat(ref_logps)
            if cache_path:
                os.makedirs(os.path.dirname(cache_path), exist_ok=True)
                torch.save(ref_logps, cache_path)
        return ref_logps

    @override
    def _init_model(self, model: str, model_ids: list[str], pad_token_id: int, **kwargs) -> PreTrainedModel:
        id2label = {i: label for i, label in enumerate(model_ids)}
        label2id = {label: i for i, label in enumerate(model_ids)}
        model_instance = load_model_with_reward_head(
            model,
            num_labels=len(model_ids),
            id2label=id2label,
            label2id=label2id,
            pad_token_id=pad_token_id,
            use_cache=False,
            use_fixed_liger=False,
            **kwargs,
        )
        return model_instance

    @override
    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            self._signature_columns = ["prompt_input_ids", "response_input_ids", "ref_logps", "scores"]

    @staticmethod
    def tokenize_row(
        features: dict[str, str],
        tokenizer: PreTrainedTokenizerBase,
        max_prompt_length: int | None,
        max_completion_length: int | None,
    ):
        prompt_input_ids = tokenizer.apply_chat_template(
            [{"role": "user", "content": features["prompt"]}], add_generation_prompt=True
        )
        prompt_len = len(prompt_input_ids)
        response_input_ids = [
            tokenizer.apply_chat_template(
                [
                    {"role": "user", "content": features["prompt"]},
                    {"role": "assistant", "content": response},
                ],
                add_generation_prompt=True,
            )[prompt_len:]
            for response in features["responses"]
        ]

        # Truncate prompt and completion sequences
        if max_prompt_length is not None:
            prompt_input_ids = prompt_input_ids[-max_prompt_length:]
        if max_completion_length is not None:
            response_input_ids = [input_ids[:max_completion_length] for input_ids in response_input_ids]

        return {
            "prompt_input_ids": prompt_input_ids,
            "response_input_ids": response_input_ids,
        }
