import torch
import torch.nn as nn
import torch.nn.functional as F
from mammoth import Mammoth
from transformers import PretrainedConfig, PreTrainedModel
from transformers import AutoConfig, AutoModel

MODEL_TYPE = 'abmil'

import torch.nn as nn

def create_mlp(
        in_dim=768, 
        hid_dims=[512, 512], 
        out_dim=512, 
        act=nn.ReLU(),
        dropout=0.,
        end_with_fc=True, 
        end_with_dropout=False,
        bias=True
    ):

    layers = []
    if len(hid_dims) < 0:
        mlp = nn.Identity()
    elif len(hid_dims) >= 0:
        if len(hid_dims) > 0:
            for hid_dim in hid_dims:
                layers.append(nn.Linear(in_dim, hid_dim, bias=bias))
                layers.append(act)
                layers.append(nn.Dropout(dropout))
                in_dim = hid_dim
        layers.append(nn.Linear(in_dim, out_dim))
        if not end_with_fc:
            layers.append(act)
        if end_with_dropout:
            layers.append(nn.Dropout(dropout))
        mlp = nn.Sequential(*layers)
    return mlp


#
# Attention networks
#
class GlobalAttention(nn.Module):
    """
    Attention Network without Gating (2 fc layers)
    args:
        L: input feature dimension
        D: hidden layer dimension
        dropout: dropout
        num_classes: number of classes
    """

    def __init__(self, L=1024, D=256, dropout=0., num_classes=1):
        super().__init__()
        self.module = [
            nn.Linear(L, D),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(D, num_classes)]

        self.module = nn.Sequential(*self.module)

    def forward(self, x):
        return self.module(x)  # N x num_classes


class GlobalGatedAttention(nn.Module):
    """
    Attention Network with Sigmoid Gating (3 fc layers)
    args:
        L: input feature dimension
        D: hidden layer dimension
        dropout: dropout
        num_classes: number of classes
    """

    def __init__(self, L=1024, D=256, dropout=0., num_classes=1):
        super().__init__()

        self.attention_a = [
            nn.Linear(L, D),
            nn.Tanh(),
            nn.Dropout(dropout)
        ]

        self.attention_b = [
            nn.Linear(L, D),
            nn.Sigmoid(),
            nn.Dropout(dropout)
        ]

        self.attention_a = nn.Sequential(*self.attention_a)
        self.attention_b = nn.Sequential(*self.attention_b)
        self.attention_c = nn.Linear(D, num_classes)

    def forward(self, x):
        a = self.attention_a(x)
        b = self.attention_b(x)
        A = a.mul(b)
        A = self.attention_c(A)  # N x num_classes
        return A



