import collections
from typing import Optional
import torch
from torch import nn
import numpy as np
from transformers import ViTModel
from transformers.models.vit.modeling_vit import ViTEmbeddings
from einops import rearrange
from transformers import CLIPVisionConfig, Trainer, ViTConfig

def calc_image_size(patch_size):
    base_size = (91,109,91)
    res = []
    for dim in base_size:
        pad_size = (patch_size - dim % patch_size) % patch_size
        res.append(dim + pad_size)
    return tuple(res)

class ViTPatchEmbeddings3D(nn.Module):
    def __init__(self, config):
        super().__init__()
        image_size, patch_size = config.image_size, config.patch_size
        num_channels, hidden_size = config.num_channels, config.hidden_size
        self.flatten = getattr(config, "flatten", False)

        if config.token_ids is not None:
            token_ids = torch.tensor(config.token_ids, dtype=torch.long)
            self.register_buffer("token_ids", token_ids)
            num_patches = len(token_ids)
            print(num_patches)
        else:
            self.token_ids = None
            num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) * (image_size[2] // patch_size[2])

        image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size, image_size)
        patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size, patch_size)

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_channels = num_channels
        self.num_patches = num_patches

        if not self.flatten:
            if  config.patch_type == "perceptron":
                from einops.layers.torch import Rearrange
                # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
                chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:3]
                from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
                to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
                axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)}
                self.projection = nn.Sequential(
                    Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(np.prod(patch_size), hidden_size)
                )
            else:
                self.projection = nn.Conv3d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
        else:
            self.projection = nn.Conv1d(
                num_channels,
                hidden_size,
                kernel_size=np.prod(patch_size),
                stride=np.prod(patch_size)
            )

        print(self.token_ids)
        print(self.projection)

    def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
        if len(pixel_values.shape) == 4:
            pixel_values = pixel_values.unsqueeze(1)

        if self.flatten:
            pixel_values = rearrange(pixel_values, "b c h w d -> b c (h w d)")

        embeddings = self.projection(pixel_values)
        if isinstance(self.projection, nn.Conv3d):
            embeddings = embeddings.flatten(2).transpose(1, 2)
        if self.token_ids is not None:
            embeddings = torch.index_select(embeddings, -2, self.token_ids)
        return embeddings


class ViTEmbeddings3D(ViTEmbeddings):
    def __init__(self, config):
        super().__init__(config)
        self.patch_embeddings = ViTPatchEmbeddings3D(config)
        self.position_embeddings = (
            nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, config.hidden_size))
        )

    def forward(
        self,
        pixel_values: torch.Tensor,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        interpolate_pos_encoding: bool = False,
    ) -> torch.Tensor:
        batch_size = pixel_values.shape[0]
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=False)

        if bool_masked_pos is not None:
            seq_length = embeddings.shape[1]
            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
            # replace the masked visual tokens by mask_tokens
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

        # add the [CLS] token to the embedded patch tokens
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        # print(embeddings.shape, cls_tokens.shape, self.position_embeddings.shape)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        embeddings = embeddings + self.position_embeddings

        embeddings = self.dropout(embeddings)

        return embeddings

class ViT3dModel(ViTModel):
    def __init__(self, config, add_pooling_layer=True):
        super().__init__(config, add_pooling_layer=add_pooling_layer)
        self.embeddings = ViTEmbeddings3D(config)