# =============================================================================
# CONFIDENTIAL - FOR REVIEW ONLY
# This code is submitted as supplementary material for paper review.
# DO NOT DISTRIBUTE - Pending patent application.
# =============================================================================

from typing import Optional, Tuple, Union

import torch
from torch import nn
from transformers import activations

from .base import EmbeddingReplacement, HeadReplacement
# Note: This requires the custom Mamba2 with MLP implementation (included separately)
from modeling_mamba2mlp import (
    Mamba2ForCausalLM, Mamba2Model, Mamba2Config, Mamba2Cache, Mamba2CausalLMOutput
)


def custom_forward(
        self,
        input: torch.Tensor,
        input_embeds: Optional[torch.FloatTensor] = None,
        cache_params = None,
        labels: Optional[torch.LongTensor] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        logits_to_keep: Optional[int] = 0,
        **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    assert input_embeds is None, "input_embeds not supported"
    assert cache_params is None, "cache_params not supported"
    assert labels is None, "labels not supported"
    assert output_hidden_states is None, "output_hidden_states not supported"
    assert use_cache is None, "use_cache not supported"
    assert cache_position is None, "cache_position not supported"
    assert attention_mask is None, "attention_mask not supported"

    outputs = self.backbone(
        input,
        cache_params=cache_params,
        input_embeds=input_embeds,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        use_cache=use_cache,
        cache_position=cache_position,
        attention_mask=attention_mask,
    )
    hidden_states = outputs[0]

    mu, sigma = self.lm_head(hidden_states.to(self.lm_head.weight.dtype))
    sigma = nn.functional.softplus(sigma)

    return mu, sigma


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_model_with_parallel_input(
    model_name: str,
    pretrained: bool = True,
    num_parallel_input: int = 1,
    num_embedding_layers: int = 2,
    num_head_layers: int = 2,
    cache_dir: Optional[str] = None,
    hidden_size: Optional[int] = None,
    num_hidden_layers: Optional[int] = None,
):
    configs = {
        'mamba2-mini': Mamba2Config(
            hidden_size=512, num_heads=16, num_hidden_layers=12,
            residual_in_fp32=False
        ),
        'mamba2-micro': Mamba2Config(
            hidden_size=768, num_heads=24, num_hidden_layers=24,
            residual_in_fp32=False
        ),
        'mamba2-small': Mamba2Config(
            hidden_size=1024, num_heads=32, num_hidden_layers=48,
            residual_in_fp32=False
        ),
        'mamba2-medium': Mamba2Config(
            hidden_size=1536, num_heads=48, num_hidden_layers=48,
            residual_in_fp32=False
        ),
        'mamba2-large': Mamba2Config(
            hidden_size=2048, num_heads=64, num_hidden_layers=48,
            residual_in_fp32=False
        ),
        'mamba2-mid': Mamba2Config(
            hidden_size=1536, num_heads=48, num_hidden_layers=21,
            model_type="mamba2mlp", residual_in_fp32=False
        ),
    }

    if model_name not in configs:
        raise ValueError(f"Invalid model name: {model_name}. Choose from: {list(configs.keys())}")

    config = configs[model_name]
    config.d_model = config.hidden_size
    config.n_layer = config.num_hidden_layers

    if hidden_size is not None:
        config.hidden_size = hidden_size
        config.d_model = hidden_size
        config.intermediate_size = hidden_size * 2

    if num_hidden_layers is not None:
        config.num_hidden_layers = num_hidden_layers
        config.n_layer = num_hidden_layers

    model = Mamba2ForCausalLM(config)

    model.backbone.original_embeddings = model.backbone.embeddings
    model.backbone.embeddings = EmbeddingReplacement(
        num_parallel_input,
        config.d_model,
        config.d_model,
        num_embedding_layers
    )
    
    model.original_lm_head = model.lm_head
    model.lm_head = HeadReplacement(
        config.d_model,
        config.d_model,
        [num_parallel_input, num_parallel_input],
        num_head_layers
    )

    model.original_forward = model.forward
    model.forward = custom_forward.__get__(model, model.__class__)

    print(f"Model has {count_parameters(model)/1e9:.3f}B parameters")

    return model, config

