#    Copyright 2018- The Hugging Face team. All rights reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.
# ------------------------------------------------------------------------
# Modified from CLIP (https://github.com/huggingface/transformers)
# Copyright 2024 Jiachen Li
# ------------------------------------------------------------------------

import torch
import torch.nn as nn
from transformers.activations import ACT2FN
from moe_model.model.moe.register import get_moe


class CLIPAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
    ):
        """Input shape: Batch x Time x Channel"""
        bsz, tgt_len, embed_dim = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scale
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
                f" {attn_weights.size()}"
            )

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_output = torch.bmm(attn_probs, value_states)

        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output

class CLIPMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.activation_fn = ACT2FN[config.hidden_act]
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, hidden_states):
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states
class CLIPEncoderMoELayer(nn.Module):
    def __init__(self, config, args):
        super().__init__()
        print("Using CLIP Encoder MoE Layer")
        self.embed_dim = config.hidden_size
        self.self_attn = CLIPAttention(config)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.moelayer = get_moe(config.moe_name)(
            in_embed_dim = self.embed_dim, 
            out_embed_dim = self.embed_dim, 
            num_of_experts = config.num_experts,
            num_selected = config.num_selected,
            expert = CLIPMLP(config),
            args = args
        )
    
    def forward(
        self,
        hidden_states,
        return_id_experts = False
    ):
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states = self.self_attn(hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        results, auxiliary_loss, id_experts = self.moelayer(hidden_states, return_id_experts)
        hidden_states = residual + results
        outputs = (hidden_states, auxiliary_loss,  id_experts)
        return outputs


class CLIPEncoder(nn.Module):
    def __init__(self, config , args):
        super().__init__()
        config.moe_name = args.moe_name
        config.num_experts = args.num_experts
        self.config = config
        self.layers = nn.ModuleList([CLIPEncoderMoELayer(config = config, args=args) for _ in range(config.num_hidden_layers)])
    
    def forward(
        self,
        inputs_embeds,
        return_id_experts = False
    ):
        encoder_states = ()
        hidden_states = inputs_embeds
        auxiliary_losses = []
        stored_history_experts = {}
        if return_id_experts:

            for idx, encoder_layer in enumerate(self.layers):
                encoder_states = encoder_states + (hidden_states,)
                layer_outputs = encoder_layer(hidden_states, return_id_experts)
                hidden_states = layer_outputs[0]
                auxiliary_loss = layer_outputs[1]
                auxiliary_losses.append(auxiliary_loss)
                stored_history_experts[idx] = layer_outputs[2]
        else:
            for idx, encoder_layer in enumerate(self.layers):
                encoder_states = encoder_states + (hidden_states,)
                layer_outputs = encoder_layer(hidden_states = hidden_states)
                hidden_states = layer_outputs[0]
                auxiliary_loss = layer_outputs[1]
                auxiliary_losses.append(auxiliary_loss)

        return encoder_states, auxiliary_losses, stored_history_experts 

class CLIPVisionEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=False,
        )

        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.num_positions = self.num_patches + 1
        self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
        self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

    def forward(self, pixel_values):
        batch_size = pixel_values.shape[0]
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))  # shape = [*, width, grid, grid]
        patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        embeddings = embeddings + self.position_embedding(self.position_ids)
        return embeddings

class CLIPSMoEVisionTransformer(nn.Module):
    def __init__(self, config, args):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size
        config.num_of_experts = args.num_experts
        config.num_selected = args.num_selected
        self.embeddings = CLIPVisionEmbeddings(config)
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.encoder = CLIPEncoder(
            config=config, 
            args = args
            )
        #self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)

    def forward(self, pixel_values, return_id_experts = False):
        hidden_states = self.embeddings(pixel_values)
        hidden_states = self.pre_layrnorm(hidden_states)
        encoder_outputs, auxiliary_losses, stored_history_experts = self.encoder(hidden_states, return_id_experts)
        return encoder_outputs[-1], torch.stack(auxiliary_losses).mean(), stored_history_experts
