# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    logging,
    replace_return_docstrings,
)
from .configuration_autoencoder import AutoencoderConfig
import random
import pdb
from transformers.models.llama.modeling_llama import LlamaPreTrainedModel,LlamaRMSNorm,LlamaMLP

logger = logging.get_logger(__name__)

ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm)



class AELayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.mlp = LlamaMLP(config)
        self.layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

        # Fully Connected
        residual = hidden_states
        hidden_states = self.layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

class Encoder(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config):
        super().__init__(config)
        self.patch_size = config.patch_size
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.latent_size = config.latent_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.encoder_layers = nn.ModuleList([AELayer(config) for _ in range(config.num_encoder_layers)])
        self.num_enclayer_stage = config.num_encoder_layers // 2
        self.hidden_to_latent = nn.Linear(config.hidden_size, config.latent_size * 2)
        self.squeeze_layer = nn.Linear(self.patch_size * config.hidden_size, config.hidden_size)

        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value
        
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        # retrieve input_ids and inputs_embeds
        batch_size, seq_length = input_ids.shape
        latent_length = seq_length // self.patch_size
        if seq_length != self.patch_size:
            batch_size = batch_size * latent_length
            seq_length = self.patch_size
            input_ids = input_ids.reshape(batch_size, seq_length)

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
        if self.training:
            inputs_embeds = inputs_embeds.to(dtype=torch.bfloat16)

        hidden_states = inputs_embeds

        for stage in range(2):
            for layer_idx in range(self.num_enclayer_stage):
                encoder_idx = stage * self.num_enclayer_stage + layer_idx
                encoder_layer = self.encoder_layers[encoder_idx]
                hidden_states = encoder_layer(hidden_states)

            if stage == 0:
                merged_length = hidden_states.size(1) // self.patch_size
                hidden_states = hidden_states.view(batch_size, merged_length, -1)
                hidden_states = self.squeeze_layer(hidden_states)

        hidden_states = self.norm(hidden_states)
        latent_states = self.hidden_to_latent(hidden_states)
        latent_states = latent_states.reshape(batch_size // latent_length, latent_length, latent_states.shape[-1])

        return latent_states


class Decoder(LlamaPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.patch_size = config.patch_size
        self.num_declayer_stage = config.num_decoder_layers // 2
        
        self.latent_to_hidden = nn.Linear(config.latent_size, config.hidden_size)
        self.decoder_layers = nn.ModuleList([AELayer(config) for _ in range(config.num_decoder_layers)])
        self.expand_layer = nn.Linear(config.hidden_size, self.patch_size * config.hidden_size)
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        latent_states,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:

        batch_size, seq_length, latent_size = latent_states.shape
        origin_seq_length = seq_length
        if seq_length != 1:
            batch_size = batch_size * seq_length
            seq_length = 1
            latent_states = latent_states.reshape(batch_size, seq_length, latent_size)
        hidden_states = self.latent_to_hidden(latent_states)

        device = hidden_states.device

        for stage in range(2):
            for layer_idx in range(self.num_declayer_stage):
                decoder_idx = stage * self.num_declayer_stage + layer_idx
                decoder_layer = self.decoder_layers[decoder_idx]
                hidden_states = decoder_layer(hidden_states)

            if stage == 0:
                hidden_states = self.expand_layer(hidden_states)
                hidden_states = hidden_states.reshape(batch_size, seq_length * self.patch_size, -1)
                seq_length = hidden_states.size(1)

        hidden_states = self.norm(hidden_states)
        if origin_seq_length != 1:
            batch_size = batch_size // origin_seq_length
            hidden_states = hidden_states.reshape(batch_size, origin_seq_length * seq_length, -1)
        logits = F.linear(hidden_states, self.lm_head_weight)
        return logits


class Autoencoder(LlamaPreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.encoder = Encoder(config)
        self.decoder = Decoder(config)
        self.decoder.lm_head_weight = self.encoder.embed_tokens.weight
        self.vocab_size = config.vocab_size

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.encoder.embed_tokens

    def set_input_embeddings(self, value):
        self.encoder.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        input_ids = input_ids.reshape(-1, self.encoder.patch_size)
        if self.training:
            mask = torch.rand_like(input_ids.float()) > 0.15
            input_ids = input_ids * mask.long()

        latent_states = self.encoder(
            input_ids=input_ids,
        )
        mean, log_std = torch.chunk(latent_states, 2, dim=-1)
        std = torch.exp(log_std)
        eps = torch.randn_like(mean)
        latent_states = mean + eps * std
        latent_states = torch.nn.functional.dropout(latent_states, p=0.15, training=self.training)

        kl_loss = 0.5 * (torch.pow(mean, 2) + torch.pow(std, 2) - 1 - log_std * 2)
        kl_loss = torch.clamp(kl_loss, min = 0.5)
        kl_loss = torch.mean(torch.sum(kl_loss, dim=-1))

        logits = self.decoder(
            latent_states=latent_states,
        )

        logits = logits.float()

        loss_fct = CrossEntropyLoss()
        logits = logits.view(-1, self.config.vocab_size)
        labels = labels.view(-1).to(logits.device)
        loss = loss_fct(logits, labels) 
        if self.training:
            loss = loss * self.encoder.patch_size + kl_loss * 1e-3

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )


