import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
from typing import Any, Literal
from dataclasses import dataclass, asdict
from transformers import Trainer, PreTrainedModel, PreTrainedTokenizerBase, AutoModelForSequenceClassification
from transformers.data.data_collator import DataCollatorMixin
from transformers.trainer_utils import PredictionOutput
from transformers.trainer_pt_utils import nested_detach
from datasets import Dataset
from typing_extensions import override
from trl.trainer.utils import pad

from args import RDArguments
from utils import plot_trainer_state, get_logger

logger = get_logger(__name__)


@dataclass
class RDCollator(DataCollatorMixin):
    pad_token_id: int
    eos_token_id: int
    model_token_ids: list[int]
    max_length: int | None = None
    return_tensors: str = "pt"

    def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
        """Collate a batch of examples into a batch of tensors.

        Args:
            examples: List of examples.

        Returns:
            A dictionary of 4 items: input_ids, attention_mask, model_token_indices and rewards.
        """
        max_length = self.max_length or sys.maxsize
        # Concatenate prompt, model ID and response
        rm_input_ids = [
            torch.tensor((example["prompt_input_ids"] + response_input_ids)[:max_length] + [self.eos_token_id])
            for example in examples
            for response_input_ids in example["response_input_ids"]
        ]
        rm_attention_mask = [torch.ones_like(ids) for ids in rm_input_ids]
        routing_input_ids = [
            torch.tensor(example["prompt_input_ids"] + [self.model_token_ids[i], self.eos_token_id])
            for example in examples
            for i in range(len(example["response_input_ids"]))
        ]
        routing_attention_mask = [torch.ones_like(ids) for ids in routing_input_ids]

        # Pad
        output = {
            "rm_input_ids": pad(rm_input_ids, self.pad_token_id, "left"),  # shape=(batch_size*n_candidates, seq_length)
            "rm_attention_mask": pad(rm_attention_mask, 0, "left"),  # shape=(batch_size*n_candidates, seq_length)
            "routing_input_ids": pad(
                routing_input_ids, self.pad_token_id, "left"
            ),  # shape=(batch_size*n_candidates, seq_length)
            "routing_attention_mask": pad(
                routing_attention_mask, 0, "left"
            ),  # shape=(batch_size*n_candidates, seq_length)
            "scores": torch.tensor([example["scores"] for example in examples]),  # shape=(batch_size, n_candidates)
        }

        return output


