import torch
from torch import nn
from torchvision.models import resnet18
from transformers import AutoModel


class MLP(nn.Module):
    """Fully-connected neural network class.

    This class represents a fully-connected neural network with an arbitrary number of hidden layers.
    The output layer is a linear layer with a single output.
    The hidden layers are all fully connected, have ReLU activation functions and are batch-normalized.
    """

    def __init__(self, dim: int, hidden_layer_sizes: list[int], **kwargs):
        """Initialize the class.

        Args:
            dim: the input dimension.
            hidden_layer_size: the dimension of hidden layers.
        """
        super().__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.n_layers = len(hidden_layer_sizes)
        in_size = dim
        layer_list: list[nn.Module] = [nn.Flatten(start_dim=1)]
        for layer_size in hidden_layer_sizes:
            layer_list.append(nn.Linear(in_size, layer_size))
            layer_list.append(nn.ReLU())
            in_size = layer_size  # Correctly update in_size for the next layer
        layer_list.append(nn.Linear(in_size, 1))
        self.net = nn.Sequential(*layer_list)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the network.

        Args:
            x: the input tensor.

        Returns:
            the output of the model.
        """
        return self.net(x)

    def __repr__(self):
        return f"{self.__class__.__name__}_" + "_".join([str(layer_size) for layer_size in self.hidden_layer_sizes])


class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        _resnet = resnet18(weights=None, norm_layer=nn.Identity)
        _resnet.fc = nn.Linear(_resnet.fc.in_features, 1)
        self.resnet = _resnet

    def forward(self, x):
        return torch.flatten(self.resnet(x))

    def __repr__(self):
        return f"{self.__class__.__name__}"


class RBertClassifier(nn.Module):
    marker_ids: dict[str, int]

    def __init__(self, dropout: float = 0.1, tokenizer=None):
        super().__init__()
        self.encoder = AutoModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1")

        # If tokenizer is provided, resize embeddings to accommodate special tokens
        if tokenizer is not None:
            self.encoder.resize_token_embeddings(len(tokenizer))
            self.marker_ids = {
                "e1_open": tokenizer.convert_tokens_to_ids("[E1]"),
                "e1_close": tokenizer.convert_tokens_to_ids("[/E1]"),
                "e2_open": tokenizer.convert_tokens_to_ids("[E2]"),
                "e2_close": tokenizer.convert_tokens_to_ids("[/E2]"),
            }
        else:
            raise ValueError("Tokenizer must be provided to set marker IDs.")
        hidden_size = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(dropout)
        # We'll concat: [CLS] (H), pooled E1 (H), pooled E2 (H) => 3H
        self.classifier = nn.Linear(hidden_size * 3, 1)

    @staticmethod
    def _first_index(mask: torch.Tensor) -> torch.Tensor:
        """
        mask: [B, T] (bool)
        returns: [B] (long) -1 if not found
        """
        # argmax は全ゼロでも 0 を返すので any でガード
        any_true = mask.any(dim=1)
        idx = mask.float().argmax(dim=1)  # 最初のTrueの位置
        idx = torch.where(any_true, idx, torch.full_like(idx, -1))
        return idx.long()

    def _pool_between(self, last_hidden: torch.Tensor, open_ids: int, close_ids: int) -> torch.Tensor:
        """
        last_hidden: [B, T, H]
        open_ids/close_ids: int (token id)
        returns pooled vectors [B, H] for (open, close) の間の平均
        """
        input_ids = self._cached_input_ids  # [B, T]
        B, T, H = last_hidden.shape
        device = last_hidden.device

        open_mask = input_ids == open_ids  # [B, T]
        close_mask = input_ids == close_ids  # [B, T]

        open_idx = self._first_index(open_mask)  # [B]
        close_idx = self._first_index(close_mask)  # [B]

        # 有効なスパン（両方見つかっていて、open < close）
        valid = (open_idx >= 0) & (close_idx >= 0) & (close_idx > open_idx)

        # 位置テンソル [1, T] を作ってブロードキャストで「open < pos < close」を作る
        pos = torch.arange(T, device=device).unsqueeze(0)  # [1, T]
        between = (pos > open_idx.unsqueeze(1)) & (pos < close_idx.unsqueeze(1))  # [B, T]

        # [B, T, 1] で拡張して masked sum / masked count
        between_f = between.float().unsqueeze(-1)  # [B, T, 1]
        summed = (last_hidden * between_f).sum(dim=1)  # [B, H]
        count = between_f.sum(dim=1).clamp_min(1.0)  # [B, 1] ゼロ割防止

        pooled = summed / count  # [B, H]

        # スパンが無効な行は CLS にフォールバック
        cls = last_hidden[:, 0, :]  # [B, H]
        pooled = torch.where(valid.unsqueeze(1), pooled, cls)
        return pooled

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        # --- 軽量な入力検証（O(1)） ---
        if input_ids.max() >= self.encoder.config.vocab_size:
            raise ValueError(
                "Found token id >= vocab_size. Did you call resize_token_embeddings "
                "after adding special tokens to the tokenizer?"
            )

        # encoder 実行（fp16/bf16 autocast と相性◎）
        outputs = self.encoder(
            input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True
        )
        last_hidden = outputs.last_hidden_state  # [B, T, H]
        cls = last_hidden[:, 0, :]  # [B, H]

        # ベクトル化プールのために input_ids を一時保持（関数引数を増やさないための小技）
        self._cached_input_ids = input_ids

        mid = self.marker_ids
        e1_vecs = self._pool_between(last_hidden, mid["e1_open"], mid["e1_close"])
        e2_vecs = self._pool_between(last_hidden, mid["e2_open"], mid["e2_close"])

        # 後処理
        feat = torch.cat([cls, e1_vecs, e2_vecs], dim=-1)
        feat = self.dropout(feat)
        logits = self.classifier(feat)
        # flatten せず [B] を返す（Torch っぽく）
        return logits.squeeze(-1)

    def __repr__(self):
        return f"{self.__class__.__name__}"
