"""
Note this file is a direct modification from the pytorch library's implementation of ViT.
"""

import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, Dict, List, NamedTuple, Optional

import torch
import torch.nn as nn

# from ..ops.misc import Conv2dNormActivation, MLP
from torchvision.ops.misc import MLP
# from ..transforms._presets import ImageClassification, InterpolationMode
# from ..utils import _log_api_usage_once
# from ._api import register_model, Weights, WeightsEnum
# from ._meta import _IMAGENET_CATEGORIES
# from ._utils import _ovewrite_named_param, handle_legacy_interface




class ConvStemConfig(NamedTuple):
    out_channels: int
    kernel_size: int
    stride: int
    norm_layer: Callable[..., nn.Module] = nn.BatchNorm2d
    activation_layer: Callable[..., nn.Module] = nn.ReLU


class MLPBlock(MLP):
    """Transformer MLP block."""

    _version = 2

    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__(in_dim, [mlp_dim, in_dim], activation_layer=nn.GELU, inplace=None, dropout=dropout)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)

    def _load_from_state_dict(
        self,
        state_dict,
        prefix,
        local_metadata,
        strict,
        missing_keys,
        unexpected_keys,
        error_msgs,
    ):
        version = local_metadata.get("version", None)

        if version is None or version < 2:
            # Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
            for i in range(2):
                for type in ["weight", "bias"]:
                    old_key = f"{prefix}linear_{i+1}.{type}"
                    new_key = f"{prefix}{3*i}.{type}"
                    if old_key in state_dict:
                        state_dict[new_key] = state_dict.pop(old_key)

        super()._load_from_state_dict(
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        )


class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x, _ = self.self_attention(query=x, key=x, value=x, need_weights=False)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        # print("input.shape:", input.shape)
        # print("x.shape:", x.shape)
        # print("y.shape:", y.shape)
        # print("mlp_dim:", self.mlp)
        # add = x+y
        # print("add.shape:", add.shape)
        # exit()
        return x + y


class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default
        self.seq_length = seq_length
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")

        #This if statement is meant to handle variable length sequences
        if input.shape[1] <  self.seq_length:
            input = input + self.pos_embedding[:, :input.shape[1], :]
        else:
            input = input + self.pos_embedding
        return self.ln(self.layers(self.dropout(input)))


class TempRepsVisionTransformer(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        # image_size: int,
        # patch_size: int,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        num_classes: int = 1000,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        conv_stem_configs: Optional[List[ConvStemConfig]] = None,
        linear_proj=False
    ):
        super().__init__()
        # _log_api_usage_once(self)
        # torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        # self.image_size = image_size
        # self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.num_classes = num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer
        self.linear_proj = linear_proj # fot attempting lin proj method
        self.linear_out_head = nn.Linear(hidden_dim*seq_length, hidden_dim)
        #MLPBlock(hidden_dim*seq_length, hidden_dim, dropout)

        # if conv_stem_configs is not None:
        #     raise NotImplementedError("Conv stem not implemented yet for 1d")
        #     # # As per https://arxiv.org/abs/2106.14881
        #     # seq_proj = nn.Sequential()
        #     # prev_channels = 3
        #     # for i, conv_stem_layer_config in enumerate(conv_stem_configs):
        #     #     seq_proj.add_module(
        #     #         f"conv_bn_relu_{i}",
        #     #         Conv2dNormActivation(
        #     #             in_channels=prev_channels,
        #     #             out_channels=conv_stem_layer_config.out_channels,
        #     #             kernel_size=conv_stem_layer_config.kernel_size,
        #     #             stride=conv_stem_layer_config.stride,
        #     #             norm_layer=conv_stem_layer_config.norm_layer,
        #     #             activation_layer=conv_stem_layer_config.activation_layer,
        #     #         ),
        #     #     )
        #     #     prev_channels = conv_stem_layer_config.out_channels
        #     # seq_proj.add_module(
        #     #     "conv_last", nn.Conv2d(in_channels=prev_channels, out_channels=hidden_dim, kernel_size=1)
        #     # )
        #     # self.conv_proj: nn.Module = seq_proj
        # else:
        #     # self.conv_proj = nn.Conv2d(
        #     #     in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
        #     # )
        #     self.conv_proj = nn.Conv1d(
        #         in_channels=6, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
        #     )
        # seq_length = (image_size // patch_size) ** 2
        # seq_length = image_size // patch_size

        if self.linear_proj==False:
            # Add a class token
            self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim)) # batch, channels, hidden_dim
            # self.class_token = nn.Parameter(torch.zeros(1, hidden_dim))
            seq_length += 1

        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            norm_layer,
        )
        self.seq_length = seq_length

        # Create the linear head
        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, num_classes)

        self.heads = nn.Sequential(heads_layers)

        """ Let's assume 1d Conv, can add conv_proj_stems later if we want"""
        # if isinstance(self.conv_proj, nn.Conv2d):
        #     # Init the patchify stem
        #     fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
        #     nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
        #     if self.conv_proj.bias is not None:
        #         nn.init.zeros_(self.conv_proj.bias)
        # elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
        #     # Init the last 1x1 conv of the conv stem
        #     nn.init.normal_(
        #         self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
        #     )
        #     if self.conv_proj.conv_last.bias is not None:
        #         nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            nn.init.zeros_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        # n, c, h, w = x.shape
        n, t, c = x.shape
        p = self.patch_size
        # torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        # torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        # n_h = h // p
        # n_w = w // p
        n_t = t // p

        # # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        # x = self.conv_proj(x)
        # # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        # x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, c, t) -> (n, hidden_dim, t)
        x = self.conv_proj(x)
        # # (n, hidden_dim, t) -> (n, hidden_dim, t)
        # x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor):
        # This takes in a tensor of shape (batch_size, seq_length, hidden_dim)
        if self.linear_proj==False:
            # print("In here lin proj = false")
            # Reshape and permute the input tensor
            # x = self._process_input(x) #MODIFYING FOR C3T
            n = x.shape[0]

            # Expand the class token to the full batch
            batch_class_token = self.class_token.expand(n, -1, -1)
            # print("batch_class_token.shape:", batch_class_token.shape)
            # print("x.shape:", x.shape)
            x = torch.cat([batch_class_token, x], dim=1)

            x = self.encoder(x)

            # Classifier "token" as used by standard language architectures
            x = x[:, 0]
        else:
            # print("In here lin proj = true")
            x = self.encoder(x)
            x = x.view(x.shape[0], -1)
            x = self.linear_out_head(x)

        x = self.heads(x)

        #final output should be batch_size x num_classes
        return x