import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
from typing import Any, Literal
from dataclasses import dataclass, asdict
from transformers import Trainer, PreTrainedModel, PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
from transformers.trainer_utils import PredictionOutput
from transformers.trainer_pt_utils import nested_detach
from datasets import Dataset
from typing_extensions import override
from trl.trainer.utils import pad

from args import ReconstructArguments
from utils import plot_trainer_state, get_logger
from model.reconstruct import load_model_with_hidden_variable_and_reward_head

logger = get_logger(__name__)


@dataclass
class ReconstructCollator(DataCollatorMixin):
    pad_token_id: int
    eos_token_id: int
    max_length: int | None = None
    use_logits_to_keep: bool = True
    return_tensors: str = "pt"

    def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
        # Concatenate prompt, model ID and response
        max_length = self.max_length or sys.maxsize
        prompt_input_ids = [torch.tensor(example["prompt_input_ids"] + [self.eos_token_id]) for example in examples]
        prompt_attention_mask = [torch.ones_like(ids) for ids in prompt_input_ids]
        sequence_input_ids = [
            torch.tensor(
                (example["prompt_input_ids"] + [self.pad_token_id] + response_input_ids)[:max_length]
                + [self.eos_token_id]
            )
            for example in examples
            for response_input_ids in example["response_input_ids"]
        ]  # [PAD] will be replaced with hidden variable in the forwarding; and make sure len(example["prompt_input_ids"]) + 1 <= max_length
        sequence_attention_mask = [torch.ones_like(ids) for ids in sequence_input_ids]
        labels = [
            torch.tensor(([-100] * (len(example["prompt_input_ids"]) + 1) + response_input_ids)[:max_length])
            for example in examples
            for response_input_ids in example["response_input_ids"]
        ]

        # Pad
        output = {
            "prompt_input_ids": pad(prompt_input_ids, self.pad_token_id, "left"),  # shape=(batch_size, seq_length)
            "prompt_attention_mask": pad(prompt_attention_mask, 0, "left"),  # shape=(batch_size, seq_length)
            "sequence_input_ids": pad(
                sequence_input_ids, self.pad_token_id, "left"
            ),  # shape=(batch_size*n_candidates, seq_length)
            "sequence_attention_mask": pad(
                sequence_attention_mask, 0, "left"
            ),  # shape=(batch_size*n_candidates, seq_length)
            "labels": pad(labels, -100, "left"),  # shape=(batch_size*n_candidates, seq_length)
            "scores": torch.tensor([example["scores"] for example in examples]),  # shape=(batch_size, n_candidates)
        }
        output["hidden_variable_positions"] = (output["labels"] == -100).sum(
            dim=1
        ) - 1  # shape=(batch_size*n_candidates,)

        if self.use_logits_to_keep:
            start_idx = (output["labels"] != -100).sum(dim=1).argmin()
            output["labels"] = output["labels"][:, start_idx:]
            output["logits_to_keep"] = output["labels"].shape[1]
        return output