class ABMIL(nn.Module):
    """
    ABMIL (Attention-based Multiple Instance Learning) model.

    This class implements the core ABMIL architecture, which uses a patch embedding MLP,
    followed by a global attention or gated attention mechanism, and an optional classification head.

    Args:
        in_dim (int): Input feature dimension for each instance (default: 1024).
        embed_dim (int): Embedding dimension after patch embedding (default: 512).
        num_fc_layers (int): Number of fully connected layers in the patch embedding MLP (default: 1).
        dropout (float): Dropout rate applied in the MLP and attention layers (default: 0.25).
        attn_dim (int): Dimension of the attention mechanism (default: 384).
        gate (int): Whether to use gated attention (True) or standard attention (False) (default: True).
        num_classes (int): Number of output classes for the classification head (default: 2).
    """

    def __init__(
            self,
            in_dim: int = 1024,
            embed_dim: int = 512,
            num_fc_layers: int = 1,
            dropout: float = 0.25,
            attn_dim: int = 384,
            gate: int = True,
            num_classes: int = 2,
            moe_args: dict = {},
    ):
        super().__init__()
        self.in_dim = in_dim
        self.embed_dim = embed_dim
        self.num_classes = num_classes
        self.num_fc_layers = num_fc_layers
        self.dropout = dropout
        self.attn_dim = attn_dim
        self.gate = gate
        if moe_args.get('num_experts', 0) > 0:
            self.patch_embed = Mammoth(**moe_args)
        else:
            self.patch_embed = create_mlp(
                in_dim=in_dim,
                hid_dims=[embed_dim] *
                         (num_fc_layers - 1),
                dropout=dropout,
                out_dim=embed_dim,
                end_with_fc=False
            )

        attn_func = GlobalGatedAttention if gate else GlobalAttention
        self.global_attn = attn_func(
            L=embed_dim,
            D=attn_dim,
            dropout=dropout,
            num_classes=1
        )

        if num_classes > 0:
            self.classifier = nn.Linear(embed_dim, num_classes)
                

    def forward_attention(self, h: torch.Tensor, attn_mask=None, attn_only=True) -> torch.Tensor:
        """
        Compute the attention scores (and optionally the embedded features) for the input instances.

        Args:
            h (torch.Tensor): Input tensor of shape [B, M, D], where B is the batch size,
                M is the number of instances (patches), and D is the input feature dimension.
            attn_mask (torch.Tensor, optional): Optional attention mask of shape [B, M], where 1 indicates
                valid positions and 0 indicates masked positions. If provided, masked positions are set to
                a very large negative value before softmax.
            attn_only (bool, optional): If True, return only the attention scores (A).
                If False, return a tuple (h, A) where h is the embedded features and A is the attention scores.

        Returns:
            torch.Tensor: If attn_only is True, returns the attention scores tensor of shape [B, K, M],
                where K is the number of attention heads (usually 1). If attn_only is False, returns a tuple
                (h, A) where h is the embedded features of shape [B, M, D'] and A is the attention scores.
        """
        h = self.patch_embed(h)
        A = self.global_attn(h)  # B x M x K
        A = torch.transpose(A, -2, -1)  # B x K x M
        if attn_mask is not None:
            A = A + (1 - attn_mask).unsqueeze(dim=1) * torch.finfo(A.dtype).min

        if attn_only:
            return A
        return h, A

    def forward_features(self, h: torch.Tensor, attn_mask=None, return_attention: bool = True) -> torch.Tensor:
        """
        Compute bag-level features using attention pooling.

        Args:
            h (torch.Tensor): [B, M, D] input features.
            attn_mask (torch.Tensor, optional): Attention mask.

        Returns:
            Tuple[torch.Tensor, dict]: Bag features [B, D] and attention weights.
        """
        h, A_base = self.forward_attention(h, attn_mask=attn_mask, attn_only=False)  # A == B x K x M
        A = F.softmax(A_base, dim=-1)  # softmax over N
        h = torch.bmm(A, h).squeeze(dim=1)  # B x K x C --> B x C
        log_dict = {'attention': A_base if return_attention else None}
        return h, log_dict

    def forward_head(self, h: torch.Tensor) -> torch.Tensor:
        """
        Args:
            h: [B x D]-dim torch.Tensor.

        Returns:
            logits: [B x num_classes]-dim torch.Tensor.
        """
        logits = self.classifier(h)
        return logits

    def forward(self, h: torch.Tensor,
                loss_fn: nn.Module = None,
                label: torch.LongTensor = None,
                attn_mask=None,
                return_attention: bool = False,
                return_slide_feats: bool = False) -> torch.Tensor:
        """
        Forward pass for ABMIL.

        Args:
            h: [B, M, D] input features.
            loss_fn: Optional loss function.
            label: Optional labels.
            attn_mask: Optional attention mask.

        Returns:
            Tuple of (results_dict, log_dict) with logits and loss.
        """
        wsi_feats, log_dict = self.forward_features(h, attn_mask=attn_mask, return_attention=return_attention)
        logits = self.forward_head(wsi_feats)
        if loss_fn is not None and label is not None:
            cls_loss = loss_fn(logits, label)
        else:
            cls_loss = None
        results_dict = {'logits': logits, 'loss': cls_loss}
        log_dict['loss'] = cls_loss.item() if cls_loss is not None else -1
        if return_slide_feats:
            log_dict['slide_feats'] = wsi_feats
        return results_dict, log_dict


