import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange
from functools import partial
from timm.layers import DropPath
# from flash_cosine_sim_attention import flash_cosine_sim_attention
import math
try:
    from .xshared_modules2_comp import RelativePositionBias, ContinuousPositionBias1D, MLP
except:
    from xshared_modules2_comp import RelativePositionBias, ContinuousPositionBias1D, MLP

def build_time_block(params):
    """
    Builds a time block from the parameter file. 
    """
    if params.time_type == 'attention':
        return partial(
            AttentionBlock,
            params.embed_dim,
            params.num_heads,
            bias_type=params.bias_type,
        )
    else:
        raise NotImplementedError

class InstanceNormNd(nn.Module):          # drop-in replacement
    def __init__(self, num_channels, eps=1e-5, affine=True):
        super().__init__()
        self.norm = nn.GroupNorm(
            num_groups=num_channels,   
            num_channels=num_channels,
            eps=eps,
            affine=affine,
        )

    def forward(self, x):
        return self.norm(x)

    

class AttentionBlock(nn.Module):
    def __init__(self, hidden_dim=768, num_heads=12, drop_path=0, layer_scale_init_value=1e-6, bias_type='rel'):
        super().__init__()
        self.num_heads = num_heads
        self.norm1 = InstanceNormNd(hidden_dim, affine=True)
        self.norm2 = InstanceNormNd(hidden_dim, affine=True)
        # self.norm1 = nn.InstanceNorm1d(hidden_dim, affine=True)
        # self.norm2 = nn.InstanceNorm1d(hidden_dim, affine=True)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((hidden_dim)), 
                            requires_grad=True) if layer_scale_init_value > 0 else None
        # self.input_head = nn.Conv2d(hidden_dim, 3*hidden_dim, 1)
        # self.output_head = nn.Conv2d(hidden_dim, hidden_dim, 1)
        self.input_head = nn.Linear(hidden_dim, 3*hidden_dim)
        self.output_head = nn.Linear(hidden_dim, hidden_dim)
        self.qnorm = nn.LayerNorm(hidden_dim//num_heads)
        self.knorm = nn.LayerNorm(hidden_dim//num_heads)
        if bias_type == 'none':
            self.rel_pos_bias = lambda x, y: None
        elif bias_type == 'continuous':
            self.rel_pos_bias = ContinuousPositionBias1D(n_heads=num_heads)
        else:
            self.rel_pos_bias = RelativePositionBias(n_heads=num_heads)
        self.drop_path = nn.Identity()

    def forward(self, x_list):
        new_x_list = []
        for x in x_list:
            # input is t x b x c x h x w 
            T, B, C, *H = x.shape
            D = len(H)
            axes = {f"s{i}":v for i,v in enumerate(H)}
            init = " ".join(axes.keys())
            input = x.clone()
            # Rearrange and prenorm
            x = rearrange(x, 't b c ... -> (t b) c ...')

            # shape = x.shape
            # x = rearrange(x, 'b c ... -> b c (...)')
            x = self.norm1(x)
            # x = x.view(*shape)

            x = rearrange(x, 'b c ... -> b ... c')
            x = self.input_head(x) # Q, K, V projections
            x = rearrange(x, 'b ... c -> b c ...')
            # Rearrange for attention
            x = rearrange(x, f'(t b) (he c) {init} ->  (b {init}) he t c', t=T, he=self.num_heads)
            q, k, v = x.tensor_split(3, dim=-1)
            q, k = self.qnorm(q), self.knorm(k)
            rel_pos_bias = self.rel_pos_bias(T, T)
            if rel_pos_bias is not None:
                x = F.scaled_dot_product_attention(q, k, v, attn_mask=rel_pos_bias) 
            else:
                x = F.scaled_dot_product_attention(q.contiguous(), k.contiguous(), v.contiguous())
            # Rearrange after attention
            x = rearrange(x, f'(b {init}) he t c -> (t b) (he c) {init}', **axes)

            # shape = x.shape
            # x = rearrange(x, "b c ... -> b c (...)")
            x = self.norm2(x) 
            # x = x.view(*shape)

            x = rearrange(x, 'b c ... -> b ... c')
            x = self.output_head(x)
            x = rearrange(x, 'b ... c -> b c ...')
            x = rearrange(x, '(t b) c ... -> t b c ...', t=T)
            gamma = self.gamma.view(1, 1, -1, *([1]*D))
            output = self.drop_path(x*gamma) + input
            new_x_list.append(output)
        return new_x_list
