# =============================================================================
# CONFIDENTIAL - FOR REVIEW ONLY
# This code is submitted as supplementary material for paper review.
# DO NOT DISTRIBUTE - Pending patent application.
# =============================================================================

from typing import List, Union

import torch
from torch import nn
from transformers import activations


class EmbeddingReplacement(nn.Module):
    def __init__(
        self,
        num_inputs: int,
        hidden_dim: int,
        num_outputs: int,
        num_layers: int,
        act: str = 'silu'
    ):
        super().__init__()
        
        layers = [
            nn.Sequential(
                nn.Linear(num_inputs, hidden_dim),
                activations.get_activation(act),
            )
        ]
        
        for _ in range(num_layers - 2):
            layers.append(
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    activations.get_activation(act),
                )
            )
        
        layers.append(
            nn.Sequential(
                nn.Linear(hidden_dim, num_outputs),
                activations.get_activation(act),
            )
        )
        
        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)


class HeadReplacement(nn.Module):
    def __init__(
        self,
        num_inputs: int,
        hidden_dim: int,
        num_outputs: Union[int, List[int]],
        num_layers: int,
        act: str = 'silu'
    ):
        super().__init__()
        
        if isinstance(num_outputs, int):
            num_outputs = [num_outputs]
        
        num_outputs_total = sum(num_outputs)
        
        layers = [
            nn.Sequential(
                nn.Linear(num_inputs, hidden_dim),
                activations.get_activation(act),
            )
        ]
        
        for _ in range(num_layers - 2):
            layers.append(
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    activations.get_activation(act),
                )
            )
        
        layers.append(nn.Linear(hidden_dim, num_outputs_total))
        
        self.layers = nn.Sequential(*layers)
        self.num_outputs = num_outputs
        self.register_buffer('weight', torch.zeros(1))

    def forward(self, x: torch.Tensor) -> tuple:
        x = self.layers(x)
        return torch.split(x, self.num_outputs, dim=-1)


def disable_dropout(model: nn.Module) -> None:
    for module in model.modules():
        for attr_name in dir(module):
            if "drop" in attr_name.lower():
                attr_value = getattr(module, attr_name)
                if isinstance(attr_value, (float, int)):
                    setattr(module, attr_name, 0.0)
                if isinstance(attr_value, nn.Dropout):
                    attr_value.p = 0.0

