from typing import Tuple, Union

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import ViTForImageClassification, ViTImageProcessor
from data.tinyimagenet import *

def get_tinyimagenet_labels_from_dataset(dataset_root):
    """
    Extract labels for Tiny ImageNet dataset using the TinyImageNet class.
    
    Args:
        dataset_root (str): Path to the Tiny ImageNet root directory.
    
    Returns:
        list: Sorted list of class labels for Tiny ImageNet.
    """
    tiny_imagenet = TinyImageNet(root=dataset_root, split="train")
    train_folder = tiny_imagenet.split_folder

    # Get the list of folder names, which correspond to class labels
    class_labels = sorted(os.listdir(train_folder))
    return class_labels


class ViT(nn.Module):
    """
    ViT model
    Arguments:
        dataset (str): name of the dataset to be used.
        model_name_or_path (str): pretrained weights to be used
    """
    def __init__(self, dataset, model_name_or_path):
        super().__init__()
        if dataset == "CIFAR10":
            labels = [
                "airplane", "automobile", "bird", "cat", "deer",
                "dog", "frog", "horse", "ship", "truck"
            ]
        elif dataset == "TINYIMAGENET":
            labels = get_tinyimagenet_labels_from_dataset(os.getcwd()+"/data/")
        else: 
             raise Exception("Oops, this dataset cannot be combined with a ViT!")
        
        #self.processor = ViTImageProcessor.from_pretrained(model_name_or_path)

        self.vit = ViTForImageClassification.from_pretrained(
                        model_name_or_path,
                        num_labels=len(labels),
                        id2label={str(i): c for i, c in enumerate(labels)},
                        label2id={c: str(i) for i, c in enumerate(labels)}
                    )

    def forward(self, x):
        #device = self.vit.device
        #x = x.to(device)
        #print(device)
        #x = self.processor(images=x, return_tensors="pt")["pixel_values"]
        #print(x.device)
        return self.vit(x).logits


# # change size of images so that they match the patch size
# def img_to_patch(x, patch_size, flatten_channels=True):
#     """
#     Args:
#         x: Tensor representing the image of shape [B, C, H, W]
#         patch_size: Number of pixels per dimension of the patches (integer)
#         flatten_channels: If True, the patches will be returned in a flattened format
#                            as a feature vector instead of a image grid.
#     """
#     B, C, H, W = x.shape
#     x = x.reshape(B, C, H // patch_size, patch_size, W // patch_size, patch_size)
#     x = x.permute(0, 2, 4, 1, 3, 5)  # [B, H', W', C, p_H, p_W]
#     x = x.flatten(1, 2)  # [B, H'*W', C, p_H, p_W]
#     if flatten_channels:
#         x = x.flatten(2, 4)  # [B, H'*W', C*p_H*p_W]
#     return x

# class AttentionBlock(nn.Module):
#     def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
#         """Attention Block.

#         Args:
#             embed_dim: Dimensionality of input and attention feature vectors
#             hidden_dim: Dimensionality of hidden layer in feed-forward network
#                          (usually 2-4x larger than embed_dim)
#             num_heads: Number of heads to use in the Multi-Head Attention block
#             dropout: Amount of dropout to apply in the feed-forward network

#         """
#         super().__init__()

#         self.layer_norm_1 = nn.LayerNorm(embed_dim)
#         self.attn = nn.MultiheadAttention(embed_dim, num_heads)
#         self.layer_norm_2 = nn.LayerNorm(embed_dim)
#         self.linear = nn.Sequential(
#             nn.Linear(embed_dim, hidden_dim),
#             nn.GELU(),
#             nn.Dropout(dropout),
#             nn.Linear(hidden_dim, embed_dim),
#             nn.Dropout(dropout),
#         )

#     def forward(self, x):
#         inp_x = self.layer_norm_1(x)
#         x = x + self.attn(inp_x, inp_x, inp_x)[0]
#         x = x + self.linear(self.layer_norm_2(x))
#         return x

# class VisionTransformer(nn.Module):
#     def __init__(
#         self,
#         embed_dim,
#         hidden_dim,
#         num_channels,
#         num_heads,
#         num_layers,
#         num_classes,
#         patch_size,
#         num_patches,
#         dropout=0.0,
#     ):
#         """Vision Transformer.

#         Args:
#             embed_dim: Dimensionality of the input feature vectors to the Transformer
#             hidden_dim: Dimensionality of the hidden layer in the feed-forward networks
#                          within the Transformer
#             num_channels: Number of channels of the input (3 for RGB)
#             num_heads: Number of heads to use in the Multi-Head Attention block
#             num_layers: Number of layers to use in the Transformer
#             num_classes: Number of classes to predict
#             patch_size: Number of pixels that the patches have per dimension
#             num_patches: Maximum number of patches an image can have
#             dropout: Amount of dropout to apply in the feed-forward network and
#                       on the input encoding

#         """
#         super().__init__()

#         self.patch_size = patch_size

#         # Layers/Networks
#         self.input_layer = nn.Linear(num_channels * (patch_size**2), embed_dim)
#         self.transformer = nn.Sequential(
#             *(AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers))
#         )
#         self.mlp_head = nn.Sequential(nn.LayerNorm(embed_dim), nn.Linear(embed_dim, num_classes))
#         self.dropout = nn.Dropout(dropout)

#         # Parameters/Embeddings
#         self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
#         self.pos_embedding = nn.Parameter(torch.randn(1, 1 + num_patches, embed_dim))

#     def forward(self, x):
#         # Preprocess input
#         x = img_to_patch(x, self.patch_size)
#         B, T, _ = x.shape
#         x = self.input_layer(x)

#         # Add CLS token and positional encoding
#         cls_token = self.cls_token.repeat(B, 1, 1)
#         x = torch.cat([cls_token, x], dim=1)
#         x = x + self.pos_embedding[:, : T + 1]

#         # Apply Transforrmer
#         x = self.dropout(x)
#         x = x.transpose(0, 1)
#         x = self.transformer(x)

#         # Perform classification prediction
#         cls = x[0]
#         out = self.mlp_head(cls)
#         return out

