# =============================================================================
# CONFIDENTIAL - FOR REVIEW ONLY
# This code is submitted as supplementary material for paper review.
# DO NOT DISTRIBUTE - Pending patent application.
# =============================================================================

from typing import Optional, Tuple, Union

import torch
from torch import nn
from transformers import AutoConfig, activations
from transformers.models.mamba.modeling_mamba import MambaCausalLMOutput, MambaCache

from .base import EmbeddingReplacement, HeadReplacement
from mamba_ssm import MambaLMHeadModel


def custom_forward(
        self,
        input: torch.Tensor,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        cache_params: Optional[MambaCache] = None,
        labels: Optional[torch.LongTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        use_cache: Optional[bool] = None,
        **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    assert inputs_embeds is None, "inputs_embeds not supported"
    assert cache_params is None, "cache_params not supported"
    assert labels is None or labels.dim() == 2, "labels must be a 2D tensor"
    assert use_cache is None, "use_cache not supported"

    hidden_states = self.backbone(input)

    mu, sigma = self.lm_head(hidden_states.to(self.lm_head.weight.dtype))
    sigma = nn.functional.softplus(sigma)

    return mu, sigma


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_model_with_parallel_input(
    model_name: str,
    pretrained: bool = True,
    num_parallel_input: int = 1,
    num_embedding_layers: int = 2,
    num_head_layers: int = 2,
    cache_dir: Optional[str] = None,
    hidden_size: Optional[int] = None,
    num_hidden_layers: Optional[int] = None,
):
    if cache_dir is not None and cache_dir.lower() != 'none':
        config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
    else:
        config = AutoConfig.from_pretrained(model_name)

    if hidden_size is not None:
        config.hidden_size = hidden_size
        config.d_model = hidden_size
        config.intermediate_size = hidden_size * 2
        config.time_step_rank = hidden_size // 16

    if num_hidden_layers is not None:
        config.num_hidden_layers = num_hidden_layers
        config.n_layer = num_hidden_layers

    model = MambaLMHeadModel(config)

    model.backbone.original_embedding = model.backbone.embedding
    model.backbone.embedding = EmbeddingReplacement(
        num_parallel_input,
        config.d_model,
        config.d_model,
        num_embedding_layers
    )
    
    model.original_lm_head = model.lm_head
    model.lm_head = HeadReplacement(
        config.d_model,
        config.d_model,
        [num_parallel_input, num_parallel_input],
        num_head_layers
    )

    model.original_forward = model.forward
    model.forward = custom_forward.__get__(model, model.__class__)

    print(f"Model has {count_parameters(model)/1e9:.3f}B parameters")

    return model, config