class ABMILGatedBaseConfig(PretrainedConfig):
    """
    Configuration class for the ABMIL Gated Base model.

    This class stores the configuration parameters required to instantiate an ABMIL model
    with gated attention. It is compatible with Hugging Face's Transformers library and
    can be used to save, load, and share model configurations.

    Args:
        gate (bool): Whether to use gated attention (default: True).
        embed_dim (int): Embedding dimension after patch embedding (default: 512).
        attn_dim (int): Dimension of the attention mechanism (default: 384).
        num_fc_layers (int): Number of fully connected layers in the patch embedding MLP (default: 1).
        dropout (float): Dropout rate applied in the MLP and attention layers (default: 0.25).
        in_dim (int): Input feature dimension for each instance (default: 1024).
        num_classes (int): Number of output classes for the classification head (default: 2).
        **kwargs: Additional keyword arguments passed to the PretrainedConfig base class.

    Attributes:
        model_type (str): The model type identifier ("abmil").
        gate (bool): Whether to use gated attention.
        embed_dim (int): Embedding dimension after patch embedding.
        attn_dim (int): Dimension of the attention mechanism.
        num_fc_layers (int): Number of fully connected layers in the patch embedding MLP.
        dropout (float): Dropout rate applied in the MLP and attention layers.
        in_dim (int): Input feature dimension for each instance.
        num_classes (int): Number of output classes for the classification head.
        auto_map (dict): Mapping for Hugging Face AutoConfig and AutoModel registration.
    """

    model_type = MODEL_TYPE

    # add mapping

    # _target_: str = "src.models.abmil.ABMIL"
    def __init__(self,
                 gate: bool = True,
                 embed_dim: int = 512,
                 attn_dim: int = 384,
                 num_fc_layers: int = 1,
                 dropout: float = 0.25,
                 in_dim: int = 1024,
                 num_classes: int = 2,
                 moe_args: dict = {},
                 **kwargs):
        super().__init__(**kwargs)
        self.gate = gate
        self.embed_dim = embed_dim
        self.attn_dim = attn_dim
        self.num_fc_layers = num_fc_layers
        self.dropout = dropout
        self.in_dim = in_dim
        self.num_classes = num_classes
        self.moe_args = moe_args
        self.auto_map = {
            "AutoConfig": "modeling_abmil.ABMILGatedBaseConfig",
            "AutoModel": "modeling_abmil.ABMILModel",
        }


class ABMILModel(PreTrainedModel):
    config_class = ABMILGatedBaseConfig

    def __init__(self, config: ABMILGatedBaseConfig, **kwargs):
        """
        Initialize ABMILModel with the given config, allowing attribute overrides via kwargs.
        """

        self.config = config
        for k, v in kwargs.items():
            setattr(config, k, v)

        super().__init__(config)
        self.model = ABMIL(
            in_dim=config.in_dim,
            embed_dim=config.embed_dim,
            num_fc_layers=config.num_fc_layers,
            dropout=config.dropout,
            attn_dim=config.attn_dim,
            gate=config.gate,
            num_classes=config.num_classes,
            moe_args=config.moe_args
        )
        self.forward = self.model.forward
        self.forward_attention = self.model.forward_attention
        self.forward_features = self.model.forward_features
        self.forward_head = self.model.forward_head

if __name__ == '__main__':
    moe_kwargs = {
        'input_dim': 1024,
        'dim': 512,
        'num_experts': 30,
        'num_slots': 1,
        'num_heads': 16,
        'lora_rank': 13,
        'lora_method': 'linear',
        'share_lora_weights': True,
        'slot_dim': 256,
    }
    config = ABMILGatedBaseConfig(moe_args=moe_kwargs)
    model = ABMILModel(config)
    print(model)
    x = torch.randn(1, 10000, 1024)
    print(model(x))
