from __future__ import annotations

from functools import partial
from typing import Callable, override

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn import init
from transformers import PretrainedConfig, PreTrainedModel

from mow.modules.gcn import GCN
from mow.modules.mlp import MLP

_IS_EMBEDDING_PROJECTION = "_is_embedding_projection"


class GraphRouterConfig(PretrainedConfig):
    model_type = "graph-router"

    def __init__(
        self,
        *,
        hidden_size: int | None = None,
        context_size: int | None = None,
        embed_dim: int | None = None,
        output_dim: int | None = None,
        use_mlp: bool = True,
        num_layers: int = 4,
        aggregate_layers: list[int] | None = None,
        dropout: float = 0.2,
        use_embedding_projection: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.hidden_size = hidden_size
        self.context_size = context_size
        self.embed_dim = embed_dim
        self.output_dim = output_dim
        self.use_mlp = use_mlp
        self.num_layers = num_layers
        self.aggregate_layers = (
            aggregate_layers
            if aggregate_layers is not None
            else list(range(num_layers))
        )
        self.dropout = dropout
        self.use_embedding_projection = use_embedding_projection


class _GraphRouterModule(nn.Module):
    def __init__(
        self,
        *,
        num_layers: int,
        aggregate_layers: list[int],
        hidden_size: int | None = None,
        context_size: int | None = None,
        embed_dim: int,
        output_dim: int | None = None,
        use_mlp: bool = True,
        dropout: float,
    ):
        super().__init__()

        self.num_layers = num_layers
        self.aggregate_layers = aggregate_layers
        self.gcn = nn.ModuleList(
            GCN(
                hidden_size=hidden_size,
                context_size=context_size,
                embed_dim=embed_dim,
                dropout=dropout,
                batch_norm=False,
                layer_norm=False,
                relational=True,
            )
            for _ in range(num_layers)
        )
        if use_mlp:
            output_dim = output_dim or embed_dim
            self.mlp = MLP(input_dim=embed_dim, output_dim=output_dim)
        else:
            self.mlp = nn.Identity()

    def forward(
        self,
        hidden_states: torch.Tensor,
        adjacency_matrix: torch.Tensor,
        relation_matrix: torch.Tensor,
        context: torch.Tensor | None = None,
    ):
        embeddings: list[torch.Tensor] = []
        h = hidden_states

        for i, module in enumerate(self.gcn):
            # (batch_size, num_nodes, embed_dim)
            if i in self.aggregate_layers:
                h = module(
                    h, adjacency_matrix, relation_matrix, context=context
                )
            else:
                zero = torch.zeros_like(adjacency_matrix)
                h = module(h, zero, relation_matrix, context=context)

            embeddings.append(h)

        # (batch_size, num_layers, num_nodes, embed_dim)
        embedding = torch.stack(embeddings, dim=-3)

        # (batch_size, num_layers, embed_dim)
        embedding = torch.mean(embedding, dim=-2)

        # (batch_size, num_adapters)
        logits = self.mlp(embedding[..., -1, :])

        return embedding, logits


class GraphRouter(PreTrainedModel):
    config_class = GraphRouterConfig
    base_model_prefix = "graph_router"
    router_module_class = _GraphRouterModule

    def __init__(
        self,
        config: GraphRouterConfig,
        *,
        compute_loss: (
            Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | None
        ) = None,
    ):
        super().__init__(config)

        self.module = _GraphRouterModule(
            num_layers=config.num_layers,
            aggregate_layers=config.aggregate_layers,
            hidden_size=config.hidden_size,
            context_size=config.context_size,
            embed_dim=config.embed_dim or 128,
            output_dim=config.output_dim,
            use_mlp=config.use_mlp,
            dropout=config.dropout,
        )
        self.embedding_set = nn.ParameterDict()
        if config.use_embedding_projection:
            self.embedding_projection = nn.ModuleList(
                MLP(
                    input_dim=config.embed_dim or 128,
                    output_dim=config.embed_dim or 128,
                    use_bias=True,
                )
                for _ in range(config.num_layers)
            )
            for module in self.embedding_projection.modules():
                setattr(module, _IS_EMBEDDING_PROJECTION, True)
        else:
            self.embedding_projection = None

        self.compute_loss = compute_loss

        self.post_init()

        self.hidden_size = config.hidden_size
        self.context_size = config.context_size
        self.embed_dim = config.embed_dim
        self.output_dim = config.output_dim
        self.use_mlp = config.use_mlp
        self.num_layers = config.num_layers
        self.aggregate_layers = config.aggregate_layers
        self.dropout = config.dropout
        self.use_embedding_projection = config.use_embedding_projection

    def _init_weights(self, module: nn.Module):
        if isinstance(module, nn.Linear):
            if getattr(module, _IS_EMBEDDING_PROJECTION, False):
                gain = 1e-5
            else:
                gain = 1.0
            init.xavier_uniform_(module.weight, gain)
            if module.bias is not None:
                init.zeros_(module.bias)

    def forward(
        self,
        hidden_states: torch.Tensor,
        adjacency_matrix: torch.Tensor,
        relation_matrix: torch.Tensor | None = None,
        context: torch.Tensor | None = None,
        labels: torch.LongTensor | None = None,
    ):
        embedding, logits = self.module(
            hidden_states=hidden_states,
            adjacency_matrix=adjacency_matrix,
            relation_matrix=relation_matrix,
            context=context,
        )

        loss = None
        if labels is not None:
            loss = (self.compute_loss or F.cross_entropy)(logits, labels)

        return {
            "loss": loss,
            "logits": logits,
            "embedding": embedding,
        }

    def get_embedding(
        self,
        hidden_states: torch.Tensor,
        adjacency_matrix: torch.Tensor,
        relation_matrix: torch.Tensor | None = None,
        context: torch.Tensor | None = None,
    ):
        return self.module(
            hidden_states=hidden_states,
            adjacency_matrix=adjacency_matrix,
            relation_matrix=relation_matrix,
            context=context,
        )[0]

    def update_embedding_set(
        self,
        name: str,
        hidden_states: torch.Tensor,
        adjacency_matrix: torch.Tensor,
        relation_matrix: torch.Tensor | None = None,
        context: torch.Tensor | None = None,
    ):
        embedding = self.get_embedding(
            hidden_states=hidden_states,
            adjacency_matrix=adjacency_matrix,
            relation_matrix=relation_matrix,
            context=context,
        )

        while embedding.ndim > 2:
            embedding = torch.mean(embedding, dim=0)

        if name not in self.embedding_set:
            self.embedding_set[name] = nn.Parameter(
                embedding, requires_grad=False
            )
        else:
            self.embedding_set[name].data.copy_(embedding)

    def get_similarities(
        self,
        hidden_states: torch.Tensor,
        adjacency_matrix: torch.Tensor,
        relation_matrix: torch.Tensor | None = None,
        context: torch.Tensor | None = None,
        temperature: float = 0.1,
        similarity_fn: Callable[
            [torch.Tensor, torch.Tensor], torch.Tensor
        ] = partial(torch.cosine_similarity, dim=-1),
        keys: list[str] | None = None,
        top_k: int | None = None,
        flatten: bool = False,
    ):
        if not top_k and flatten:
            raise ValueError("top_k must be specified when flatten is True.")
        embedding = self.get_embedding(
            hidden_states=hidden_states,
            adjacency_matrix=adjacency_matrix,
            relation_matrix=relation_matrix,
            context=context,
        )
        names = list(self.embedding_set.keys())
        similarities = torch.stack(
            [
                similarity_fn(
                    self.__proj(embedding), self.__proj(embedding_set)
                )
                for key, embedding_set in self.embedding_set.items()
                if keys is None or key in keys
            ],
            dim=0,
        )
        if top_k:
            top_k_vals, top_k_idx = torch.topk(similarities, top_k, dim=0)
            if flatten:
                top_k_vals = torch.ones_like(top_k_vals)
            similarities = torch.full_like(similarities, float("-inf"))
            similarities.scatter_(0, top_k_idx, top_k_vals)
        similarities = torch.softmax(similarities / temperature, dim=0)
        return {
            key: similarity
            for key, similarity in zip(keys or names, similarities)
        }

    def __proj(self, embedding: torch.Tensor):
        if not self.use_embedding_projection:
            return embedding
        if self.embedding_projection is None:
            raise ValueError("embedding_projection is not set, but called.")
        return torch.stack(
            [
                module(embedding[..., i, :])
                for i, module in enumerate(self.embedding_projection)
            ],
            dim=-2,
        )

    @override
    def _load_from_state_dict(
        self,
        state_dict,
        prefix: str,
        local_metadata,
        strict: bool,
        missing_keys: list[str],
        unexpected_keys: list[str],
        error_msgs: list[str],
    ):
        to_register = [
            k for k in state_dict if k.startswith(prefix + "embedding_set.")
        ]
        for full_key in to_register:
            param_tensor = state_dict[full_key]
            sub_name = full_key[len(prefix + "embedding_set.") :]

            self.embedding_set[sub_name] = nn.Parameter(
                param_tensor, requires_grad=False
            )

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )
