# =============================================================================
# CONFIDENTIAL - FOR REVIEW ONLY
# This code is submitted as supplementary material for paper review.
# DO NOT DISTRIBUTE - Pending patent application.
# =============================================================================

from typing import Optional, Tuple, Union

import torch
from torch import nn
from transformers import AutoConfig
from transformers.models.qwen3.modeling_qwen3 import Qwen3Config
from transformers.modeling_outputs import CausalLMOutputWithPast

from .base import EmbeddingReplacement, HeadReplacement
from ..architectures.qwen_flex_attn import Qwen3ForCausalLM, Qwen3Model


def custom_forward(
        self,
        input: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        past_key_values: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: Optional[bool] = None,
        cache_position: Optional[torch.Tensor] = None,
        logits_to_keep: Optional[int] = 0,
        **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
    assert attention_mask is None, "attention_mask not supported"
    assert position_ids is None, "position_ids not supported"
    assert past_key_values is None, "past_key_values not supported"
    assert inputs_embeds is None, "inputs_embeds not supported"
    assert labels is None, "labels not supported"
    assert use_cache is None, "use_cache not supported"
    assert cache_position is None, "cache_position not supported"

    outputs = self.model(
        input_ids=input,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        cache_position=cache_position,
        **kwargs,
    )
    hidden_states = outputs.last_hidden_state

    mu, sigma = self.lm_head(hidden_states.to(self.lm_head.weight.dtype))
    sigma = nn.functional.softplus(sigma)

    return mu, sigma


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def get_model_with_parallel_input(
    model_name: str,
    pretrained: bool = True,
    num_parallel_input: int = 1,
    num_embedding_layers: int = 2,
    num_head_layers: int = 2,
    cache_dir: Optional[str] = None,
    hidden_size: Optional[int] = None,
    num_hidden_layers: Optional[int] = None,
    flash_attention: bool = False,
    **kwargs
):
    if flash_attention:
        attn_implementation = 'flash_attention_2'
    else:
        attn_implementation = 'eager'
    print(f"Using attention implementation: {attn_implementation}")

    if model_name == 'qwen3-small':
        config = Qwen3Config(
            attention_bias=False,
            attention_dropout=0.0,
            head_dim=128,
            hidden_act='silu',
            hidden_size=1024,
            initializer_range=0.02,
            intermediate_size=3072,
            max_position_embeddings=40960,
            max_window_layers=24,
            num_attention_heads=16,
            num_hidden_layers=24,
            num_key_value_heads=8,
            rms_norm_eps=1e-06,
            rope_theta=1000000,
            use_cache=True,
            use_sliding_window=False,
            vocab_size=151936,
            attn_implementation=attn_implementation,
        )
    elif model_name == 'qwen3-micro':
        config = Qwen3Config(
            attention_bias=False,
            attention_dropout=0.0,
            head_dim=128,
            hidden_act='silu',
            hidden_size=768,
            initializer_range=0.02,
            intermediate_size=2304,
            max_position_embeddings=40960,
            max_window_layers=12,
            num_attention_heads=12,
            num_hidden_layers=12,
            num_key_value_heads=6,
            rms_norm_eps=1e-06,
            rope_theta=1000000,
            use_cache=True,
            use_sliding_window=False,
            vocab_size=151936,
            attn_implementation=attn_implementation,
        )
    elif model_name == 'qwen3-mini':
        config = Qwen3Config(
            attention_bias=False,
            attention_dropout=0.0,
            head_dim=128,
            hidden_act='silu',
            hidden_size=512,
            initializer_range=0.02,
            intermediate_size=1536,
            max_position_embeddings=40960,
            max_window_layers=8,
            num_attention_heads=8,
            num_hidden_layers=8,
            num_key_value_heads=4,
            rms_norm_eps=1e-06,
            rope_theta=1000000,
            use_cache=True,
            use_sliding_window=False,
            vocab_size=151936,
            attn_implementation=attn_implementation,
        )
    else:
        raise ValueError(f"Invalid model name: {model_name}. Use 'qwen3-mini', 'qwen3-micro', or 'qwen3-small'")

    config.d_model = config.hidden_size
    config.n_layer = config.num_hidden_layers

    if hidden_size is not None:
        config.hidden_size = hidden_size
        config.d_model = hidden_size
        config.intermediate_size = hidden_size * 3

    if num_hidden_layers is not None:
        config.num_hidden_layers = num_hidden_layers
        config.n_layer = num_hidden_layers

    if kwargs:
        print(f"Modifying config with: {kwargs.keys()}")
        for key, value in kwargs.items():
            print(f"  Setting {key} = {value}")
            setattr(config, key, value)

    model = Qwen3ForCausalLM(config)

    model.model.original_embed_tokens = model.model.embed_tokens
    model.model.embed_tokens = EmbeddingReplacement(
        num_parallel_input,
        config.d_model,
        config.d_model,
        num_embedding_layers
    )
    
    model.original_lm_head = model.lm_head
    model.lm_head = HeadReplacement(
        config.d_model,
        config.d_model,
        [num_parallel_input, num_parallel_input],
        num_head_layers
    )

    model.original_forward = model.forward
    model.forward = custom_forward.__get__(model, model.__class__)

    print(f"Model has {count_parameters(model)/1e9:.3f}B parameters")

    return model, config

