import math
import numpy as np
import torch
import torch.nn as nn
# import dgl
from einops import repeat, rearrange
from torch.nn import functional as F
from torch.nn import GELU, ReLU, Tanh, Sigmoid
from torch.nn.utils.rnn import pad_sequence
from kappamodules.layers import ContinuousSincosEmbed, LinearProjection
from kappamodules.transformer import PerceiverPoolingBlock, Mlp, DitBlock, PrenormBlock, PerceiverBlock

from util.misc import MultipleTensors
from models.mlp import MLP
from functools import partial


class encoder_perceiver(nn.Module):
    def __init__(
            self,
            ndim=2,
            dim=128,
            num_attn_heads=4,
            num_output_tokens=256,
            add_type_token=False,
            init_weights="xavier_uniform",
            init_last_proj_zero=False,
            **kwargs,
    ):
        super().__init__(**kwargs)
        self.ndim = ndim
        self.dim = dim
        self.num_attn_heads = num_attn_heads
        self.num_output_tokens = num_output_tokens
        self.add_type_token = add_type_token

        # set ndim
        # _, ndim = self.input_shape
        # self.static_ctx["ndim"] = ndim

        # pos_embed
        self.mlp = Mlp(in_dim=ndim, hidden_dim=dim * 4, out_dim=dim, init_weights=init_weights)
        # if ndim <= 3:
        #     self.pos_embed = ContinuousSincosEmbed(dim=dim, ndim=ndim)
        #     self.mlp = Mlp(in_dim=dim, hidden_dim=dim * 4, init_weights=init_weights)
        # else:
        #     self.mlp = Mlp(in_dim=ndim, hidden_dim=dim * 4, out_dim=dim, init_weights=init_weights)

        # perceiver
        self.block = PerceiverPoolingBlock(
            dim=dim,
            num_heads=num_attn_heads,
            num_query_tokens=num_output_tokens,
            perceiver_kwargs=dict(
                init_weights=init_weights,
                init_last_proj_zero=init_last_proj_zero,
            ),
        )

        if add_type_token:
            self.type_token = nn.Parameter(torch.empty(size=(1, 1, dim,)))
        else:
            self.type_token = None

        # output shape
        self.output_shape = (num_output_tokens, dim)


    def forward(self, mesh_pos):
        
        # if self.ndim <= 3:
        #     x = self.pos_embed(mesh_pos)
        # else:
        #     x = mesh_pos
        # x, mask = to_dense_batch(x, batch_idx)
        # if torch.all(mask):
        #     mask = None
        # else:
        #     # add dimensions for num_heads and query (keys are masked)
        #     mask = einops.rearrange(mask, "batchsize num_nodes -> batchsize 1 1 num_nodes")

        # perceiver
        # x = self.mlp(x)
        x = self.mlp(mesh_pos)
        x = self.block(kv=x)

        if self.add_type_token:
            x = x + self.type_token

        return x
    
    
    
