from dataclasses import asdict
import random
import torch
from typing import Any
from torch import Tensor
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from datasets import Dataset
from typing_extensions import override

from args import ContrastArguments
from utils import get_logger
from model.with_llm_embeddings import load_router_with_llm_embeddings
from utils.plotting import plot_trainer_state
from .base import BaseRouterTrainer

logger = get_logger(__name__)


class ContrastRouterTrainer(BaseRouterTrainer):
    def __init__(
        self,
        model: PreTrainedModel | str,
        processing_class: PreTrainedTokenizerBase,
        model_ids: list[str],
        args: ContrastArguments,
        data_collator=None,
        train_dataset: Dataset | None = None,
        eval_dataset: Dataset | None = None,
        **kwargs,
    ):
        self.args = args
        preprocess_args = {
            "fn_kwargs": {
                "tokenizer": processing_class,
                "max_length": args.max_length,
            },
            "batched": False,
            "num_proc": args.dataset_num_proc,
        }
        super().__init__(
            model,
            processing_class,
            model_ids,
            args,
            data_collator,
            train_dataset,
            eval_dataset.shuffle(seed=args.seed) if eval_dataset else None,
            preprocess_args,
            **kwargs,
        )

    @override
    def plot(self):
        """Plots the training state."""
        assert self.args.output_dir is not None, "output_dir must be set to plot the training state."
        plot_trainer_state(asdict(self.state), ["loss"], self.args.output_dir, "loss")
        plot_trainer_state(asdict(self.state), ["sample_llm_loss"], self.args.output_dir, "sample_llm_loss")
        plot_trainer_state(asdict(self.state), ["sample_sample_loss"], self.args.output_dir, "sample_sample_loss")
        plot_trainer_state(asdict(self.state), ["reward"], self.args.output_dir, "reward")
        plot_trainer_state(asdict(self.state), ["accuracy"], self.args.output_dir, "accuracy")

    @override
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        scores = inputs.pop("scores")
        target_indices = scores.argmax(dim=1)
        cluster_ids = inputs.pop("cluster_id", None)
        outputs = model(**inputs)
        is_training = outputs.routing_logits.requires_grad
        if cluster_ids is not None:
            sample_llm_loss = self._compute_sample_llm_loss(outputs.routing_logits, scores)
            sample_sample_loss = self._compute_sample_sample_loss(outputs.query_embedding, cluster_ids)
            loss = sample_llm_loss + self.args.sample_loss_weight * sample_sample_loss
            metrics = {
                "sample_llm_loss": sample_llm_loss.detach().cpu(),
                "sample_sample_loss": sample_sample_loss.detach().cpu(),
            }
        else:
            loss = torch.tensor(0.0, device=outputs.routing_logits.device)
            metrics = {}
        if is_training:
            with torch.no_grad():
                train_accuracy = (outputs.routing_logits.argmax(dim=1) == target_indices).float().mean()
            metrics["accuracy"] = train_accuracy.detach().cpu()
        self.store_metrics(metrics, "train" if is_training else "eval")
        return (loss, outputs.routing_logits) if return_outputs else loss

    def _init_model(self, model: str, model_ids: list[str], pad_token_id: int, **kwargs) -> PreTrainedModel:
        model_instance = load_router_with_llm_embeddings(model, len(model_ids), self.args.similarity_function, **kwargs)
        return model_instance

    def _compute_sample_llm_loss(self, logits: Tensor, scores: Tensor):
        """Compute contrastive learning loss between samples and LLMs.

        Args:
            logits: Similarity between sample embeddings and llm embeddings, shape=(batch_size, num_targets).
            scores: Scores of the targets, shape=(batch_size, num_targets).
        """
        loss = 0
        positive_scores, positive_indices = scores.topk(self.args.top_k_llms, dim=-1)
        negative_scores, negative_indices = scores.topk(self.args.last_k_llms, dim=-1, largest=False)
        negative_similarity = logits.gather(dim=1, index=negative_indices)  # shape=(batch_size, last_k)
        negative_similarity = torch.where(
            negative_scores > 0.5, float("-inf"), negative_similarity
        )  # Mask out good targets
        for i in range(self.args.top_k_llms):
            positive_scores_i = positive_scores[:, i]  # shape=(batch_size,)
            positive_indices_i = positive_indices[:, i].view(-1, 1)  # shape=(batch_size, 1)
            positive_similarity_i = logits.gather(dim=1, index=positive_indices_i)  # shape=(batch_size, 1)
            tmp_similarity = torch.concat(
                [positive_similarity_i, negative_similarity], dim=-1
            )  # shape=(batch_size, 1+last_k)
            log_softmax_similarity = torch.log_softmax(tmp_similarity, dim=-1)[:, 0]  # shape=(batch_size,)
            mask = torch.where(positive_scores_i > 0, 1, 0)  # Mask out bad targets, shape=(batch_size, 1)
            log_softmax_similarity = log_softmax_similarity * mask
            loss += log_softmax_similarity.mean()
        return -loss

    def _compute_sample_sample_loss(self, hidden_state: Tensor, cluster_ids: Tensor):
        """Compute contrastive learning loss between samples.

        Args:
            hidden_state: Hidden state of the samples, shape=(batch_size, hidden_size).
            cluster_ids: Cluster id of the samples, shape=(batch_size,).
        """
        similarity = self.model.compute_similarity(hidden_state, hidden_state)
        all_indices = []
        for cluster_id in cluster_ids:
            positive_indices = torch.nonzero(cluster_ids == cluster_id)
            positive_index = random.choice(positive_indices)
            negative_indices = torch.nonzero(cluster_ids != cluster_id)
            if len(negative_indices) < self.args.last_k_samples:
                print(f"Warning: cluster {cluster_id} has less than {self.args.last_k_samples} negative samples")
                continue
            random_indices = random.sample(range(len(negative_indices)), self.args.last_k_samples)
            negative_indices = negative_indices[random_indices].view(-1)
            indices = torch.concat([positive_index, negative_indices])
            all_indices.append(indices)
        if all_indices:
            all_indices = torch.stack(all_indices)
            similarity = similarity.gather(dim=1, index=all_indices)  # shape=(batch_size, 1+last_k_samples)
            log_softmax_similarity = torch.log_softmax(similarity, dim=-1)[:, 0]
            loss = -log_softmax_similarity.mean()
        else:
            loss = 0
        return loss

    @override
    def _set_signature_columns_if_needed(self):
        if self._signature_columns is None:
            self._signature_columns = ["input_ids", "attention_mask", "scores", "cluster_id"]

    @staticmethod
    def tokenize_row(features: dict[str, Any], tokenizer: PreTrainedTokenizerBase, max_length: int | None = None):
        return tokenizer(features["prompt"], truncation=True, max_length=max_length)
