import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
from typing import Any, Literal
from dataclasses import dataclass, asdict
from transformers import Trainer, PreTrainedModel, PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
from transformers.trainer_utils import PredictionOutput
from transformers.trainer_pt_utils import nested_detach
from datasets import Dataset
from typing_extensions import override
from trl.trainer.utils import pad

from args import TokenPredictionArguments
from utils import plot_trainer_state, get_logger
from model.with_lm_head import load_router_with_lm_head

logger = get_logger(__name__)


@dataclass
class TokenPredictionCollator(DataCollatorMixin):
    pad_token_id: int
    token_space_size: int
    return_tensors: str = "pt"

    def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
        # Create tensors
        input_ids = [torch.tensor(example["input_ids"]) for example in examples]
        attention_mask = [torch.ones_like(ids) for ids in input_ids]
        tokens = []
        for example in examples:
            token_labels = torch.zeros(self.token_space_size, dtype=torch.float32)
            token_labels[example["common_response_token_ids"]] = 1
            tokens.append(token_labels)
        tokens = torch.stack(tokens)
        scores = torch.tensor([example["scores"] for example in examples])

        # Pad
        output = {
            "input_ids": pad(input_ids, self.pad_token_id, "left"),  # shape=(batch_size, seq_length)
            "attention_mask": pad(attention_mask, 0, "left"),  # shape=(batch_size, seq_length)
            "tokens": tokens,  # shape=(batch_size, token_space_size)
            "scores": scores,  # shape=(batch_size, n_candidates)
        }
        return output


class TokenPredictionTrainer(Trainer):
    def __init__(
        self,
        model: PreTrainedModel | str,
        processing_class: PreTrainedTokenizerBase,
        model_ids: list[str],
        args: TokenPredictionArguments,
        data_collator=None,
        train_dataset: Dataset | None = None,
        eval_dataset: Dataset | None = None,
        **kwargs,
    ):
        self.args = args
        self.n_candidates = len(model_ids)
        if isinstance(model, str):
            model = load_router_with_lm_head(model, self.n_candidates, **args.model_init_kwargs)
        self.preprocess_args = {
            "fn_kwargs": {
                "tokenizer": processing_class,
                "max_length": args.max_length,
                "good_threshold": args.good_threshold,
            },
            "batched": False,
            "num_proc": args.dataset_num_proc,
        }
        with args.main_process_first(desc="preprocess dataset"):
            if train_dataset is not None:
                train_dataset = train_dataset.map(self.tokenize_row, **self.preprocess_args)
            if eval_dataset is not None:
                eval_dataset = eval_dataset.map(self.tokenize_row, **self.preprocess_args)
        if data_collator is None:
            data_collator = TokenPredictionCollator(processing_class.pad_token_id, model.vocab_size)
        args.label_names = ["tokens", "scores"]
        self._stored_metrics = defaultdict(lambda: defaultdict(list))
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, processing_class, **kwargs)

    def plot(self):
        """Plots the training state."""
        plot_trainer_state(asdict(self.state), ["loss", "routing_loss", "token_loss"], self.args.output_dir, "loss")
        plot_trainer_state(asdict(self.state), ["accuracy"], self.args.output_dir, "accuracy")
        plot_trainer_state(asdict(self.state), ["reward"], self.args.output_dir, "reward")

    def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    @override
    def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
        """Log `logs` on the various objects watching training, including stored metrics.

        Args:
            logs (`dict[str, float]`):
                The values to log.
            start_time (`float` or `None`, *optional*, defaults to `None`):
                Start time of the training.
        """
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]
        super().log(logs, start_time)

    @override
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        scores = inputs.pop("scores")  # shape=(batch_size, n_candidates)
        tokens = inputs.pop("tokens")  # shape=(batch_size, token_space_size)
        _, routing_logits, token_logits = model(**inputs)

        # Compute routing loss
        if self.args.loss_type == "MSE":
            routing_loss = F.mse_loss(routing_logits, scores, reduction="mean")
        elif self.args.loss_type == "CE":
            target_indices = scores.argmax(dim=1)
            routing_loss = F.cross_entropy(routing_logits, target_indices, reduction="mean")
        else:
            scores = scores / self.args.temperature
            log_score_distribution = F.log_softmax(scores, dim=1)
            log_predicted_distribution = F.log_softmax(routing_logits.view(-1, self.n_candidates), dim=1)
            if self.args.loss_type == "ForwardKL":
                routing_loss = F.kl_div(
                    log_predicted_distribution, log_score_distribution, reduction="batchmean", log_target=True
                )
            elif self.args.loss_type == "ReverseKL":
                routing_loss = F.kl_div(
                    log_score_distribution, log_predicted_distribution, reduction="batchmean", log_target=True
                )
            else:
                raise ValueError(f"Unsupported loss type: {self.args.loss_type}")

        is_training = routing_loss.requires_grad
        prefix = "" if is_training else "eval_"
        metrics = {f"{prefix}routing_loss": routing_loss.item()}
        if self.args.token_loss_weight != 0:
            token_loss = F.binary_cross_entropy_with_logits(token_logits, tokens, reduction="mean")
            metrics[f"{prefix}token_loss"] = token_loss.item()
            loss = (1 - self.args.token_loss_weight) * routing_loss + self.args.token_loss_weight * token_loss
        else:
            loss = routing_loss
        self.store_metrics(metrics, train_eval="train" if is_training else "eval")
        return (loss, routing_logits) if return_outputs else loss

    @override
    def predict(
        self, test_dataset: Dataset, ignore_keys: list[str] | None = None, metric_key_prefix: str = "test"
    ) -> PredictionOutput:
        torch.cuda.empty_cache()
        with self.args.main_process_first(desc="preprocess dataset"):
            test_dataset = test_dataset.map(self.tokenize_row, **self.preprocess_args)
        outputs = super().predict(test_dataset, ignore_keys, metric_key_prefix)
        return outputs

    @override
    def prediction_step(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        prediction_loss_only: bool,
        ignore_keys: list[str] | None = None,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
        inputs = self._prepare_inputs(inputs)
        scores = nested_detach(inputs["scores"])
        with torch.no_grad():
            with self.compute_loss_context_manager():
                loss, logits = self.compute_loss(model, inputs, return_outputs=True)
            loss = loss.mean().detach()

        if prediction_loss_only:
            return (loss, None, None)
        else:
            return (loss, nested_detach(logits), scores)  # type: ignore

    @override
    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            self._signature_columns = ["input_ids", "common_response_token_ids", "scores"]

    @staticmethod
    def tokenize_row(
        features: dict[str, Any],
        tokenizer: PreTrainedTokenizerBase,
        max_length: int | None = None,
        good_threshold: float = 1,
    ):
        input_ids = tokenizer.encode(features["prompt"], truncation=True, max_length=max_length)
        max_score = max(features["scores"])
        good_responses = [
            response
            for response, score in zip(features["responses"], features["scores"])
            if score >= good_threshold * max_score
        ]
        bad_responses = [
            response
            for response, score in zip(features["responses"], features["scores"])
            if score < good_threshold * max_score
        ]
        good_response_token_ids = [
            set(tokenizer.encode(response, add_special_tokens=False)) for response in good_responses
        ]
        bad_response_token_ids = [
            set(tokenizer.encode(response, add_special_tokens=False)) for response in bad_responses
        ]
        good_token_ids = set.intersection(*good_response_token_ids) - (
            set.union(*bad_response_token_ids) if bad_response_token_ids else set()
        )

        return {
            "input_ids": input_ids,
            "common_response_token_ids": list(good_token_ids),
            "scores": features["scores"],
        }
