import collections.abc
import math
import sys
from itertools import repeat

import matplotlib.pyplot as plt
import numpy as np
import timm
import torch
from torch import nn
from torchvision.models.vision_transformer import Encoder 


from typing import Tuple
from functools import partial
from collections.abc import Iterable   # import directly from collections for Python < 3.3


def plot_fbank(fbank, title=None, save_path=None, **kwargs):
    fig, axs = plt.subplots(min(4, fbank.shape[0]), 1, sharex=True, sharey=True)
    if not isinstance(axs, Iterable): 
        axs = np.array([axs])
    vmin, vmax = kwargs.get("vmin", None), kwargs.get("vmax", None)
    # max 4 channels...
    for channel in range(0, min(4, fbank.shape[0])):
        axs[channel].set_title(f"Filter bank channel {channel}, {title}")
        im = axs[channel].imshow(fbank[channel].T, aspect="auto", vmin=vmin, vmax=vmax)
        axs[channel].set_ylabel("mel")
        axs[channel].set_xlabel("time")
    plt.gca().invert_yaxis()
    plt.tight_layout()
    fig.colorbar(im, ax=axs.ravel().tolist())
    plt.show()
    if save_path:
        fig.savefig(save_path)
    plt.close()
    return fig


# From PyTorch Internals to create the tuples of the given iterable.
def _ntuple(n):
    def parse(x):
        # if x is already an instance of iterable object, create a tuple out of it
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        # Otherwise repeat the x, n times, and create a tuple.
        return tuple(repeat(x, n))

    return parse


class PatchEmbed(nn.Module):
    """Image to Patch Embedding"""

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = _ntuple(2)(img_size)
        patch_size = _ntuple(2)(patch_size)
        num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(
            in_channels=in_chans,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    # We need to override these.
    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x


def get_sinusoid_encoding(n_position, d_hid):
    """Sinusoid position encoding table"""

    def get_position_angle_vec(position):
        return [
            position / np.power(10000, 2 * (hid_j // 2) / d_hid)
            for hid_j in range(d_hid)
        ]

    sinusoid_table = np.array(
        [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
    )
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # dim 2i
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # dim 2i+1

    return torch.FloatTensor(sinusoid_table).unsqueeze(0)


def create_pretrained_model(model_size):
    if model_size == "tiny":
        v = timm.create_model("deit_tiny_distilled_patch16_224", pretrained=False)
        hidden_dim = 182
    
    elif model_size == "small":
        v = timm.create_model("deit_small_distilled_patch16_224", pretrained=False)
        hidden_dim = 384
    
    elif model_size == "base":
        print("Using Flash Attention")
        v = Encoder(
            seq_length = 0, #Only used for pos_embeddings and we set them later!
            num_layers = 12,
            num_heads = 12,
            hidden_dim = 768,
            mlp_dim= 3072,
            dropout = 0.0,
            attention_dropout = 0.0,
            norm_layer = partial(nn.LayerNorm, eps=1e-6))
        hidden_dim = 768
    
    elif model_size == "base_nokd":
        v = timm.create_model("deit_base_patch16_384", pretrained=False)
        hidden_dim = 768

    else:
        print("Wrong model size!")
        sys.exit(0)

    return v, hidden_dim


def _trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    left = norm_cdf((a - mean) / std)
    up = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * left - 1, 2 * up - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.0))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor


def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
    # type: (Tensor, float, float, float, float) -> Tensor
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
    applied while sampling the normal with mean/std applied, therefore a, b args
    should be adjusted to match the range of mean, std args.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)


def expand_index_like(index: torch.Tensor, tokens: torch.Tensor) -> torch.Tensor:
    """Expands the index along the last dimension of the input tokens.

    Args:
        index:
            Index tensor with shape (batch_size, idx_length) where each entry is
            an index in [0, sequence_length).
        tokens:
            Tokens tensor with shape (batch_size, sequence_length, dim).

    Returns:
        Index tensor with shape (batch_size, idx_length, dim) where the original
        indices are repeated dim times along the last dimension.

    """
    dim = tokens.shape[-1]
    index = index.unsqueeze(-1).expand(-1, -1, dim)
    return index
    
def set_at_index(
    tokens: torch.Tensor, index: torch.Tensor, value: torch.Tensor
) -> torch.Tensor:
    """Copies all values into the input tensor at the given indices.

    Args:
        tokens:
            Tokens tensor with shape (batch_size, sequence_length, dim).
        index:
            Index tensor with shape (batch_size, index_length).
        value:
            Value tensor with shape (batch_size, index_length, dim).

    Returns:
        Tokens tensor with shape (batch_size, sequence_length, dim) containing
        the new values.

    """
    index = expand_index_like(index, tokens)
    return torch.scatter(tokens, 1, index, value)




def repeat_token(token: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
    """Repeats a token size times.

    Args:
        token:
            Token tensor with shape (1, 1, dim).
        size:
            (batch_size, sequence_length) tuple.

    Returns:
        Tensor with shape (batch_size, sequence_length, dim) containing copies
        of the input token.

    """
    batch_size, sequence_length = size
    return token.repeat(batch_size, sequence_length, 1)