# =============================================================================
# CONFIDENTIAL - FOR REVIEW ONLY
# This code is submitted as supplementary material for paper review.
# DO NOT DISTRIBUTE - Pending patent application.
# =============================================================================

from typing import Optional, Tuple, Union
import sys
import warnings

import torch
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.gpt2.modeling_gpt2 import CausalLMOutputWithCrossAttentions

from .base import EmbeddingReplacement, HeadReplacement, disable_dropout
from ..architectures.gpt2_flex_attn import GPT2LMHeadModel


def custom_forward(
        self,
        input: torch.Tensor,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    assert past_key_values is None, "past_key_values not supported"
    assert attention_mask is None, "attention_mask not supported"
    assert token_type_ids is None, "token_type_ids not supported"
    assert position_ids is None, "position_ids not supported"
    assert head_mask is None, "head_mask not supported"
    assert encoder_hidden_states is None, "encoder_hidden_states not supported"
    assert encoder_attention_mask is None, "encoder_attention_mask not supported"
    assert use_cache is None, "use_cache not supported"
    assert output_attentions is None, "output_attentions not supported"
    assert output_hidden_states is None, "output_hidden_states not supported"

    inputs_embeds = self.transformer.wte(input)
    position_ids = torch.arange(input.shape[1], dtype=torch.long, device=input.device)

    transformer_outputs = self.transformer(
        input_ids=None,
        past_key_values=past_key_values,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        encoder_hidden_states=encoder_hidden_states,
        encoder_attention_mask=encoder_attention_mask,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    hidden_states = transformer_outputs[0]
    mu, sigma = self.lm_head(hidden_states.to(self.lm_head.weight.dtype))
    sigma = nn.functional.softplus(sigma)

    return mu, sigma


def get_model_with_parallel_input(
    model_name: str,
    pretrained: bool = True,
    num_parallel_input: int = 1,
    num_embedding_layers: int = 2,
    num_head_layers: int = 2,
    cache_dir: Optional[str] = None,
    flash_attention: bool = False,
    **kwargs
):
    if sys.platform in ["linux", "linux2"] and flash_attention:
        attn_implementation = "flash_attention_2"
    else:
        attn_implementation = "eager"
    print(f"Using attention implementation: {attn_implementation}")
    
    if model_name == "openai-community/gpt2-micro":
        config = AutoConfig.from_pretrained("openai-community/gpt2")
        config.n_embd = config.hidden_size = 512
        config.n_layer = config.num_hidden_layers = 12
        config.n_head = config.num_attention_heads = 8
    elif model_name == "openai-community/gpt2-mini":
        config = AutoConfig.from_pretrained("openai-community/gpt2")
        config.n_embd = config.hidden_size = 512
        config.n_layer = config.num_hidden_layers = 6
        config.n_head = config.num_attention_heads = 8
    else:
        if cache_dir is not None and cache_dir.lower() != 'none':
            config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir, attn_implementation=attn_implementation)
        else:
            config = AutoConfig.from_pretrained(model_name, attn_implementation=attn_implementation)
    
    if pretrained:
        if cache_dir is not None and cache_dir.lower() != 'none':
            model = AutoModelForCausalLM.from_pretrained(model_name, config=config, cache_dir=cache_dir, attn_implementation=attn_implementation)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_name, config=config, attn_implementation=attn_implementation)
        if kwargs:
            warnings.warn(f"Ignoring kwargs for pretrained model: {kwargs.keys()}")
    else:
        if kwargs:
            print(f"Modifying config with: {kwargs.keys()}")
            for key, value in kwargs.items():
                print(f"  Setting {key} = {value}")
                setattr(config, key, value)
        
        model = GPT2LMHeadModel(config)

    disable_dropout(model)

    model.transformer.original_wte = model.transformer.wte
    model.transformer.wte = EmbeddingReplacement(
        num_parallel_input,
        config.hidden_size,
        config.hidden_size,
        num_embedding_layers
    )
    
    model.original_lm_head = model.lm_head
    model.lm_head = HeadReplacement(
        config.hidden_size,
        config.hidden_size,
        [num_parallel_input, num_parallel_input],
        num_head_layers
    )

    model.original_forward = model.forward
    model.forward = custom_forward.__get__(model, model.__class__)

    return model, config

