from time import time
from copy import deepcopy

import numpy as np 

import torch
from torch import nn
from torch.nn import functional as F
from typing import Optional, Tuple, Union

from transformers import CLIPVisionConfig, CLIPVisionModel
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from transformers.utils import torch_int


##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################


class CLIPAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config):
        
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
                f" {self.num_heads})."
            )
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
        self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)

    def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Input shape: Batch x Time x Channel"""

        bsz, tgt_len, embed_dim = hidden_states.size()

        # get query proj
        query_states = self.q_proj(hidden_states) * self.scale
        key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
        value_states = self._shape(self.v_proj(hidden_states), -1, bsz)

        proj_shape = (bsz * self.num_heads, -1, self.head_dim)
        query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
        key_states = key_states.view(*proj_shape)
        value_states = value_states.view(*proj_shape)

        src_len = key_states.size(1)
        attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))

        if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
            raise ValueError(
                f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
                f" {attn_weights.size()}"
            )

        # apply the causal_attention_mask first
        if causal_attention_mask is not None:
            if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
                    f" {causal_attention_mask.size()}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, tgt_len, src_len):
                raise ValueError(
                    f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
            attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        if output_attentions:
            # this operation is a bit akward, but it's required to
            # make sure that attn_weights keeps its gradient.
            # In order to do so, attn_weights have to reshaped
            # twice and have to be reused in the following
            attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
            attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
        else:
            attn_weights_reshaped = None

        attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_output = torch.bmm(attn_probs, value_states)

        if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)

        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights_reshaped


class QuickGELUActivation(nn.Module):
    """
    Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
    """

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input * torch.sigmoid(1.702 * input)


class CLIPMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.activation_fn = QuickGELUActivation()
        self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class CLIPEncoderLayer(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.embed_dim = config.hidden_size
        self.self_attn = CLIPAttention(config)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
        self.mlp = CLIPMLP(config)
        self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        causal_attention_mask: torch.Tensor,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
                `(config.encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
        )
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        return outputs


class CLIPEncoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.gradient_checkpointing = False

    def forward(
        self,
        inputs_embeds,
        attention_mask: Optional[torch.Tensor] = None,
        causal_attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        r"""
        Args:
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Causal mask for the text model. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        hidden_states = inputs_embeds
        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    encoder_layer.__call__,
                    hidden_states,
                    attention_mask,
                    causal_attention_mask,
                    output_attentions,
                )
            else:
                layer_outputs = encoder_layer(
                    hidden_states,
                    attention_mask,
                    causal_attention_mask,
                    output_attentions=output_attentions,
                )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
        return BaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
        )


