# =============================================================================
# CONFIDENTIAL - FOR REVIEW ONLY
# This code is submitted as supplementary material for paper review.
# DO NOT DISTRIBUTE - Pending patent application.
# =============================================================================

from typing import Optional, Tuple, Any
import sys
import warnings

import torch
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM

from .base import EmbeddingReplacement, HeadReplacement, disable_dropout
from ..architectures.mistral_flex_attn import MistralForCausalLM


def custom_forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    assert past_key_values is None, "past_key_values not supported"
    assert attention_mask is None, "attention_mask not supported"
    assert token_type_ids is None, "token_type_ids not supported"
    assert use_cache is None, "use_cache not supported"
    assert output_attentions is None, "output_attentions not supported"
    assert output_hidden_states is None, "output_hidden_states not supported"

    device = input_ids.device if input_ids is not None else inputs_embeds.device
    
    if position_ids is None:
        position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0)

    if inputs_embeds is None:
        inputs_embeds = self.model.embed_tokens(input_ids)
    
    outputs = self.model(
        input_ids=None,
        attention_mask=attention_mask,
        position_ids=position_ids,
        inputs_embeds=inputs_embeds,
        past_key_values=past_key_values,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )

    hidden_states = outputs[0]
    mu, sigma = self.lm_head(hidden_states.to(self.lm_head.weight.dtype))
    sigma = nn.functional.softplus(sigma)

    return mu, sigma


def get_model_with_parallel_input(
    model_name: str,
    pretrained: bool = True,
    num_parallel_input: int = 1,
    num_embedding_layers: int = 2,
    num_head_layers: int = 2,
    cache_dir: Optional[str] = None,
    flash_attention: bool = False,
    **kwargs
):
    if sys.platform in ["linux", "linux2"] and flash_attention:
        attn_implementation = "flash_attention_2"
    else:
        attn_implementation = "eager"
    print(f"Using attention implementation: {attn_implementation}")

    if cache_dir is not None and cache_dir.lower() != 'none':
        config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir, attn_implementation=attn_implementation)
    else:
        config = AutoConfig.from_pretrained(model_name, attn_implementation=attn_implementation)
    
    for key in dir(config):
        if 'drop' in key.lower():
            setattr(config, key, 0.0)

    if pretrained:
        assert False, "Pretrained models not tested for L-CUBE"
    else:
        if kwargs:
            print(f"Modifying config with: {kwargs.keys()}")
            for key, value in kwargs.items():
                print(f"  Setting {key} = {value}")
                setattr(config, key, value)
        
        model = MistralForCausalLM(config)

    disable_dropout(model)

    model.model.original_embed_tokens = model.model.embed_tokens
    model.model.embed_tokens = EmbeddingReplacement(
        num_parallel_input,
        config.hidden_size,
        config.hidden_size,
        num_embedding_layers
    )

    model.original_lm_head = model.lm_head
    model.lm_head = HeadReplacement(
        config.hidden_size,
        config.hidden_size,
        [num_parallel_input, num_parallel_input],
        num_head_layers
    )

    model.original_forward = model.forward
    model.forward = custom_forward.__get__(model, model.__class__)
    
    return model, config

