import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from einops import rearrange, repeat
import math
from functools import partial
from timm.layers import DropPath
try:
    from .xshared_modules2 import RelativePositionBias, ContinuousPositionBias1D, MLP
except:
    from xshared_modules2 import RelativePositionBias, ContinuousPositionBias1D, MLP
    

# Param builder func
    
def build_space_block(params):
    if params.space_type == 'axial_attention':
        return partial(AxialAttentionBlock, params.embed_dim, params.num_heads, bias_type=params.bias_type)
    else:
        raise NotImplementedError

### Space utils

class RMSInstanceNormNd(nn.Module):
    def __init__(self, dim, affine=True, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.affine = affine
        if affine:
            self.weight = nn.Parameter(torch.ones(dim))
            self.bias = nn.Parameter(torch.zeros(dim)) # Forgot to remove this so its in the pretrained weights
    
    def forward(self, x):
        B, C, *H = x.shape
        D = len(H)
        stats_dim = tuple(range(2,D+2))
        std, mean = torch.std_mean(x, dim=stats_dim, keepdims=True)
        x = (x) / (std + self.eps)
        if self.affine:
            x = x * self.weight.view(1, -1, *([1]*D))
        return x

class RMSInstanceNorm2d(nn.Module):
    def __init__(self, dim, affine=True, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.affine = affine
        if affine:
            self.weight = nn.Parameter(torch.ones(dim))
            self.bias = nn.Parameter(torch.zeros(dim)) # Forgot to remove this so its in the pretrained weights
    
    def forward(self, x):
        std, mean = torch.std_mean(x, dim=(-2, -1), keepdims=True)
        x = (x) / (std + self.eps)
        if self.affine:
            x = x * self.weight[None, :, None, None]  
        return x

class RMSInstanceNorm1d(nn.Module):
    def __init__(self, dim, affine=True, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.affine = affine
        if affine:
            self.weight = nn.Parameter(torch.ones(dim))
            self.bias = nn.Parameter(torch.zeros(dim)) # Forgot to remove this so its in the pretrained weights
    
    def forward(self, x):
        std, mean = torch.std_mean(x, dim=(-1), keepdims=True)
        x = (x) / (std + self.eps)
        if self.affine:
            x = x * self.weight[None, :, None]  
        return x
    
class SubsampledLinear(nn.Module):
    """
    Cross between a linear layer and EmbeddingBag - takes in input 
    and list of indices denoting which state variables from the state
    vocab are present and only performs the linear layer on rows/cols relevant
    to those state variables
    
    Assumes (... C) input
    """
    def __init__(self, dim_in, dim_out, subsample_in=True):
        super().__init__()
        self.subsample_in = subsample_in
        self.dim_in = dim_in
        self.dim_out = dim_out
        temp_linear = nn.Linear(dim_in, dim_out)
        self.weight = nn.Parameter(temp_linear.weight)
        self.bias = nn.Parameter(temp_linear.bias)
    
    def forward(self, x, labels):
        # Note - really only works if all batches are the same input type
        labels = labels[0] # Figure out how to handle this for normal batches later
        label_size = len(labels)
        if self.subsample_in:
            scale = (self.dim_in / label_size)**.5 # Equivalent to swapping init to correct for given subsample of input
            x = scale * F.linear(x, self.weight[:, labels], self.bias)
        else:
            x = F.linear(x, self.weight[labels], self.bias[labels])
        return x

def adaptive_avg_pool_nd(x, output_size):
    dim = x.ndim - 2  # spatial dimensions
    if dim == 1:
        return F.adaptive_avg_pool1d(x, output_size)
        # return F.adaptive_max_pool1d(x, output_size)
    elif dim == 2:
        return F.adaptive_avg_pool2d(x, output_size)
        # return F.adaptive_max_pool2d(x, output_size)
    elif dim == 3:
        return F.adaptive_avg_pool3d(x, output_size)
        # return F.adaptive_max_pool3d(x, output_size)
    else:
        raise ValueError(f"Unsupported number of spatial dimensions: {dim}")

class LayerNorm(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.layernorm = nn.LayerNorm(*args, **kwargs)

    def forward(self, x):
        x = rearrange(x, "b c ... -> b ... c")
        x = self.layernorm(x)
        x = rearrange(x, "b ... c -> b c ...")
        return x

class hMLP_stem(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, patch_size=(16,16), in_chans=3, embed_dim =768):
        super().__init__()
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        patch_size = patch_size[0]
        p3 = int(patch_size**(1/4))
        p2 = p3
        p1 = patch_size // p2 // p3
        self.in_proj1 = nn.Conv2d(in_chans, embed_dim//4, kernel_size=p1, stride=p1, bias=False)
        self.in_proj2 = LayerNorm(embed_dim//4)
        self.in_proj3 = nn.GELU()
        self.in_proj4 = nn.Conv2d(embed_dim//4, embed_dim//4, kernel_size=p2, stride=p2, bias=False)
        self.in_proj5 = LayerNorm(embed_dim//4)
        self.in_proj6 = nn.GELU()
        self.in_proj7 = nn.Conv2d(embed_dim//4, embed_dim, kernel_size=p3, stride=p3, bias=False)
        self.p1 = p1
        self.p2 = p2
        self.p3 = p3

        self.out_norm = LayerNorm(embed_dim)
        
    def single_forward(self, x):
        B, C, *H = x.shape
        D = len(H)
        _D = D
        if _D == 1:
            h = H[0]
            # x = F.pad(x[...,None,:], (0, 0, 0, self.patch_size[0]-1), value=0)
            eye = torch.eye(x.size(-1), device=x.device)
            x = x.unsqueeze(-1) * eye.unsqueeze(0)
            H = [h, h]
            D = 2
        axes = {f"s{i}": v for i, v in enumerate(H)}
        keys = list(axes.keys())
        init = " ".join(keys)
        patch_size = self.patch_size[0]
        out_size1 = [h//self.p1 for h in H]
        out_size2 = [h//self.p2 for h in out_size1]
        out_size3 = [h//self.p3 for h in out_size2]

        x_mean = -float("inf") * torch.ones((B, self.embed_dim, *out_size3), device=x.device)
        for j in range(D-1):
            node = keys[-1]
            nhbr = keys[j]
            rest = keys[:j] + keys[j+1:-1]
            rest_axes = {k: axes[k] for k in rest}
            rest = " ".join(rest)

            n = rearrange(x, f"b c {init} -> (b {rest}) c {nhbr} {node}")

            n = self.in_proj1(n)
            n = rearrange(n, f"(b {rest}) c {nhbr} {node} -> b c {init}", **rest_axes)
            n = adaptive_avg_pool_nd(n, out_size1)

            rest_axes = {k: v//self.p1 for k, v in rest_axes.items()}
            n = rearrange(n, f"b c {init} -> (b {rest}) c {nhbr} {node}")
            n = self.in_proj2(n)
            n = self.in_proj3(n)

            n = self.in_proj4(n)
            n = rearrange(n, f"(b {rest}) c {nhbr} {node} -> b c {init}", **rest_axes)
            n = adaptive_avg_pool_nd(n, out_size2)

            rest_axes = {k: v//self.p2 for k,v in rest_axes.items()}
            n = rearrange(n, f"b c {init} -> (b {rest}) c {nhbr} {node}")
            n = self.in_proj5(n)
            n = self.in_proj6(n)
            
            n = self.in_proj7(n)
            n = rearrange(n, f"(b {rest}) c {nhbr} {node} -> b c {init}", **rest_axes)
            n = adaptive_avg_pool_nd(n, out_size3)

            x_mean = torch.maximum(x_mean, n)
        x_mean = self.out_norm(x_mean)
        if _D == 1:
            x_mean = x_mean.mean(-1)
        return x_mean

    def forward(self, x_list):
        # Lifting
        return [self.single_forward(x) for x in x_list]
    
    
class hMLP_output(nn.Module):
    """ Patch to Image De-bedding
    """
    def __init__(self, patch_size=(16,16), out_chans=3, embed_dim=768):
        super().__init__()
        self.patch_size = patch_size
        self.out_chans = out_chans
        self.embed_dim = embed_dim
        patch_size = patch_size[0]
        p3 = int(patch_size**(1/4))
        p2 = p3
        p1 = patch_size // p2 // p3
        self.p1 = p1
        self.p2 = p2
        self.p3 = p3
        self.out_proj1 = nn.ConvTranspose2d(embed_dim, embed_dim//4, kernel_size=p3, stride=p3, bias=False)
        self.out_proj2 = LayerNorm(embed_dim//4)
        self.out_proj3 = nn.GELU()
        self.out_proj4 = nn.ConvTranspose2d(embed_dim//4, embed_dim//4, kernel_size=p2, stride=p2, bias=False)
        self.out_proj5 = LayerNorm(embed_dim//4)
        self.out_proj6 = nn.GELU()

        out_head = nn.ConvTranspose2d(embed_dim//4, out_chans, kernel_size=p1, stride=p1)
        self.out_kernel = nn.Parameter(out_head.weight)
        self.out_bias = nn.Parameter(out_head.bias)

        self.nhbr_proj1 = nn.ConvTranspose2d(embed_dim, embed_dim//4, kernel_size=p3, stride=p3, bias=False)
        self.nhbr_proj2 = LayerNorm(embed_dim//4)
        self.nhbr_proj3 = nn.GELU()
        self.nhbr_proj4 = nn.ConvTranspose2d(embed_dim//4, embed_dim//4, kernel_size=p2, stride=p2, bias=False)
        self.nhbr_proj5 = LayerNorm(embed_dim//4)
        self.nhbr_proj6 = nn.GELU()

        nhbr_head = nn.ConvTranspose2d(embed_dim//4, out_chans, kernel_size=p1, stride=p1)
        self.nhbr_kernel = nn.Parameter(nhbr_head.weight)
        self.nhbr_bias = nn.Parameter(nhbr_head.bias)
    
    def forward(self, x_list, state_labels=None):
        B, C, *H = x_list[0].shape
        D = len(H)
        _D = D
        new_x_list = []
        patch_size = self.patch_size[0]
        for i, x in enumerate(x_list):
            if _D == 1:
                # x = repeat(x, "b c ... w -> b c ... h w", h=1)
                eye = torch.eye(x.size(-1), device=x.device)
                x = x.unsqueeze(-1) * eye.unsqueeze(0)
                h = H[0]
                H = [h, h]
                D = 2
                
            out_size1 = [h*self.p3 for h in H]
            out_size2 = [h*self.p2 for h in out_size1]
            out_size3 = [h*self.p1 for h in out_size2]
            b, c, *h = x.shape

            rest_axes = {f"s{i}": h[i] for i in range(D-2)}
            rest = " ".join(rest_axes.keys())
            n = rearrange(x, f"b c {rest} h w -> (b {rest}) c h w")

            n = self.out_proj1(n)
            n = rearrange(n, f"(b {rest}) c h w -> b c {rest} h w", **rest_axes)
            n = adaptive_avg_pool_nd(n, out_size1)

            rest_axes = {k: v*self.p3 for k, v in rest_axes.items()}
            n = rearrange(n, f"b c {rest} h w -> (b {rest}) c h w")
            n = self.out_proj2(n)
            n = self.out_proj3(n)

            n = self.out_proj4(n)
            n = rearrange(n, f"(b {rest}) c h w -> b c {rest} h w", **rest_axes)
            n = adaptive_avg_pool_nd(n, out_size2)

            rest_axes = {k: v*self.p2 for k, v in rest_axes.items()}
            n = rearrange(n, f"b c {rest} h w -> (b {rest}) c h w")
            n = self.out_proj5(n)
            n = self.out_proj6(n)

            n = F.conv_transpose2d(
                n,
                self.out_kernel[:, state_labels],
                self.out_bias[state_labels],
                stride=self.p1
            )
            n = rearrange(n, f"(b {rest}) c h w -> b c {rest} h w", **rest_axes)
            n = adaptive_avg_pool_nd(n, out_size3)
            x_mean = n
            for j, y in enumerate(x_list):
                if i == j:
                    continue
                b, c, *h = y.shape
                rest_axes = {f"s{i}": h[i] for i in range(D-2)}
                rest = " ".join(rest_axes.keys())
                y = rearrange(y, f"b c {rest} h w -> (b {rest}) c h w")

                y = self.nhbr_proj1(y)
                y = rearrange(y, f"(b {rest}) c h w -> b c {rest} h w", **rest_axes)
                y = adaptive_avg_pool_nd(y, out_size1)

                rest_axes = {k: v*self.p3 for k, v in rest_axes.items()}
                y = rearrange(y, f"b c {rest} h w -> (b {rest}) c h w")
                y = self.nhbr_proj2(y)
                y = self.nhbr_proj3(y)

                y = self.nhbr_proj4(y)
                y = rearrange(y, f"(b {rest}) c h w -> b c {rest} h w", **rest_axes)
                y = adaptive_avg_pool_nd(y, out_size2)

                rest_axes = {k: v*self.p2 for k, v in rest_axes.items()}
                y = rearrange(y, f"b c {rest} h w -> (b {rest}) c h w")
                y = self.nhbr_proj5(y)
                y = self.nhbr_proj6(y)

                y = F.conv_transpose2d(
                    y,
                    self.nhbr_kernel[:, state_labels],
                    self.nhbr_bias[state_labels],
                    stride=self.p1
                )
                y = rearrange(y, f"(b {rest}) c h w -> b c {rest} h w", **rest_axes)
                y = adaptive_avg_pool_nd(y, out_size3)

                y = torch.swapaxes(y, -1, j+2)
                y = torch.swapaxes(y, i+2, -1)
                x_mean = torch.maximum(x_mean, y)
            if _D == 1:
                x_mean = x_mean.mean(-1)
            new_x_list.append(x_mean)
        return new_x_list

class UpsampledLinear(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        temp_linear = nn.Linear(dim_in, dim_out)
        self.weight = nn.Parameter(temp_linear.weight)
        self.bias = nn.Parameter(temp_linear.bias)

    def forward(self, x, state_labels): # TODO verifying
        # x (16, 48, 128, 128)
        print(torch.cuda.current_device(), "x", x.shape)
        print(torch.cuda.current_device(), "self.weight", self.weight.shape)
        print(torch.cuda.current_device(), "state_labels", state_labels)
        x = F.linear(x, self.weight[state_labels], self.bias[state_labels])
        return x
   
class AxialAttentionBlock(nn.Module):
    def __init__(self, hidden_dim=768, num_heads=12,  drop_path=0, layer_scale_init_value=1e-6, bias_type='rel'):
        super().__init__()
        self.num_heads = num_heads
        # self.norm1 = RMSInstanceNormNd(hidden_dim, affine=True)
        self.norm1 = LayerNorm(hidden_dim)
        # self.norm2 = RMSInstanceNormNd(hidden_dim, affine=True)
        self.norm2 = LayerNorm(hidden_dim)
        self.gamma_att = nn.Parameter(layer_scale_init_value * torch.ones((hidden_dim)), 
                            requires_grad=True) if layer_scale_init_value > 0 else None
        self.gamma_mlp = nn.Parameter(layer_scale_init_value * torch.ones((hidden_dim)), 
                            requires_grad=True) if layer_scale_init_value > 0 else None
        
        self.input_head = nn.Linear(hidden_dim, 3*hidden_dim)
        self.output_head = nn.Linear(hidden_dim, hidden_dim)
        self.qnorm = nn.LayerNorm(hidden_dim//num_heads)
        self.knorm = nn.LayerNorm(hidden_dim//num_heads)
        if bias_type == 'none':
            self.rel_pos_bias = lambda x, y: None
        elif bias_type == 'continuous':
            self.rel_pos_bias = ContinuousPositionBias1D(n_heads=num_heads)
        else:
            self.rel_pos_bias = RelativePositionBias(n_heads=num_heads)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()


        self.mlp = MLP(hidden_dim)
        # self.mlp_norm = RMSInstanceNormNd(hidden_dim, affine=True)
        self.mlp_norm = LayerNorm(hidden_dim)

    def single_forward(self, x, bcs):
        # input is t x b x c x h x w 
        B, C, *H = x.shape
        D = len(H)
        input = x.clone()
        x = self.norm1(x)
        x = rearrange(x, 'b c ... -> b ... c')
        x = self.input_head(x)
        x = rearrange(x, 'b ... c -> b c ...')

        x = rearrange(x, 'b (he c) ... ->  b he ... c', he=self.num_heads)
        q, k, v = x.tensor_split(3, dim=-1)
        q, k = self.qnorm(q), self.knorm(k)

        axes = {f"s{i}": v for i, v in enumerate(H)}
        keys = list(axes.keys())
        init = " ".join(keys)
        x_mean = 0
        for i in range(D):
            node = keys[i]
            nhbr = keys[:i] + keys[i+1:]
            rest = " ".join(nhbr)
            # Do attention with current q, k, v matrices along each spatial axis then average results
            # X direction attention
            qx, kx, vx = map(lambda x: rearrange(x, f'b he {init} c ->  (b {rest}) he {node} c'), [q,k,v])
            rel_pos_bias_x = self.rel_pos_bias(H[i], H[i], bcs[0, 0])
            # Functional doesn't return attention mask :(
            if rel_pos_bias_x is not None:
                xx = F.scaled_dot_product_attention(qx, kx, vx, attn_mask=rel_pos_bias_x)
            else:
                xx = F.scaled_dot_product_attention(qx.contiguous(), kx.contiguous(), vx.contiguous())
            xx = rearrange(xx, f'(b {rest}) he {node} c -> b (he c) {init}', **axes)
            x_mean += xx / D

        # Combine
        x = x_mean
        x = self.norm2(x)
        x = rearrange(x, 'b c ... -> b ... c')
        x = self.output_head(x)
        x = rearrange(x, 'b ... c -> b c ...')
        gamma_att = self.gamma_att.view(1, -1, *([1]*D))
        x = self.drop_path(x*gamma_att) + input

        # MLP
        input = x.clone()
        x = rearrange(x, 'b c ... -> b ... c')
        x = self.mlp(x)
        x = rearrange(x, 'b ... c -> b c ...')
        x = self.mlp_norm(x)
        gamma_mlp = self.gamma_mlp.view(1, -1, *([1]*D))
        output = input + self.drop_path(gamma_mlp * x)

        return output

    def forward(self, x_list, bcs):
        return [self.single_forward(x, bcs) for x in x_list]