class ReconstructRouterTrainer(Trainer):
    def __init__(
        self,
        model: PreTrainedModel | str,
        processing_class: PreTrainedTokenizerBase,
        model_ids: list[str],
        args: ReconstructArguments,
        data_collator=None,
        train_dataset: Dataset | None = None,
        eval_dataset: Dataset | None = None,
        **kwargs,
    ):
        self.args = args
        self.n_candidates = len(model_ids)
        if isinstance(model, str):
            model = load_model_with_hidden_variable_and_reward_head(
                model,
                self.n_candidates,
                use_fixed_liger=args.use_liger_kernel,
                pad_token_id=processing_class.pad_token_id,
                **args.model_init_kwargs,
            )
        self.preprocess_args = {
            "fn_kwargs": {
                "tokenizer": processing_class,
                "max_prompt_length": args.max_prompt_length,
                "max_completion_length": args.max_completion_length,
            },
            "batched": False,
            "num_proc": args.dataset_num_proc,
        }
        with args.main_process_first(desc="preprocess dataset"):
            if train_dataset is not None:
                train_dataset = train_dataset.map(self.tokenize_row, **self.preprocess_args)
            if eval_dataset is not None:
                eval_dataset = eval_dataset.map(self.tokenize_row, **self.preprocess_args)
        if data_collator is None:
            data_collator = ReconstructCollator(
                processing_class.pad_token_id, processing_class.eos_token_id, args.max_length, args.use_logits_to_keep
            )
        args.label_names = ["scores"]
        self._stored_metrics = defaultdict(lambda: defaultdict(list))
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, processing_class, **kwargs)

    def plot(self):
        """Plots the training state."""
        plot_trainer_state(asdict(self.state), ["loss", "reward_loss", "reconstruct_loss"], self.args.output_dir, "loss")
        plot_trainer_state(asdict(self.state), ["accuracy"], self.args.output_dir, "accuracy")
        plot_trainer_state(asdict(self.state), ["reward"], self.args.output_dir, "reward")

    def store_metrics(self, metrics: dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
        for key, value in metrics.items():
            self._stored_metrics[train_eval][key].append(value)

    @override
    def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
        """Log `logs` on the various objects watching training, including stored metrics.

        Args:
            logs (`dict[str, float]`):
                The values to log.
            start_time (`float` or `None`, *optional*, defaults to `None`):
                Start time of the training.
        """
        # logs either has 'loss' or 'eval_loss'
        train_eval = "train" if "loss" in logs else "eval"
        # Add averaged stored metrics to logs
        for key, metrics in self._stored_metrics[train_eval].items():
            logs[key] = torch.tensor(metrics).mean().item()
        del self._stored_metrics[train_eval]
        super().log(logs, start_time)

    @override
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        scores = inputs.pop("scores")  # shape=(batch_size, n_candidates)
        target_indices = scores.argmax(1)
        if self.args.reconstruct_loss_weight == 0:
            inputs.pop("labels")
        reconstruct_loss, rewards = model(**inputs, num_items_in_batch=num_items_in_batch)
        if self.args.loss_type == "MSE":
            reward_loss = F.mse_loss(rewards, scores.view(-1), reduction="mean")
        elif self.args.loss_type == "CE":
            rewards = rewards.view(-1, self.n_candidates)
            reward_loss = F.cross_entropy(rewards, target_indices, reduction="mean")
        else:
            scores = scores / self.args.temperature
            log_score_distribution = F.log_softmax(scores, dim=1)
            log_predicted_distribution = F.log_softmax(rewards.view(-1, self.n_candidates), dim=1)
            if self.args.loss_type == "ForwardKL":
                reward_loss = F.kl_div(
                    log_predicted_distribution, log_score_distribution, reduction="batchmean", log_target=True
                )
            elif self.args.loss_type == "ReverseKL":
                reward_loss = F.kl_div(
                    log_score_distribution, log_predicted_distribution, reduction="batchmean", log_target=True
                )
            else:
                raise ValueError(f"Unsupported loss type: {self.args.loss_type}")
        is_training = reward_loss.requires_grad
        prefix = "" if is_training else "eval_"
        metrics = {f"{prefix}reward_loss": reward_loss.detach().cpu()}
        if is_training and self.model_accepts_loss_kwargs:
            reward_loss /= self.args.gradient_accumulation_steps
        if self.args.reconstruct_loss_weight != 0:
            metrics[f"{prefix}reconstruct_loss"] = (
                reconstruct_loss.detach().cpu() * self.args.gradient_accumulation_steps
                if is_training and self.model_accepts_loss_kwargs
                else reconstruct_loss.detach().cpu()
            )
            loss = (
                1 - self.args.reconstruct_loss_weight
            ) * reward_loss + self.args.reconstruct_loss_weight * reconstruct_loss
        else:
            loss = reward_loss
        self.store_metrics(metrics, train_eval="train" if is_training else "eval")
        return (loss, rewards) if return_outputs else loss

    @override
    def predict(
        self, test_dataset: Dataset, ignore_keys: list[str] | None = None, metric_key_prefix: str = "test"
    ) -> PredictionOutput:
        with self.args.main_process_first(desc="preprocess dataset"):
            test_dataset = test_dataset.map(self.tokenize_row, **self.preprocess_args)
        torch.cuda.empty_cache()
        outputs = super().predict(test_dataset, ignore_keys, metric_key_prefix)
        return outputs

    @override
    def prediction_step(
        self,
        model: nn.Module,
        inputs: dict[str, torch.Tensor | Any],
        prediction_loss_only: bool,
        ignore_keys: list[str] | None = None,
    ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
        inputs = self._prepare_inputs(inputs)
        labels = nested_detach(inputs["scores"])
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", ["past_key_values"])
            else:
                ignore_keys = []

        with torch.no_grad():
            with self.compute_loss_context_manager():
                loss, rewards = self.compute_loss(model, inputs, return_outputs=True)
            loss = loss.mean().detach()

        if prediction_loss_only:
            return (loss, None, None)

        rewards = nested_detach(rewards.view(-1, self.n_candidates))
        return (loss, rewards, labels)  # type: ignore

    @override
    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            self._signature_columns = ["prompt_input_ids", "response_input_ids", "scores"]

    @staticmethod
    def tokenize_row(
        features: dict[str, Any],
        tokenizer: PreTrainedTokenizerBase,
        max_prompt_length: int | None = None,
        max_completion_length: int | None = None,
    ):
        prompt_input_ids: list[int] = tokenizer.encode(features["prompt"], add_special_tokens=False)
        response_input_ids: list[list[int]] = [
            tokenizer.encode(response, add_special_tokens=False) for response in features["responses"]
        ]

        # Truncate prompt and completion sequences
        if max_prompt_length is not None:
            prompt_input_ids = prompt_input_ids[-max_prompt_length:]
        if max_completion_length is not None:
            response_input_ids = [input_ids[:max_completion_length] for input_ids in response_input_ids]

        return {
            "prompt_input_ids": prompt_input_ids,
            "response_input_ids": response_input_ids,
        }
