
import collections
from dataclasses import dataclass
from typing import Optional, Tuple

from einops import rearrange
from transformers import ViTConfig, PreTrainedModel, VisionEncoderDecoderConfig
from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel, ViTModel, GPT2Model
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.vit.modeling_vit import ViTEmbeddings
import requests
from PIL import Image

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from timm.models import register_model
from transformers.utils import ModelOutput

import torch


@dataclass
class EncoderDecoderOutput(ModelOutput):
    logits: Optional[torch.FloatTensor] = None
    encoder_last_hidden_state: Optional[Tuple[torch.FloatTensor, ...]] = None
    past_key_values: Optional[Tuple[Tuple[torch.FloatTensor], ...]] = None
    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
    vision_loss: Optional[torch.FloatTensor] = None
    text_loss: Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None


@dataclass
class Vit3dModelOutput(BaseModelOutputWithPooling):
    vision_embeds: Optional[torch.FloatTensor] = None
    vae_embeds: Optional[torch.FloatTensor] = None
    loss: Optional[torch.FloatTensor] = None
    vae_loss: Optional[torch.FloatTensor] = None


class ViTPatchEmbeddings3D(nn.Module):


    def __init__(self, config):
        super().__init__()
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size
        self.flatten = getattr(config, "flatten", False)

        if config.token_ids is not None:
            token_ids = torch.tensor(config.token_ids, dtype=torch.long)
            self.register_buffer("token_ids", token_ids)
            num_patches = len(token_ids)
            print(num_patches)
        else:
            self.token_ids = None
            num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) * (image_size[2] // patch_size[2])

        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size, patch_size)

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        if not self.flatten:
            self.projection = nn.Conv3d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
        else:
            self.projection = nn.Conv1d(
                num_channels,
                hidden_size,
                kernel_size=np.prod(patch_size),
                stride=np.prod(patch_size)
            )

        print(self.token_ids)
        print(self.projection)

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        if len(pixel_values.shape) == 4:
            pixel_values = pixel_values.unsqueeze(1)

        if self.flatten:
            pixel_values = rearrange(pixel_values, "b c h w d -> b c (h w d)")

        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
        if self.token_ids is not None:
            embeddings = torch.index_select(embeddings, -2, self.token_ids)
        return embeddings


class ViTEmbeddings3D(ViTEmbeddings):
    def __init__(self, config):
        super().__init__(config)
        self.patch_embeddings = ViTPatchEmbeddings3D(config)
        self.position_embeddings = (
            nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, config.hidden_size))
        )

    def forward(
        self,
        pixel_values: torch.Tensor,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        interpolate_pos_encoding: bool = False,
    ) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=False)

        if bool_masked_pos is not None:
            seq_length = embeddings.shape[1]
            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
            # replace the masked visual tokens by mask_tokens
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

        # add the [CLS] token to the embedded patch tokens
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        # print(embeddings.shape, cls_tokens.shape, self.position_embeddings.shape)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        embeddings = embeddings + self.position_embeddings

        embeddings = self.dropout(embeddings)

        return embeddings


class ViT3dModel(ViTModel):
    def __init__(self, config):
        super().__init__(config)
        self.embeddings = ViTEmbeddings3D(config)


class ViT3dWithProjectionModel(ViT3dModel):
    config_class = ViTConfig

    def __init__(self, config):
        super().__init__(config)
        self.loss_fn = nn.MSELoss()
        self.patch_size = config.patch_size
        self.with_vae = config.with_vae
        # self.mixup = config.mixup

        if self.with_vae:
            vae_dim = 4 * 96 * 96
            self.vae_projection = nn.Sequential(
                nn.Linear(config.hidden_size, 1024),
                nn.GELU(),
                nn.Linear(config.projection_dim, 1024),
                nn.GELU(),
                nn.Linear(1024, vae_dim),
            )
            self.projection = nn.Sequential(
                nn.Linear(config.hidden_size, config.projection_dim),
                nn.GELU(),
                nn.Linear(config.projection_dim, config.projection_dim),
            )
        else:
            self.projection = nn.Linear(config.hidden_size, config.projection_dim)

    def forward(
        self,
        pixel_values: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        vae_labels: Optional[torch.Tensor] = None,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        interpolate_pos_encoding: Optional[bool] = None,
        return_dict: Optional[bool] = None
    ):
        # change dtype to same as weight
        # print(self.embeddings.position_embeddings.dtype, pixel_values.dtype)
        pixel_values = pixel_values.type_as(self.embeddings.position_embeddings)
        if labels is not None:
            labels = labels.type_as(self.embeddings.position_embeddings)
        if vae_labels is not None:
            vae_labels = vae_labels.type_as(self.embeddings.position_embeddings)

        # if not self.training:
        #     print(self.mixup)

        output = super().forward(
            pixel_values=pixel_values,
            bool_masked_pos=bool_masked_pos,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            interpolate_pos_encoding=interpolate_pos_encoding,
            return_dict=return_dict
        )
        logits = self.projection(output[0][:, 0, :])
        loss = self.loss_fn(logits, labels) if labels is not None else None

        vae_loss = 0.
        vae_logits = 0.
        if self.with_vae:
            vae_logits = self.vae_projection(output[0][:, 0, :])
            vae_loss = self.loss_fn(vae_logits, vae_labels.flatten(1)) if vae_labels is not None else None


        return Vit3dModelOutput(
            last_hidden_state=output.last_hidden_state,
            pooler_output=output.pooler_output,
            hidden_states=output.hidden_states,
            attentions=output.attentions,
            vision_embeds=logits,
            vae_embeds=vae_logits,
            loss=loss,
            vae_loss=vae_loss
        )


class ViT3dDecoderModel(PreTrainedModel):
    config_class = VisionEncoderDecoderConfig

    def __init__(
        self,
        config
    ):

        super(ViT3dDecoderModel, self).__init__(config)

        encoder = ViT3dModel(config.encoder)
        vit_decoder = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
        # image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
        vit_decoder.encoder = encoder  # replace the encoder with the custom one
        for param in vit_decoder.decoder.parameters():
            param.requires_grad = False

        self.vit_decoder = vit_decoder
        self.classifier = nn.Sequential(
            nn.Linear(config.encoder.hidden_size, config.encoder.projection_dim),
        )
        self.loss_fn = nn.MSELoss()

    def forward(self, pixel_values, labels=None, vision_embeds=None, **kwargs):
        encoder_outputs = self.vit_decoder.encoder(pixel_values)
        logits = self.classifier(encoder_outputs[0][:, 0, :])
        # logits = logits / torch.norm(logits, p=2, dim=-1, keepdim=True)

        outputs = self.vit_decoder(
            encoder_outputs=encoder_outputs,
            labels=labels["input_ids"],
            decoder_attention_mask=labels["attention_mask"]
        )

        vision_loss = self.loss_fn(logits, vision_embeds) if vision_embeds is not None else 0.
        loss = 0.25 * outputs.loss + vision_loss

        return EncoderDecoderOutput(
            logits=logits,
            past_key_values=outputs.past_key_values,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            vision_loss=vision_loss,
            text_loss=outputs.loss,
            loss=loss,
        )

    def generate(self, pixel_values, max_new_tokens=77, **kwargs):
        tokens = self.vit_decoder.generate(pixel_values, max_new_tokens=max_new_tokens, **kwargs)
        logits = self.classifier(self.vit_decoder.encoder(pixel_values)[0][:, 0, :])

        return tokens, logits
