from functools import partial
from typing import Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from torch.nn.init import trunc_normal_
from .fm_utils import pair, DropPath


class ClassificationPoolingHead(nn.Module):
    """Output head that replaces a tokenized embedding module for classification task (to be used with pure ViT models)
    Args:
        vocab_size: Vocabulary size
        patch_size: Int or tuple of the patch size over the full image size.
        dim_tokens: Dimension of output tokens. Can be set using init method.
        sincos_pos_emb: Set to True (default) to use fixed 2D sin-cos positional embeddings
        image_size: Default image size. Used to initialize size of positional embeddings.
    """
    def __init__(self,
                 num_classes: int,
                 patch_size: Union[int, Tuple[int,int]] = 16,
                 dim_tokens: Optional[int] = None,
                 image_size: Union[int, Tuple[int]] = 224,
                 norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6),
                 init_scale: float = 1.0,
                 use_cls_token: bool = False,
                 **kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.patch_size = pair(patch_size)
        self.dim_tokens = dim_tokens
        self.image_size = pair(image_size)
        self.num_patches = (self.image_size[0] // self.patch_size[0]) * (self.image_size[1] // self.patch_size[1])
        self.norm_layer = norm_layer
        self.init_scale = init_scale
        self.use_cls_token = use_cls_token

        if self.dim_tokens is not None:
            self.init(dim_tokens=dim_tokens)

    def init(self, dim_tokens: int = 768):
        """
        Initialize parts of module that are dependent on dimension of tokens.

        Args:
            dim_tokens: Dimension of tokens
        """
        self.dim_tokens = dim_tokens

        # Classification layer
        self.norm = self.norm_layer(self.dim_tokens)
        self.head = nn.Linear(self.dim_tokens, self.num_classes) if self.num_classes > 0 else nn.Identity()

        self.head.weight.data.mul_(self.init_scale)
        self.head.bias.data.mul_(self.init_scale)


    @torch.jit.ignore
    def no_weight_decay(self):
        # TODO: Check if excluding embeddings works
        # return {'pos_emb', 'mod_emb', 'token_emb', 'to_logits'}
        return set()

    def reset_classifier(self, num_classes=None):
        """Resets the classifier weights
        FourM init can override the init scale so calling this afterwards 
        will ensure that the classifier head has its proper value
        """
        if num_classes is not None:
            self.num_classes = num_classes
        self.init(dim_tokens=self.dim_tokens)

    def forward_logits(self, x):
        """
        Forward pass through the classification head, pooling all embeddings and outputting a class.

        :param x: Output tokens from the decoder
        """
        # Average pooling
        x = x.reshape(-1, self.num_patches, self.dim_tokens)
        if self.use_cls_token:
            x = x[:, 0]
        else:
            x = x.mean(1)
        x = self.head(self.norm(x))
        return x

    def forward(self, x):
        return self.forward_logits(x)


class ConvNeXtBlock(nn.Module):
    """Based on ConvNeXt code base: https://github.com/facebookresearch/ConvNeXt
    ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch

    Args:
        dim (int): Number of input channels.
        drop_path: Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 0 (disabled for isotropic ConvNeXt).

    Code from: https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py
    """

    def __init__(self, dim, drop_path=0., layer_scale_init_value=0.):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
        self.norm = nn.LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim)  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x

class ConvNeXtHead(nn.Module):
    """Output head with ConvNext blocks for semantic segmentation

    Args:
        num_classes: Number of classes
        num_heads: Number of attention heads
        embed_dim: Token dimension after projection, and before reshaping operation.
        preds_per_patch: Increases size of feature map by reshaping each patch  Each patch gets reshaped
        from embed_dim x 1 x 1 to (embed_dim / preds_per_patch) x (preds_per_patch ** 0.5) x (preds_per_patch ** 0.5)
        patch_size: Size of patches
        depth: Number of ConvNeXt blocks
        interpolate_mode: Interpolation mode for final upsampling
    """

    def __init__(
            self,
            num_classes,
            img_size,
            embed_dim: int = 6144,
            preds_per_patch: int = 16,
            patch_size: int = 16,
            depth: int = 4,
            interpolate_mode: str = 'bilinear',
            **kwargs,
    ):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.preds_per_patch = preds_per_patch
        self.class_dim = embed_dim // preds_per_patch
        self.num_classes = num_classes
        self.image_size = pair(img_size)
        self.interpolate_mode = interpolate_mode

        self.blocks = nn.Sequential(*[
            ConvNeXtBlock(dim=self.class_dim)
            for _ in range(depth)
        ])
        self.final_layer = nn.Conv2d(self.class_dim, self.num_classes, 1)
        self.apply(self._init_weights)

    def init(self, dim_tokens: int = 768):
        """
        Initialize parts of decoder that are dependent on dimension of encoder tokens.

        Args:
            dim_tokens: Dimension of tokens coming from backbone
        """
        self.in_channels = dim_tokens

        # Projection of encoder tokens to the patch dimension
        self.proj_dec = nn.Linear(self.in_channels, self.embed_dim)
        self._init_weights(self.proj_dec)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        H, W = self.image_size[0], self.image_size[1]
        N_H, N_W = H // self.patch_size, W // self.patch_size

        x = self.proj_dec(x)
        x = rearrange(x, "b n (p c) -> b (n p) c", n=N_H * N_W, p=self.preds_per_patch, c=self.class_dim)
        x = rearrange(x, "b (nh nw ph pw) c -> b c (nh ph) (nw pw)",
                      nh=N_H, nw=N_W,
                      ph=int(self.preds_per_patch ** 0.5),
                      pw=int(self.preds_per_patch ** 0.5))
        x = self.blocks(x)
        x = self.final_layer(x)

        # Interpolate to semseg res
        x = F.interpolate(x, size=(H, W), mode=self.interpolate_mode)

        return x