# =============================================================================
# CONFIDENTIAL - FOR REVIEW ONLY
# This code is submitted as supplementary material for paper review.
# DO NOT DISTRIBUTE - Pending patent application.
# =============================================================================

from typing import Optional, Tuple
import warnings

import torch
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM

from .base import EmbeddingReplacement, HeadReplacement, disable_dropout


def custom_forward(
        self,
        input: torch.Tensor,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        state: Optional[Tuple[torch.Tensor]] = None,
        return_dict: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    assert state is None, "state is not supported"
    assert use_cache is None, "use_cache not supported"
    assert output_attentions is None, "output_attentions not supported"
    assert output_hidden_states is None, "output_hidden_states not supported"

    inputs_embeds = self.rwkv.embeddings(input)

    core_outputs = self.rwkv(
        input_ids=None,
        inputs_embeds=inputs_embeds,
        state=state,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    hidden_states = core_outputs[0]

    mu, sigma = self.head(hidden_states.to(self.head.weight.dtype))
    sigma = nn.functional.softplus(sigma)

    return mu, sigma


def get_model_with_parallel_input(
    model_name: str,
    pretrained: bool = True,
    num_parallel_input: int = 1,
    num_embedding_layers: int = 2,
    num_head_layers: int = 2,
    cache_dir: Optional[str] = None,
    **kwargs
):
    if cache_dir is not None and cache_dir.lower() != 'none':
        config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
    else:
        config = AutoConfig.from_pretrained(model_name)

    if pretrained:
        if cache_dir is not None and cache_dir.lower() != 'none':
            model = AutoModelForCausalLM.from_pretrained(model_name, config=config, cache_dir=cache_dir)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_name, config=config)
        if kwargs:
            warnings.warn(f"Ignoring kwargs for pretrained model: {kwargs.keys()}")
    else:
        if kwargs:
            print(f"Modifying config with: {kwargs.keys()}")
            for key, value in kwargs.items():
                print(f"  Setting {key} = {value}")
                setattr(config, key, value)
        model = AutoModelForCausalLM.from_config(config)

    disable_dropout(model)

    model.rwkv.original_embeddings = model.rwkv.embeddings
    model.rwkv.embeddings = EmbeddingReplacement(
        num_parallel_input,
        config.hidden_size,
        config.hidden_size,
        num_embedding_layers
    )
    
    model.original_head = model.head
    model.head = HeadReplacement(
        config.hidden_size,
        config.hidden_size,
        [num_parallel_input, num_parallel_input],
        num_head_layers
    )

    model.original_forward = model.forward
    model.forward = custom_forward.__get__(model, model.__class__)

    return model, config