class CLIPVisionEmbeddings(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.num_positions = self.num_patches + 1
        # self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
        self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)


    def forward(self, patch_embeds: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
        batch_size, _, _ = patch_embeds.shape

        class_embeds = self.class_embedding.expand(batch_size, 1, -1)
        embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
        # embeddings = embeddings + self.position_embedding(self.position_ids)
        return embeddings


class CLIPVisionTransformer(nn.Module):
    def __init__(self, config: CLIPVisionConfig):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = CLIPVisionEmbeddings(config)
        self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        self.encoder = CLIPEncoder(config)
        self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)


    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = False,
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        r"""
        Returns:

        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if pixel_values is None:
            raise ValueError("You have to specify pixel_values")

        hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
        hidden_states = self.pre_layrnorm(hidden_states)

        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs[0]
        pooled_output = last_hidden_state[:, 0, :]
        pooled_output = self.post_layernorm(pooled_output)

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor:
    # more stable https://github.com/pytorch/pytorch/issues/41663
    gumbel_dist = torch.distributions.gumbel.Gumbel(
        torch.tensor(0.0, device=logits.device, dtype=logits.dtype),
        torch.tensor(1.0, device=logits.device, dtype=logits.dtype),
    )
    gumbels = gumbel_dist.sample(logits.shape)

    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

    
class VectorQuantizer(nn.Module):
    """
    Reference:
    [1] https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py
    """
    def __init__(self, embedding_dim):
        super(VectorQuantizer, self).__init__()
        self.linear = nn.Linear(embedding_dim, 2)
        
        self.embedding_dim = embedding_dim
        forma = torch.nn.init.normal_(torch.zeros(self.embedding_dim)) / embedding_dim**0.5
        self.forma = nn.Parameter(forma)

    def forward(self, latents, input):
        logits = self.linear(latents) # [B, 576, 2]

        one_hotes = gumbel_softmax(logits, hard=True) # [B, 576, 2]
        forma_batch = torch.zeros_like(latents) # [B x 576 x 1024]
        forma_batch += self.forma[None, None, :]  # [B x 576 x 1024]

        quantized_latents = input * one_hotes[:, :, :1] + forma_batch * one_hotes[:, :, 1:] # [B x 576 x 1024]
        input_vector_amount = torch.mean(one_hotes[:, :, :1])

        inds = torch.arange(one_hotes.shape[1], device=one_hotes.device)[None].repeat(latents.shape[0], 1)
        gumbel_mask = one_hotes[:, :, 0] > 0.5
        num_items = gumbel_mask.sum(dim=1)
        encoding_inds = inds[gumbel_mask]
        encoding_inds = encoding_inds.split(num_items.tolist())

        val_dict = {
            'input_vector_amount': input_vector_amount,
            'encoding_inds': encoding_inds,
            'input_norm': input.norm(dim=-1),
        }
        return quantized_latents, val_dict  # [B x D x H x W]

##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
##########################################################################################################
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        # First convolution layer
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = QuickGELUActivation()

        # Second convolution layer
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Skip connection (if input and output channels are different, use 1x1 conv)
        self.skip_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.skip_conv(x)  # Identity for skip connection

        # Pass through the first convolutional layer
        out = self.relu(self.bn1(self.conv1(x)))

        # Pass through the second convolutional layer
        out = self.bn2(self.conv2(out))

        # Add the skip connection to the output
        out += identity
        out = self.relu(out)  # Final ReLU activation

        return out


class Decoder(nn.Module):
    def __init__(self, emb_size):
        super(Decoder, self).__init__()

        # Stage 1: Upsample from 14x14 to 28x28, reduce channels to 384
        self.upconv1 = nn.ConvTranspose2d(emb_size, emb_size // 2, kernel_size=4, stride=2, padding=1)
        self.resblock1 = ResidualBlock(emb_size // 2, emb_size // 2)

        # Stage 2: Upsample from 28x28 to 56x56, educe channels to 192
        self.upconv2 = nn.ConvTranspose2d(emb_size // 2, emb_size // 4, kernel_size=4, stride=2, padding=1)
        self.resblock2 = ResidualBlock(emb_size // 4, emb_size // 4)

        # Stage 3: Upsample from 56x56 to 112x112, reduce channels to 96
        self.upconv3 = nn.ConvTranspose2d(emb_size // 4, emb_size // 8, kernel_size=4, stride=2, padding=1)
        self.resblock3 = ResidualBlock(emb_size // 8, emb_size // 8)

        # Stage 4: Upsample from 112x112 to 224x224, reduce channels to 48
        self.upconv4 = nn.ConvTranspose2d(emb_size // 8, emb_size // 16, kernel_size=4, stride=2, padding=1)
        self.resblock4 = ResidualBlock(emb_size // 16, emb_size // 16)

        # Stage 5: Final convolution to reduce channels to 3 (RGB output)
        self.final_conv = nn.Conv2d(emb_size // 16, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # Stage 1: Decoder part with residual block
        x = self.upconv1(x)  # Upsample
        x = self.resblock1(x)  # Apply residual block

        # Stage 2: Decoder part with residual block
        x = self.upconv2(x)  # Upsample
        x = self.resblock2(x)  # Apply residual block

        # Stage 3: Decoder part with residual block
        x = self.upconv3(x)  # Upsample
        x = self.resblock3(x)  # Apply residual block

        # Stage 4: Decoder part with residual block
        x = self.upconv4(x)  # Upsample
        x = self.resblock4(x)  # Apply residual block

        # Final convolution to output RGB image
        x = self.final_conv(x)
        return x


class VQVAE(nn.Module):

    def __init__(self, iva_factor, encoder_layers, pretrained_path, cache_dir):
        super(VQVAE, self).__init__()
        self.iva_factor = iva_factor
        self.low_iva_factor = 0.1
        self.config = CLIPVisionConfig.from_pretrained(pretrained_path, cache_dir=cache_dir)

        self.config.num_hidden_layers = encoder_layers
        self.config.hidden_size = 1024
        self.config.intermediate_size = 4096

        self.encoder = CLIPVisionTransformer(self.config) # [B x 576 x 1024]
        self.vq_layer = VectorQuantizer(self.config.hidden_size) # [B x 576 x 1024]

        self.config.num_hidden_layers = 4
        self.transformer_features = CLIPVisionTransformer(self.config)
        self.config.num_hidden_layers = 2
        self.transformer_images = CLIPVisionTransformer(self.config)

        self.decoder = Decoder(self.config.hidden_size)

    def forward(self, input, image):
        features = input.view(input.shape[0], input.shape[1], -1).permute(0, 2, 1) # [B x 768 x 14 x 14] -> [B x 196 x 768]

        encoding = self.encoder(features).last_hidden_state[:, 1:] # [B x 576 x 1024]
        quantized_latents, val_dict = self.vq_layer(encoding, features) # [B x 576 x 1024]
        recovered_features = self.transformer_features(quantized_latents).last_hidden_state[:, 1:] # [B x 196 x 768]
        features = self.transformer_images(recovered_features).last_hidden_state[:, 1:] # [B x 196 x 768]

        H = W = int(features.shape[1]**0.5) # 14
        recovered_features = recovered_features.view(features.shape[0], H, W, features.shape[2]).permute(0, 3, 1, 2) # [B x 768 x 14 x 14]
        features = features.view(features.shape[0], H, W, features.shape[2]).permute(0, 3, 1, 2) # [B x 768 x 14 x 14]
        recovered_image = self.decoder(features) # [B x 3 x 224 x 224]

        in_out_cos = torch.einsum('bdij,bdij->bij',
                                  recovered_image / recovered_image.norm(dim=1, keepdim=True),
                                  image / image.norm(dim=1, keepdim=True))
        val_dict['in_out_cos'] = in_out_cos
        val_dict['transposed_input'] = recovered_image
        return [recovered_image, recovered_features, input, val_dict]

    def loss_function(self, input, input_features, 
                            rec, rec_features, input_vector_amount):
        """
        :param args:
        :param kwargs:
        :return:
        """
        rec_loss = F.mse_loss(rec, input)
        rec_features_loss = F.mse_loss(rec_features, input_features)
        iva_factor = torch.tensor(self.iva_factor, device=input_vector_amount.device)
        low_iva_factor = torch.tensor(self.low_iva_factor, device=input_vector_amount.device)
        # loss = rec_loss + rec_features_loss + torch.max(iva_factor, input_vector_amount)
        loss = (rec_loss + rec_features_loss) + \
            0.1 * torch.max(iva_factor, input_vector_amount) - \
            0.1 * torch.min(low_iva_factor, input_vector_amount)

        normed_recons = rec / torch.norm(rec, dim=1)[:, None]
        normed_input = input / torch.norm(input, dim=1)[:, None]
        latent_cross_product = torch.einsum('bdij,bdij->bij', normed_recons, normed_input) # [B x 14 x 14]
        cos_loss = latent_cross_product.mean()

        return {
            'loss': loss,
            'msei': rec_loss,
            'msef': rec_features_loss,
            'cos': cos_loss,
            'iva': input_vector_amount
        }
