import torch
import math
import torch.nn as nn
from itertools import repeat
import collections.abc
from functools import partial
from natten.functional import na1d_qk, na1d_av


# From PyTorch internals
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
            return tuple(x)
        return tuple(repeat(x, n))
    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)

class Attention(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int = 8,
            qkv_bias: bool = False,
            qk_norm: bool = False,
            attn_drop: float = 0.,
            proj_drop: float = 0.,
            kernel_size: int = 10,
            dilation: int = 4,
            norm_layer: nn.Module = nn.LayerNorm,
    ) -> None:
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.kernel_size = kernel_size
        self.dilation = dilation

    def forward(self, x: torch.Tensor, h: torch.Tensor = None) -> torch.Tensor:
        B, N, C = x.shape
        qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k = qk.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)
        if h is None:
            h = x
        v = self.v(h).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attn_1d = na1d_qk(
            q, k,
            kernel_size=self.kernel_size, dilation=self.dilation,
        )
        attn_1d = attn_1d.softmax(dim=-1)
        x = na1d_av(
            attn_1d, v,
            kernel_size=self.kernel_size, dilation=self.dilation,
        )
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.GELU,
            norm_layer=None,
            bias=True,
            drop=0.,
            use_conv=False,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        bias = to_2tuple(bias)
        drop_probs = to_2tuple(drop)
        linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear

        self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop_probs[0])
        self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
        self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
        self.drop2 = nn.Dropout(drop_probs[1])

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.norm(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x

class DiGBlock(torch.nn.Module):

    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn_x = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.attn_h = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)

    def forward(self, x, h):
        # c: batchsize, 1, hiddensize
        # x: batchsize, nodesize, hiddensize
        # shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + self.attn_h(self.norm2(h), self.norm1(x))
        x = x + self.attn_x(self.norm1(x))
        x = x + self.mlp(self.norm1(x))
        return x


def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class FinalLayer(torch.nn.Module):
    """
    The final layer of DiT.
    """
    def __init__(self, hidden_size, output_size):
        super().__init__()
        # self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, output_size, bias=True)

    def forward(self, x):
        x = self.linear(x)
        return x

class LatentEmbedder(nn.Module):
    """

    """
    def __init__(self, inputsize, hidden_size):
        super().__init__()
        self.embedding_table = nn.Embedding(inputsize, hidden_size)

    def forward(self, latent):
        embeddings = self.embedding_table(latent)
        return embeddings

class DilatedGraph(torch.nn.Module):
    """
    Diffusion model with a Transformer backbone.
    """
    def __init__(
        self,
        x_input_size=3,
        h_input_size=4,
        patch_size=2,
        in_channels=4,

        hidden_size=512,
        depth=12,
        num_heads=16,
        mlp_ratio=4.0,
        learn_sigma=True,
        kernel_size=9,
        dilation=4,
        dtype=torch.float32,
    ):
        super().__init__()
        self.learn_sigma = learn_sigma
        self.in_channels = in_channels
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        self.patch_size = patch_size
        self.num_heads = num_heads
        self.output_size = x_input_size + h_input_size
        self.x_embedder = nn.Linear(x_input_size, hidden_size, bias=True)
        self.h_embedder = nn.Linear(h_input_size, hidden_size, bias=True)
        emt_blocks = []
        for i in range(depth):
            emt_blocks.append(DiGBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, kernel_size=kernel_size, dilation=i*dilation+1))
        self.x_blocks = nn.ModuleList(emt_blocks)
        self.x_final_layer = FinalLayer(hidden_size, x_input_size)
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

        nn.init.normal_(self.x_embedder.weight, std=0.02)

        # Initialize h embedding:
        nn.init.normal_(self.h_embedder.weight, std=0.02)

        # Zero-out output layers:
        for final_layer in [self.x_final_layer]:
            nn.init.constant_(final_layer.linear.weight, 0)
            nn.init.constant_(final_layer.linear.bias, 0)

    def forward(self, x, h, pro_mol_cutoff, node_mask):
        x = self.x_embedder(x)
        h = self.h_embedder(h)
        for block in self.x_blocks:
            x = block(x, h)                      # (N, T, D)
            x = x * node_mask
        x = self.x_final_layer(x)                # (N, T, patch_size ** 2 * out_channels)
        # x = x_h[:, :, :x.size(-1)]
        x = x * node_mask
        return x


    @torch.no_grad()
    def simulate(self, x, h, pro_mol_cutoff, node_mask):
        return self.forward(x, h, pro_mol_cutoff, node_mask)
