import math
from collections import OrderedDict
from functools import partial
from typing import Any, Callable, NamedTuple, Optional

import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights

class Attention(nn.Module):
    def __init__(self, 
            hidden_dim, 
            num_heads, 
            dropout=0.0, 
            bias=True,
            batch_first=True
        ):
        super(Attention, self).__init__()
        self.num_attention_heads = num_heads
        self.attention_head_size = int(hidden_dim / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.batch_first = batch_first
        self.query = nn.Linear(hidden_dim, self.all_head_size, bias=bias)
        self.value = nn.Linear(hidden_dim, self.all_head_size, bias=bias)
        self.key = nn.Linear(hidden_dim, self.all_head_size, bias=bias)
        self.out = nn.Linear(hidden_dim, hidden_dim, bias=bias)
        self.attn_dropout = nn.Dropout(dropout)
        self.proj_dropout = nn.Dropout(dropout)
        self.softmax = nn.Softmax(dim=-1)
    
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)

    def forward(self, x):
        mixed_query_layer = self.query(x)
        mixed_key_layer = self.key(x)
        mixed_value_layer = self.value(x)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)

        return attention_output 
    
class MLPBlock(nn.Sequential):
    def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
        super().__init__()
        self.linear_1 = nn.Linear(in_dim, mlp_dim)
        self.act = nn.GELU()
        self.dropout_1 = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(mlp_dim, in_dim)
        self.dropout_2 = nn.Dropout(dropout)

        nn.init.xavier_uniform_(self.linear_1.weight)
        nn.init.xavier_uniform_(self.linear_2.weight)
        nn.init.normal_(self.linear_1.bias, std=1e-6)
        nn.init.normal_(self.linear_2.bias, std=1e-6)

class EncoderBlock(nn.Module):
    """Transformer encoder block."""

    def __init__(
        self,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = Attention(hidden_dim, num_heads, dropout=attention_dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (seq_length, batch_size, hidden_dim) got {input.shape}")
        x = self.ln_1(input)
        x = self.self_attention(x)
        x = self.dropout(x)
        x = x + input

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y


class Encoder(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation."""

    def __init__(
        self,
        seq_length: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float,
        attention_dropout: float,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
    ):
        super().__init__()
        # Note that batch_size is on the first dim because
        # we have batch_first=True in nn.MultiAttention() by default
        self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
        self.dropout = nn.Dropout(dropout)
        layers: OrderedDict[str, nn.Module] = OrderedDict()
        for i in range(num_layers):
            layers[f"encoder_layer_{i}"] = EncoderBlock(
                num_heads,
                hidden_dim,
                mlp_dim,
                dropout,
                attention_dropout,
                norm_layer,
            )
        self.layers = nn.Sequential(layers)
        self.ln = norm_layer(hidden_dim)

    def forward(self, input: torch.Tensor):
        torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
        input = input + self.pos_embedding
        return self.ln(self.layers(self.dropout(input)))


class ViT(nn.Module):
    """Vision Transformer as per https://arxiv.org/abs/2010.11929."""

    def __init__(
        self,
        image_size: int,
        patch_size: int,
        num_layers: int,
        num_heads: int,
        hidden_dim: int,
        mlp_dim: int,
        dropout: float = 0.0,
        attention_dropout: float = 0.0,
        source_num_classes: int = 1000,
        target_num_classes: int = 100,
        representation_size: Optional[int] = None,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6)
    ):
        super().__init__()
        torch._assert(image_size % patch_size == 0, "Input shape indivisible by patch size!")
        self.image_size = image_size
        self.patch_size = patch_size
        self.hidden_dim = hidden_dim
        self.mlp_dim = mlp_dim
        self.attention_dropout = attention_dropout
        self.dropout = dropout
        self.source_num_classes = source_num_classes
        self.target_num_classes = target_num_classes
        self.representation_size = representation_size
        self.norm_layer = norm_layer
        self.conv_proj = nn.Conv2d(
            in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size
        )

        seq_length = (image_size // patch_size) ** 2

        # Add a class token
        self.class_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        seq_length += 1

        self.encoder = Encoder(
            seq_length,
            num_layers,
            num_heads,
            hidden_dim,
            mlp_dim,
            dropout,
            attention_dropout,
            norm_layer,
        )
        self.seq_length = seq_length

        heads_layers: OrderedDict[str, nn.Module] = OrderedDict()
        if representation_size is None:
            heads_layers["head"] = nn.Linear(hidden_dim, target_num_classes)
        else:
            heads_layers["pre_logits"] = nn.Linear(hidden_dim, representation_size)
            heads_layers["act"] = nn.Tanh()
            heads_layers["head"] = nn.Linear(representation_size, target_num_classes)

        self.heads = nn.Sequential(heads_layers)

        self.src_head = nn.Linear(hidden_dim, source_num_classes)
        self.src_head.weight.require_grad = False
        self.src_head.bias.require_grad = False

        if isinstance(self.conv_proj, nn.Conv2d):
            # Init the patchify stem
            fan_in = self.conv_proj.in_channels * self.conv_proj.kernel_size[0] * self.conv_proj.kernel_size[1]
            nn.init.trunc_normal_(self.conv_proj.weight, std=math.sqrt(1 / fan_in))
            if self.conv_proj.bias is not None:
                nn.init.zeros_(self.conv_proj.bias)
        elif self.conv_proj.conv_last is not None and isinstance(self.conv_proj.conv_last, nn.Conv2d):
            # Init the last 1x1 conv of the conv stem
            nn.init.normal_(
                self.conv_proj.conv_last.weight, mean=0.0, std=math.sqrt(2.0 / self.conv_proj.conv_last.out_channels)
            )
            if self.conv_proj.conv_last.bias is not None:
                nn.init.zeros_(self.conv_proj.conv_last.bias)

        if hasattr(self.heads, "pre_logits") and isinstance(self.heads.pre_logits, nn.Linear):
            fan_in = self.heads.pre_logits.in_features
            nn.init.trunc_normal_(self.heads.pre_logits.weight, std=math.sqrt(1 / fan_in))
            nn.init.zeros_(self.heads.pre_logits.bias)

        if isinstance(self.heads.head, nn.Linear):
            nn.init.zeros_(self.heads.head.weight)
            nn.init.zeros_(self.heads.head.bias)

    def _process_input(self, x: torch.Tensor) -> torch.Tensor:
        n, c, h, w = x.shape
        p = self.patch_size
        torch._assert(h == self.image_size, f"Wrong image height! Expected {self.image_size} but got {h}!")
        torch._assert(w == self.image_size, f"Wrong image width! Expected {self.image_size} but got {w}!")
        n_h = h // p
        n_w = w // p

        # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
        x = self.conv_proj(x)
        # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w))
        x = x.reshape(n, self.hidden_dim, n_h * n_w)

        # (n, hidden_dim, (n_h * n_w)) -> (n, (n_h * n_w), hidden_dim)
        # The self attention layer expects inputs in the format (N, S, E)
        # where S is the source sequence length, N is the batch size, E is the
        # embedding dimension
        x = x.permute(0, 2, 1)

        return x

    def forward(self, x: torch.Tensor, ds: str = "target"):
        # Reshape and permute the input tensor
        x = self._process_input(x)
        n = x.shape[0]

        # Expand the class token to the full batch
        batch_class_token = self.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.encoder(x)

        # Classifier "token" as used by standard language architectures
        x = x[:, 0]

        if ds == "source":
            x = self.src_head(x)
        else:
            x = self.heads(x)

        return x


def ViT_B_16(source_num_classes: int = 1000, target_num_classes: int = 100):

    num_layers = 12
    hidden_dim = 768
    model = ViT(
        image_size=224,
        patch_size=16,
        num_layers=num_layers,
        num_heads=12,
        hidden_dim=hidden_dim,
        mlp_dim=3072,
        source_num_classes=source_num_classes,
        target_num_classes=target_num_classes
    )


    pretrain = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
    state_dict = pretrain.state_dict()
    for k in range(num_layers):
        weight = f"encoder.layers.encoder_layer_{k}.self_attention.in_proj_weight"
        bias = f"encoder.layers.encoder_layer_{k}.self_attention.in_proj_bias"
        q_weight = f"encoder.layers.encoder_layer_{k}.self_attention.query.weight"
        q_bias   = f"encoder.layers.encoder_layer_{k}.self_attention.query.bias"
        k_weight = f"encoder.layers.encoder_layer_{k}.self_attention.key.weight"
        k_bias   = f"encoder.layers.encoder_layer_{k}.self_attention.key.bias"
        v_weight = f"encoder.layers.encoder_layer_{k}.self_attention.value.weight"
        v_bias   = f"encoder.layers.encoder_layer_{k}.self_attention.value.bias"
        state_dict[q_weight] = state_dict[weight][:hidden_dim, :]
        state_dict[q_bias] = state_dict[bias][:hidden_dim]
        state_dict[k_weight] = state_dict[weight][hidden_dim:2*hidden_dim, :]
        state_dict[k_bias] = state_dict[bias][hidden_dim:2*hidden_dim]
        state_dict[v_weight] = state_dict[weight][2*hidden_dim:, :]
        state_dict[v_bias] = state_dict[bias][2*hidden_dim:]
        del state_dict[weight]
        del state_dict[bias]

        old_key = f"encoder.layers.encoder_layer_{k}.self_attention.out_proj.weight"
        new_key = f"encoder.layers.encoder_layer_{k}.self_attention.out.weight"
        state_dict[new_key] = state_dict.pop(old_key)

        old_key = f"encoder.layers.encoder_layer_{k}.self_attention.out_proj.bias"
        new_key = f"encoder.layers.encoder_layer_{k}.self_attention.out.bias"
        state_dict[new_key] = state_dict.pop(old_key)

        old_key = f"encoder.layers.encoder_layer_{k}.mlp.0.weight"
        new_key = f"encoder.layers.encoder_layer_{k}.mlp.linear_1.weight"
        state_dict[new_key] = state_dict.pop(old_key)

        old_key = f"encoder.layers.encoder_layer_{k}.mlp.0.bias"
        new_key = f"encoder.layers.encoder_layer_{k}.mlp.linear_1.bias"
        state_dict[new_key] = state_dict.pop(old_key)

        old_key = f"encoder.layers.encoder_layer_{k}.mlp.3.weight"
        new_key = f"encoder.layers.encoder_layer_{k}.mlp.linear_2.weight"
        state_dict[new_key] = state_dict.pop(old_key)

        old_key = f"encoder.layers.encoder_layer_{k}.mlp.3.bias"
        new_key = f"encoder.layers.encoder_layer_{k}.mlp.linear_2.bias"
        state_dict[new_key] = state_dict.pop(old_key)

    state_dict["src_head.weight"] = state_dict["heads.head.weight"]
    state_dict["src_head.bias"] = state_dict["heads.head.bias"]
    head = nn.Linear(hidden_dim,target_num_classes)
    state_dict["heads.head.weight"] = head.weight
    state_dict["heads.head.bias"] = head.bias
    
    model.load_state_dict(state_dict=state_dict)

    return model