class RDRouterTrainer(Trainer):
    def __init__(
        self,
        model: PreTrainedModel | str,
        processing_class: PreTrainedTokenizerBase,
        model_ids: list[str],
        args: RDArguments,
        data_collator=None,
        train_dataset: Dataset | None = None,
        eval_dataset: Dataset | None = None,
        **kwargs,
    ):
        self.args = args
        self.model_ids = model_ids
        self.n_candidates = len(model_ids)
        self.model_token_ids: list[int] = processing_class.convert_tokens_to_ids(model_ids)
        if isinstance(model, str):
            model = AutoModelForSequenceClassification.from_pretrained(model, num_labels=1, pad_token_id=processing_class.pad_token_id)
        self.preprocess_args = {
            "fn_kwargs": {
                "tokenizer": processing_class,
                "max_prompt_length": args.max_prompt_length,
                "max_completion_length": args.max_completion_length,
            },
            "batched": False,
            "num_proc": args.dataset_num_proc,
        }
        with args.main_process_first(desc="preprocess dataset"):
            if train_dataset is not None:
                train_dataset = train_dataset.map(self.tokenize_row, **self.preprocess_args)
            if eval_dataset is not None:
                eval_dataset = eval_dataset.map(self.tokenize_row, **self.preprocess_args)
        if data_collator is None:
            data_collator = RDCollator(
                processing_class.pad_token_id, processing_class.eos_token_id, self.model_token_ids, args.max_length
            )
        args.label_names = ["scores"]
        self._stored_metrics = defaultdict(lambda: defaultdict(list))
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, processing_class, **kwargs)

    def plot(self):
        """Plots the training state."""
        plot_trainer_state(asdict(self.state), ["loss", "reward_loss", "distill_loss"], self.args.output_dir, "loss")
        plot_trainer_state(asdict(self.state), ["accuracy"], self.args.output_dir, "accuracy")
        plot_trainer_state(asdict(self.state), ["reward"], self.args.output_dir, "reward")

    def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    @override
    def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
        """Log `logs` on the various objects watching training, including stored metrics.

        Args:
            logs (`dict[str, float]`):
                The values to log.
            start_time (`float` or `None`, *optional*, defaults to `None`):
                Start time of the training.
        """
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]
        super().log(logs, start_time)

    @override
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        scores = inputs.pop("scores")  # shape=(batch_size, n_candidates)
        kwargs = {"output_hidden_states": True, "return_dict": True}

        # First forward: prompt + response
        outputs = model(input_ids=inputs["rm_input_ids"], attention_mask=inputs["rm_attention_mask"], **kwargs)
        rm_logits = outputs.logits
        rm_hidden_state = outputs.hidden_states[-1][:, -1, :]

        # Second forward: prompt + model_id
        outputs = model(input_ids=inputs["routing_input_ids"], attention_mask=inputs["routing_attention_mask"], **kwargs)
        routing_logits = outputs.logits
        routing_hidden_state = outputs.hidden_states[-1][:, -1, :]

        # Compute loss
        distill_loss = F.mse_loss(routing_hidden_state, rm_hidden_state.detach(), reduction="mean")
        if self.args.loss_type == "MSE":
            reward_loss = F.mse_loss(rm_logits, scores.view(-1), reduction="mean")
            routing_loss = F.mse_loss(routing_logits, scores.view(-1), reduction="mean")
        elif self.args.loss_type == "CE":
            target_indices = scores.argmax(1)
            reward_loss = F.cross_entropy(rm_logits.view(-1, self.n_candidates), target_indices, reduction="mean")
            routing_loss = F.cross_entropy(routing_logits.view(-1, self.n_candidates), target_indices, reduction="mean")
        else:
            log_score_distribution = F.log_softmax(scores / self.args.temperature, dim=1)
            log_reward_distribution = F.log_softmax(rm_logits.view(-1, self.n_candidates), dim=1)
            log_routing_distribution = F.log_softmax(routing_logits.view(-1, self.n_candidates), dim=1)
            if self.args.loss_type == "ForwardKL":
                reward_loss = F.kl_div(
                    log_reward_distribution, log_score_distribution, reduction="batchmean", log_target=True
                )
                routing_loss = F.kl_div(
                    log_routing_distribution, log_score_distribution, reduction="batchmean", log_target=True
                )
            elif self.args.loss_type == "ReverseKL":
                reward_loss = F.kl_div(
                    log_score_distribution, log_reward_distribution, reduction="batchmean", log_target=True
                )
                routing_loss = F.kl_div(
                    log_score_distribution, log_routing_distribution, reduction="batchmean", log_target=True
                )
            else:
                raise ValueError(f"Unsupported loss type: {self.args.loss_type}")
        is_training = reward_loss.requires_grad
        prefix = "" if is_training else "eval_"
        metrics = {
            f"{prefix}routing_loss": routing_loss.detach().cpu(),
            f"{prefix}reward_loss": reward_loss.detach().cpu(),
            f"{prefix}distill_loss": distill_loss.detach().cpu(),
        }
        self.store_metrics(metrics, train_eval="train" if is_training else "eval")
        # if is_training and self.model_accepts_loss_kwargs:
        #     reward_loss /= self.args.gradient_accumulation_steps
        loss = (
            1 - self.args.reward_loss_weight - self.args.distill_loss_weight
        ) * routing_loss + self.args.reward_loss_weight * reward_loss + self.args.distill_loss_weight * distill_loss
        return (loss, routing_logits.view(-1, self.n_candidates)) if return_outputs else loss

    @override
    def predict(
        self, test_dataset: Dataset, ignore_keys: list[str] | None = None, metric_key_prefix: str = "test"
    ) -> PredictionOutput:
        with self.args.main_process_first(desc="preprocess dataset"):
            test_dataset = test_dataset.map(self.tokenize_row, **self.preprocess_args)
        torch.cuda.empty_cache()
        outputs = super().predict(test_dataset, ignore_keys, metric_key_prefix)
        return outputs

    @override
    def prediction_step(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        prediction_loss_only: bool,
        ignore_keys: list[str] | None = None,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
        inputs = self._prepare_inputs(inputs)
        labels = nested_detach(inputs["scores"])
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", ["past_key_values"])
            else:
                ignore_keys = []

        with torch.no_grad():
            with self.compute_loss_context_manager():
                loss, rewards, *_ = self.compute_loss(model, inputs, return_outputs=True)
            loss = loss.mean().detach()

        if prediction_loss_only:
            return (loss, None, None)

        rewards = nested_detach(rewards.view(-1, self.n_candidates))
        return (loss, rewards, labels)  # type: ignore

    @override
    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            self._signature_columns = ["prompt_input_ids", "response_input_ids", "scores"]

    @staticmethod
    def tokenize_row(
        features: dict[str, Any],
        tokenizer: PreTrainedTokenizerBase,
        max_prompt_length: int | None = None,
        max_completion_length: int | None = None,
    ):
        prompt_input_ids: list[int] = tokenizer.encode(features["prompt"], add_special_tokens=False)  # type: ignore
        response_input_ids: list[list[int]] = [
            tokenizer.encode(response, add_special_tokens=False) for response in features["responses"]
        ]  # type: ignore

        # Truncate prompt and completion sequences
        if max_prompt_length is not None:
            prompt_input_ids = prompt_input_ids[-max_prompt_length:]
        if max_completion_length is not None:
            response_input_ids = [input_ids[:max_completion_length] for input_ids in response_input_ids]

        return {
            "prompt_input_ids": prompt_input_ids,
            "response_input_ids": response_input_ids,
        }