class latent(nn.Module):
    def __init__(
            self,
            input_dim=128,
            condition_dim = None,
            dim=128,
            depth=6,
            num_attn_heads=4,
            drop_path_rate=0.0,
            drop_path_decay=True,
            init_weights="xavier_uniform",
            init_last_proj_zero=False,
            **kwargs,
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.depth = depth
        self.num_attn_heads = num_attn_heads
        self.drop_path_rate = drop_path_rate
        self.drop_path_decay = drop_path_decay
        self.init_weights = init_weights
        self.init_last_proj_zero = init_last_proj_zero

        # input/output shape
        # assert len(self.input_shape) == 2
        # seqlen, input_dim = self.input_shape
        # self.output_shape = (seqlen, dim)

        self.input_proj = LinearProjection(input_dim, dim, init_weights=init_weights)

        # blocks
        if condition_dim is not None:
            block_ctor = partial(DitBlock, cond_dim=condition_dim)
        else:
            block_ctor = PrenormBlock
            
        if drop_path_decay:
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        else:
            dpr = [drop_path_rate] * depth
        self.blocks = nn.ModuleList([
            block_ctor(
                dim=dim,
                num_heads=num_attn_heads,
                drop_path=dpr[i],
                init_weights=init_weights,
                init_last_proj_zero=init_last_proj_zero,
            )
            for i in range(self.depth)
        ])

    def forward(self, x, condition=None, static_tokens=None):
        # assert x.ndim == 3

        # concat static tokens
        if static_tokens is not None:
            x = torch.cat([static_tokens, x], dim=1)

        # input projection
        x = self.input_proj(x)

        # apply blocks
        blk_kwargs = dict(cond=condition) if condition is not None else dict()
        for blk in self.blocks:
            x = blk(x, **blk_kwargs)

        # remove static tokens
        if static_tokens is not None:
            num_static_tokens = static_tokens.size(1)
            x = x[:, num_static_tokens:]

        return x
    
    
    
    
class decoder_perceiver(nn.Module):
    def __init__(
            self,
            input_dim=128,
            output_dim=1,
            ndim=2,
            dim=128,
            num_attn_heads=4,
            init_weights="xavier_uniform",
            init_last_proj_zero=False,
            use_last_norm=False,
            **kwargs,
    ):
        super().__init__(**kwargs)
        self.dim = dim
        self.num_attn_heads = num_attn_heads
        self.use_last_norm = use_last_norm

        # input projection
        # _, input_dim = self.input_shape
        self.proj = LinearProjection(input_dim, dim, init_weights=init_weights)

        # query tokens (create them from a positional embedding)
        # self.pos_embed = ContinuousSincosEmbed(dim=dim, ndim=ndim)
        # self.pos_embed = Mlp(in_dim=ndim, hidden_dim = 4*dim, out_dim = dim)
        self.query_mlp = Mlp(in_dim=ndim, hidden_dim=4*dim, out_dim = dim, init_weights=init_weights)

        # latent to pixels
        self.perceiver = PerceiverBlock(
            dim=dim,
            num_heads=num_attn_heads,
            init_last_proj_zero=init_last_proj_zero,
            init_weights=init_weights,
        )
        # _, output_dim = self.output_shape
        self.norm = nn.LayerNorm(dim, eps=1e-6) if use_last_norm else nn.Identity()
        self.pred = LinearProjection(dim, output_dim, init_weights=init_weights)

    def forward(self, x, query_pos):
        # input projection
        x = self.proj(x)

        # create query
        # query_pos_embed = self.pos_embed(query_pos)
        query = self.query_mlp(query_pos)

        # decode
        x = self.perceiver(q=query, kv=x)
        x = self.norm(x)
        x = self.pred(x)

        # dense tensor (batch_size, max_num_points, dim) -> sparse tensor (batch_size * num_points, dim)
        # x = einops.rearrange(x, "batch_size max_num_points dim -> (batch_size max_num_points) dim")
        # unbatched = unbatch(x, batch=unbatch_idx)
        # x = torch.concat([unbatched[i] for i in unbatch_select])

        return x
    
    
    
    
class UPT_complete(nn.Module):
    def __init__(self, ndim=3, input_dim=2, dim=128, n_head=4, n_layers=6, num_latent=256, output_dim=1):
        super().__init__()
        
        self.encoder = encoder_perceiver(ndim=input_dim, dim=dim, num_attn_heads=n_head, num_output_tokens=num_latent)
        self.propagator = latent(input_dim=dim, dim=dim, depth=n_layers, num_attn_heads=n_head)
        self.decoder = decoder_perceiver(ndim=ndim, dim=dim, input_dim=dim, output_dim=output_dim, num_attn_heads=n_head)
        
    def forward(self,query,x):
        
        encoded = self.encoder(x[0])
        
        propagated = self.propagator(encoded)
        
        outputs = self.decoder(propagated, query)
        
        return outputs
    
    
    
    
class UPT_latdec(nn.Module):
    def __init__(self, ndim=3, input_dim=32, dim=128, n_head=4, n_layers=6, num_latent=256, output_dim=1):
        super().__init__()
        
        self.encoder = encoder_perceiver(ndim=input_dim, dim=dim, num_attn_heads=n_head, num_output_tokens=num_latent)
        self.propagator = latent(input_dim=dim, dim=dim, depth=n_layers, num_attn_heads=n_head)
        self.decoder = decoder_perceiver(ndim=ndim, dim=dim, input_dim=dim, output_dim=output_dim, num_attn_heads=n_head)
        
    def forward(self,query,x):
        
        encoded = self.encoder(x[0])
        
        propagated = self.propagator(encoded)
        
        outputs = self.decoder(propagated, query)
        
        return outputs
    
    
    
def create_UPT(args):
    
    if args.use_VAE:
    
        model = UPT_latdec(
            ndim = args.trunk_size,
            input_dim = args.branch_sizes[0],
            output_dim = args.output_size,
            n_layers = args.n_layer,
            dim = args.n_hidden,
            n_head = args.n_head,
        )
        
    else:
        
        model = UPT_complete(
            ndim = args.trunk_size,
            input_dim = args.branch_sizes[0],
            dim = args.n_hidden,
            n_head = args.n_head,
            n_layers = args.n_layer,
            num_latent = 256,
            output_dim = args.output_size
        )
    
    return model