import os
import collections.abc
import torch
import torch.nn as nn
import pytorch_lightning as pl
from transformers import ViTConfig, AutoImageProcessor

from .vit_pretrain import ViTModel


class ViTDepthPatchEmbeddings(nn.Module):
    """
    This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
    `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
    Transformer.
    """

    def __init__(self, config):
        super().__init__()
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = 1, config.hidden_size

        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
        num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        if num_channels != self.num_channels:
            raise ValueError(
                "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
                f" Expected {self.num_channels} but got {num_channels}."
            )
        if not interpolate_pos_encoding:
            if height != self.image_size[0] or width != self.image_size[1]:
                raise ValueError(
                    f"Input image size ({height}*{width}) doesn't match model"
                    f" ({self.image_size[0]}*{self.image_size[1]})."
                )
        embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
        return embeddings

class ViTModelRGBD(pl.LightningModule):
    def __init__(self, input_channels, image_size, hidden_size, pretrain = True, frozen = True, depth_channel = "embed"):
        super().__init__()

        assert input_channels == 4, "input_channels must be 4 (rgb + depth)"

        # self.input_channels = input_channels
        self.image_size = image_size
        self.hidden_size = hidden_size
        self.pretrain = pretrain
        self.frozen = frozen
        self.depth_channel = depth_channel

        vit_path = os.path.join(os.path.dirname(os.getcwd()), "vit-base-patch16-224-in21k")
        self.config = ViTConfig()
        self.image_processor = AutoImageProcessor.from_pretrained(vit_path)

        if depth_channel == "conv":
            self.conv_compress = nn.Conv2d(4, 3, 1, stride=1)
        elif depth_channel == "embed":
            self.depth_patch_embeddings = ViTDepthPatchEmbeddings(self.config)
        else:
            raise ValueError(f"depth_channel must be 'conv' or 'embed', but got {depth_channel}")

        if pretrain:
            self.vit = ViTModel.from_pretrained(vit_path)
        else:
            self.vit = ViTModel(self.config)

        if frozen:
            for param in self.vit.parameters():
                param.requires_grad = False

            self.vit_out = nn.Linear(self.config.hidden_size, self.hidden_size)
    
    def forward(self, inputs):
        inputs[:, :3, :, :] = (inputs[:, :3, :, :] + 2) / 4
        inputs[:, 3, :, :] = inputs[:, 3, :, :] / 5

        if self.image_size != self.config.image_size:
            rgb_resize = self.image_processor(inputs[:, :3, :, :], 
                                              return_tensors="pt")["pixel_values"].to(inputs.device)
            # b,g channels are for placeholder, as self.image_processor requires 3 channels
            dep_resize = self.image_processor(inputs[:, 1:, :, :], 
                                              return_tensors="pt")["pixel_values"][:, -1:, :, :].to(inputs.device) 
            rgbd_resize = torch.cat((rgb_resize, dep_resize), dim=1)

        else:
            rgb_resize = inputs[:, :3, :, :]
            dep_resize = inputs[:, 3:, :, :]
            rgbd_resize = inputs

        if self.depth_channel == "conv":
            input_conv = self.conv_compress(rgbd_resize)
            hidden_states = self.vit(input_conv).last_hidden_state
        elif self.depth_channel == "embed":
            depth_embeddings = self.depth_patch_embeddings(dep_resize)
            hidden_states = self.vit(rgb_resize, depth_embeddings = depth_embeddings).last_hidden_state

        cls_hidden_states = hidden_states[:, 0, :]
        if self.frozen:
            cls_hidden_states = self.vit_out(cls_hidden_states)

        return cls_hidden_states


class ImageDecoder(pl.LightningModule):
    def __init__(self, args):
        super().__init__()

        self.hidden_size = args.n_embd
        self.out_channels = args.feature_list[0].channels
        self.args = args

        if args.image_encoder == "vit-patches":
            self.msr_head = nn.Linear(self.hidden_size, self.patch_size**2 * 4)
        elif args.image_encoder in ["vit-cls", "resnet"]:
            # Use a simple deconvolutional head from the hidden states
            # to reconstruct the image back to its padded size
            self.msr_head = nn.Sequential(
                # Input: (-1, hidden, 1, 1)
                nn.ConvTranspose2d(self.hidden_size, self.hidden_size//2, kernel_size = 4, padding = 0), # (-1, hidden/2, 4, 4)
                nn.SiLU(),
                nn.ConvTranspose2d(self.hidden_size//2, self.hidden_size//2, kernel_size = 4, stride = 2, padding = 0), # (-1, hidden/2, 10, 10)
                nn.SiLU(),
                nn.ConvTranspose2d(self.hidden_size//2, self.hidden_size//4, kernel_size = 4, stride = 2, padding = 0), # (-1, hidden/4, 22, 22)
                nn.SiLU(),
                nn.ConvTranspose2d(self.hidden_size//4, self.hidden_size//4, kernel_size = 4, stride = 2, padding = 0), # (-1, hidden/4, 46, 46)
                nn.SiLU(),
                nn.ConvTranspose2d(self.hidden_size//4, self.hidden_size//8, kernel_size = 4, stride = 2, padding = 0), # (-1, hidden/8, 94, 94)
                nn.SiLU(),
                nn.ConvTranspose2d(self.hidden_size//8, self.hidden_size//8, kernel_size = 4, stride = 2, padding = 0), # (-1, hidden/8, 190, 190)
                nn.SiLU(),
                nn.ConvTranspose2d(self.hidden_size//8, self.hidden_size//16, kernel_size = 4, stride = 2, padding = 0), # (-1, hidden/16, 382, 382)
                nn.SiLU(),
                nn.ConvTranspose2d(self.hidden_size//16, self.out_channels, kernel_size = 3, stride = 1, padding = 0), # (-1, 4, 384, 384)
                nn.SiLU(),
                # Output: (-1, 4, 384, 384)
            )

    def forward(self, hidden_states):
        if self.args.image_encoder == "vit-patches":
            hidden_states = hidden_states.view(-1, self.hidden_size)
            return self.msr_head(hidden_states) # (-1, patch_size**2 * 4)
        elif self.args.image_encoder in ["vit-cls", "resnet"]:
            # DeConv MSR head
            hidden_states = hidden_states.reshape(-1, self.hidden_size, 1, 1) # (-1, hidden, 1, 1)
            pred_img = self.msr_head(hidden_states)
            return pred_img.view(hidden_states.size(0), -1)