from functools import wraps

import numpy as np

import torch
from torch import nn, einsum
import torch.nn.functional as F
from torch.autograd import Function

from einops import rearrange, repeat

from timm.models.layers import DropPath

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cache_fn(f):
    cache = None
    @wraps(f)
    def cached_fn(*args, _cache = True, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache
    return cached_fn

def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim = None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)

        if exists(self.norm_context):
            context = kwargs['context']
            normed_context = self.norm_context(context)
            kwargs.update(context = normed_context)

        return self.fn(x, **kwargs)

class GEGLU(nn.Module):
    def forward(self, x):
        x, gates = x.chunk(2, dim = -1)
        return x * F.gelu(gates)

class FeedForward(nn.Module):
    def __init__(self, dim, mult = 4, drop_path_rate = 0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim * mult * 2),
            GEGLU(),
            nn.Linear(dim * mult, dim)
        )

        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, x):
        return self.drop_path(self.net(x))

class Attention(nn.Module):
    def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, query_dim)

        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

    def forward(self, x, context = None, mask = None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h = h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim = -1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h = h)
        return self.drop_path(self.to_out(out))
    
    
class LinearAttention(nn.Module):
    """
    A vanilla multi-head masked self-attention layer with a projection at the end.
    It is possible to use torch.nn.MultiheadAttention here but I am including an
    explicit implementation here to show that there is nothing too scary here.
    """

    def __init__(self, query_dim, context_dim = None, heads = 8, dim_head = 64, drop_path_rate = 0.0):
        super(LinearAttention, self).__init__()
        
        inner_dim = dim_head * heads
        self.n_head = dim_head
        
        # key, query, value projections for all heads
        self.key = nn.Linear(query_dim, inner_dim)
        self.query = nn.Linear(context_dim, inner_dim)
        self.value = nn.Linear(context_dim, inner_dim)

        # output projection
        self.proj = nn.Linear(inner_dim, query_dim)
        
        # regularization
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()


    '''
        Linear Attention and Linear Cross Attention (if y is provided)
    '''
    def forward(self, x, y=None, kv_mask=None):
        y = x if y is None else y
        B, T1, C = x.size()
        _, T2, _ = y.size()
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.query(x).view(B, T1, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        k = self.key(y).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)    # (B, nh, T, hs)
        v = self.value(y).view(B, T2, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        q = q.softmax(dim=-1)
        k = k.softmax(dim=-1)

        if exists(mask):
            # mask_value = max_neg_value(q)
            mask = kv_mask[:, None, :, None]
            k = k.masked_fill_(~mask, 0.)
            v = v.masked_fill_(~mask, 0.)
            del mask

        k_cumsum = k.sum(dim=-2, keepdim=True)
        D_inv = 1. / (q * k_cumsum).sum(dim=-1, keepdim=True)       # normalized


        context = k.transpose(-2, -1) @ v
        y = (q @ context) * D_inv# + q
        # y = self.attn_drop((q @ context) * D_inv + q)

        # output projection
        y = rearrange(y, 'b h n d -> b n (h d)')
        y = self.proj(y)
        return self.drop_path(y)


class PointEmbed(nn.Module):
    def __init__(self, hidden_dim=48, dim=128):
        super().__init__()

        assert hidden_dim % 6 == 0

        self.embedding_dim = hidden_dim
        e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi
        e = torch.stack([
            torch.cat([e, torch.zeros(self.embedding_dim // 6),
                        torch.zeros(self.embedding_dim // 6)]),
            torch.cat([torch.zeros(self.embedding_dim // 6), e,
                        torch.zeros(self.embedding_dim // 6)]),
            torch.cat([torch.zeros(self.embedding_dim // 6),
                        torch.zeros(self.embedding_dim // 6), e]),
        ])
        self.register_buffer('basis', e)  # 3 x 16

        self.mlp = nn.Linear(self.embedding_dim+3, dim)

    @staticmethod
    def embed(input, basis):
        projections = torch.einsum(
            'bnd,de->bne', input, basis)
        embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
        return embeddings
    
    def forward(self, input):
        # input: B x N x 3
        embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
        return embed
    

class PointEmbed2D(nn.Module):
    def __init__(self, hidden_dim=32, dim=128):
        super().__init__()

        assert hidden_dim % 4 == 0

        self.embedding_dim = hidden_dim
        e = torch.pow(2, torch.arange(self.embedding_dim // 4)).float() * np.pi
        e = torch.stack([
            torch.cat([e, torch.zeros(self.embedding_dim // 4)]),
            torch.cat([torch.zeros(self.embedding_dim // 4), e]),
        ])
        self.register_buffer('basis', e)  # 2 x 16

        self.mlp = nn.Linear(self.embedding_dim+2, dim)

    @staticmethod
    def embed(input, basis):
        projections = torch.einsum(
            'bnd,de->bne', input, basis)
        embeddings = torch.cat([projections.sin(), projections.cos()], dim=2)
        return embeddings
    
    def forward(self, input):
        # input: B x N x 2
        embed = self.mlp(torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C
        return embed


class DiagonalGaussianDistribution(object):
    def __init__(self, mean, logvar, deterministic=False):
        self.mean = mean
        self.logvar = logvar
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(device=self.mean.device)

    def sample(self):
        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.mean.device)
        return x

    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.])
        else:
            if other is None:
                return 0.5 * torch.mean(torch.pow(self.mean, 2)
                                       + self.var - 1.0 - self.logvar,
                                       dim=[1, 2])
            else:
                return 0.5 * torch.mean(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
                    dim=[1, 2, 3])

    def nll(self, sample, dims=[1,2,3]):
        if self.deterministic:
            return torch.Tensor([0.])
        logtwopi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims)

    def mode(self):
        return self.mean
    
    
# --- Gradient Reversal Layer ---
class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        # Reverse the gradient by multiplying with -lambda_
        return grad_output.neg() * ctx.lambda_, None

def grad_reverse(x, lambda_=1.0):
    return GradReverse.apply(x, lambda_)

class Discriminator(nn.Module):
    def __init__(
        self,
        *,
        depth=2,
        dim=512,
        input_dim = 64,
        output_dim = 2,
        heads = 8,
        dim_head = 64,
        weight_tie_layers = False,
        lambda_ = 1.0
    ):
        super().__init__()

        self.lambda_ = lambda_
        self.depth = depth
        self.input_dim = input_dim

        get_latent_attn = lambda: PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, drop_path_rate=0.1))
        get_latent_ff = lambda: PreNorm(dim, FeedForward(dim, drop_path_rate=0.1))
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))

        self.layers = nn.ModuleList([])
        cache_args = {'_cache': weight_tie_layers}

        for i in range(depth):
            self.layers.append(nn.ModuleList([
                get_latent_attn(**cache_args),
                get_latent_ff(**cache_args)
            ]))

        self.proj = nn.Linear(input_dim, dim)
        self.cls_head = nn.Linear(dim, output_dim)


    def forward(self, x):
        
        # x = grad_reverse(x, self.lambda_)
        x = self.proj(x)
        
        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x
        
        cls_token = x.mean(dim=1, keepdim=False)    
        cls_pred = self.cls_head(cls_token)

        return cls_pred

class KLAutoEncoder(nn.Module):
    def __init__(
        self,
        *,
        depth=24,
        dim=512,
        queries_dim=512,
        output_dim = 1,
        num_inputs = 2048,
        num_latents = 512,
        latent_dim = 64,
        heads = 8,
        dim_head = 64,
        weight_tie_layers = False,
        decoder_ff = False,
        linear = False,
        drop_path_rate = 0.1,
        space_dim = 2
    ):
        super().__init__()

        self.depth = depth

        self.num_inputs = num_inputs
        self.num_latents = num_latents
        self.latents = nn.Parameter(torch.randn(num_latents, dim)) # Fixed number of latent tokens, same dimension as point embedding
        
        self.Attention = LinearAttention if linear else Attention

        self.cross_attend_blocks = nn.ModuleList([
            PreNorm(dim, self.Attention(dim, dim, heads = 1, dim_head = dim), context_dim = dim),
            PreNorm(dim, FeedForward(dim))
        ])

        if space_dim == 2:
            self.point_embed = PointEmbed2D(dim=dim)
        else:
            self.point_embed = PointEmbed(dim=dim)

        get_latent_attn = lambda: PreNorm(dim, self.Attention(dim, heads = heads, dim_head = dim_head, drop_path_rate=drop_path_rate))
        get_latent_ff = lambda: PreNorm(dim, FeedForward(dim, drop_path_rate=drop_path_rate))
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))

        self.layers = nn.ModuleList([])
        cache_args = {'_cache': weight_tie_layers}

        for i in range(depth):
            self.layers.append(nn.ModuleList([
                get_latent_attn(**cache_args),
                get_latent_ff(**cache_args)
            ]))

        self.decoder_cross_attn = PreNorm(queries_dim, self.Attention(queries_dim, dim, heads = 1, dim_head = dim), context_dim = dim)
        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None

        self.to_outputs = nn.Linear(queries_dim, output_dim) if exists(output_dim) else nn.Identity()

        self.proj = nn.Linear(latent_dim, dim)

        self.mean_fc = nn.Linear(dim, latent_dim)
        self.logvar_fc = nn.Linear(dim, latent_dim)

    def encode(self, pc, return_kl=True):
        # pc: B x N x 3
        B, N, D = pc.shape
        # assert N == self.num_inputs
        
        ###### fps
        # flattened = pc.view(B*N, D)

        # batch = torch.arange(B).to(pc.device)
        # batch = torch.repeat_interleave(batch, N)

        # pos = flattened

        # ratio = 1.0 * self.num_latents / self.num_inputs

        # idx = fps(pos, batch, ratio=ratio)

        # sampled_pc = pos[idx]
        # sampled_pc = sampled_pc.view(B, -1, 3)
        ######

        # sampled_pc_embeddings = self.point_embed(sampled_pc)
        sampled_pc_embeddings = repeat(self.latents, 'n d -> b n d', b = B)

        pc_embeddings = self.point_embed(pc)

        cross_attn, cross_ff = self.cross_attend_blocks

        x = cross_attn(sampled_pc_embeddings, context = pc_embeddings, mask = None) + sampled_pc_embeddings
        x = cross_ff(x) + x

        mean = self.mean_fc(x)
        logvar = self.logvar_fc(x)

        posterior = DiagonalGaussianDistribution(mean, logvar)
        x = posterior.sample()
        if return_kl:
            kl = posterior.kl()
            return kl, x
        else:
            return x


    def decode(self, x, queries):

        x = self.proj(x)

        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x

        # cross attend from decoder queries to latents
        queries_embeddings = self.point_embed(queries)
        latents = self.decoder_cross_attn(queries_embeddings, context = x)

        # optional decoder feedforward
        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)
        
        return self.to_outputs(latents)

    def forward(self, queries, pc):
        kl, x = self.encode(pc)

        o = self.decode(x, queries).squeeze(-1)

        # return o.squeeze(-1), kl
        return {'logits': o, 'kl': kl}
    
    
class AutoEncoder(nn.Module):
    def __init__(
        self,
        *,
        depth=24,
        dim=512,
        queries_dim=512,
        output_dim = 1,
        num_inputs = 2048,
        num_latents = 512,
        latent_dim = 64,
        heads = 8,
        dim_head = 64,
        weight_tie_layers = False,
        decoder_ff = False,
        linear = False,
        drop_path_rate = 0.1,
        space_dim = 2
    ):
        super().__init__()

        self.depth = depth

        self.num_inputs = num_inputs
        self.num_latents = num_latents
        self.latents = nn.Parameter(torch.randn(num_latents, dim)) # Fixed number of latent tokens, same dimension as point embedding
        
        self.Attention = LinearAttention if linear else Attention

        self.cross_attend_blocks = nn.ModuleList([
            PreNorm(dim, self.Attention(dim, dim, heads = 1, dim_head = dim), context_dim = dim),
            PreNorm(dim, FeedForward(dim))
        ])

        if space_dim == 2:
            self.point_embed = PointEmbed2D(dim=dim)
        else:
            self.point_embed = PointEmbed(dim=dim)

        get_latent_attn = lambda: PreNorm(dim, self.Attention(dim, heads = heads, dim_head = dim_head, drop_path_rate=drop_path_rate))
        get_latent_ff = lambda: PreNorm(dim, FeedForward(dim, drop_path_rate=drop_path_rate))
        get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))

        self.layers = nn.ModuleList([])
        cache_args = {'_cache': weight_tie_layers}

        for i in range(depth):
            self.layers.append(nn.ModuleList([
                get_latent_attn(**cache_args),
                get_latent_ff(**cache_args)
            ]))

        self.decoder_cross_attn = PreNorm(queries_dim, self.Attention(queries_dim, dim, heads = 1, dim_head = dim), context_dim = dim)
        self.decoder_ff = PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None

        self.to_outputs = nn.Linear(queries_dim, output_dim) if exists(output_dim) else nn.Identity()

        self.proj = nn.Linear(latent_dim, dim)

        self.mean_fc = nn.Linear(dim, latent_dim)
        # self.logvar_fc = nn.Linear(dim, latent_dim)

    def encode(self, pc, return_kl=True):
        # pc: B x N x 3
        B, N, D = pc.shape
        # assert N == self.num_inputs
        
        ###### fps
        # flattened = pc.view(B*N, D)

        # batch = torch.arange(B).to(pc.device)
        # batch = torch.repeat_interleave(batch, N)

        # pos = flattened

        # ratio = 1.0 * self.num_latents / self.num_inputs

        # idx = fps(pos, batch, ratio=ratio)

        # sampled_pc = pos[idx]
        # sampled_pc = sampled_pc.view(B, -1, 3)
        ######

        # sampled_pc_embeddings = self.point_embed(sampled_pc)
        sampled_pc_embeddings = repeat(self.latents, 'n d -> b n d', b = B)

        pc_embeddings = self.point_embed(pc)

        cross_attn, cross_ff = self.cross_attend_blocks

        x = cross_attn(sampled_pc_embeddings, context = pc_embeddings, mask = None) + sampled_pc_embeddings
        x = cross_ff(x) + x

        x = self.mean_fc(x)
        # logvar = self.logvar_fc(x)

        return x


    def decode(self, x, queries):

        x = self.proj(x)

        for self_attn, self_ff in self.layers:
            x = self_attn(x) + x
            x = self_ff(x) + x

        # cross attend from decoder queries to latents
        queries_embeddings = self.point_embed(queries)
        latents = self.decoder_cross_attn(queries_embeddings, context = x)

        # optional decoder feedforward
        if exists(self.decoder_ff):
            latents = latents + self.decoder_ff(latents)
        
        return self.to_outputs(latents)

    def forward(self, queries, pc):
        x = self.encode(pc)

        o = self.decode(x, queries).squeeze(-1)

        # return o.squeeze(-1), kl
        return {'logits': o}
    

def create_autoencoder(depth=8, dim=512, M=512, latent_dim=64, N=2048, linear=False, drop_path_rate=0.1, deterministic=False, lambda_=1.0, domain_num=2, adversarial=True, dataset='Stress',output_dim=1):
    if dataset == 'Inductor':
        space_dim = 3
    else:
        space_dim = 2
    
    if deterministic:
        model = AutoEncoder(
            depth=depth,
            dim=dim,
            queries_dim=dim,
            output_dim = output_dim,
            num_inputs = N,
            num_latents = M,
            latent_dim = latent_dim,
            heads = 8,
            dim_head = 64,
            linear = linear,
            drop_path_rate = drop_path_rate,
            space_dim = space_dim
        )
    else:
        model = KLAutoEncoder(
            depth=depth,
            dim=dim,
            queries_dim=dim,
            output_dim = output_dim,
            num_inputs = N,
            num_latents = M,
            latent_dim = latent_dim,
            heads = 8,
            dim_head = 64,
            linear = linear,
            drop_path_rate = drop_path_rate,
            space_dim = space_dim
        )
    if adversarial:
        discriminator = Discriminator(
            dim = dim,
            input_dim = latent_dim,
            output_dim = domain_num,
            lambda_ = lambda_
        )
        return discriminator, model
    else:
        return model