import torch

from torch import nn, einsum
from torch.nn import functional as F
from einops import rearrange, repeat

from cfd.models.utils import *

'''
class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)


class GeomCFDBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True):
        super().__init__()
        self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)

    def forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x
    
class GeomCFD(nn.Module):
    #NOTE: This is a redundant implementation that contains un-used parameters. For simplified but equivalent one, see GeoCA3D.
    def __init__(self, cfd_model, geom_encoder=None, geom_proj=None, in_out_dim=64, n_heads=4, d_head=16, dropout=0., context_dim=512, gated_ff=True) -> None:
        super().__init__()
        self.geom_encoder = geom_encoder
        self.geom_proj = geom_proj
        self.cfd_model = cfd_model
        if self.geom_encoder is not None:
            self.n_blocks = self.cfd_model.nb_hidden_layers + 2
            dims = [in_out_dim] + [self.cfd_model.size_hidden_layers] * self.cfd_model.nb_hidden_layers + [in_out_dim]
            self.blocks = nn.ModuleList([GeomCFDBlock(dim=dim, n_heads=n_heads, d_head=d_head, dropout=dropout, context_dim=context_dim, gated_ff=True) for dim in dims])

    def forward(self, data):
        cfd_data, geom_data = data

        if self.geom_encoder is None:
            x = self.cfd_model(cfd_data)
            return x
        
        x, edge_index = cfd_data.x, cfd_data.edge_index

        if hasattr(self.cfd_model, 'get_edge_attr'):
            edge_attr = self.cfd_model.get_edge_attr(x, edge_index)

        x = self.cfd_model.encoder(x)

        z = self.geom_encoder(geom_data) @ self.geom_proj
        z = z / z.norm(dim=-1, keepdim=True) 
        #z = z.expand(x.shape[0], -1)
        z = z.repeat_interleave(x.shape[0] // z.shape[0], 0)
        z = rearrange(z, '(b n) d -> b n d', n=1)

        if self.cfd_model.enc_dim == self.cfd_model.dec_dim:
            x_in = x

        x = rearrange(x, '(b n) d -> b n d', n=1)
        x = self.blocks[0](x, context=z)
        x = rearrange(x, 'b n d -> (b n) d')

        if hasattr(self.cfd_model, 'get_edge_attr'):
            x = self.cfd_model.in_layer(x, edge_index, edge_attr)
        else:
            x = self.cfd_model.in_layer(x, edge_index)

        if self.cfd_model.bn_bool:
            x = self.cfd_model.bn[0](x)

        x = self.cfd_model.activation(x)

        for i in range(1, self.n_blocks - 2):
            if hasattr(self.cfd_model, 'res_bool') and self.cfd_model.res_bool:
                x_res = x

            x = rearrange(x, '(b n) d -> b n d', n=1)
            x = self.blocks[i](x, context=z)
            x = rearrange(x, 'b n d -> (b n) d')

            if hasattr(self.cfd_model, 'get_edge_attr'):
                x = self.cfd_model.hidden_layers[i-1](x, edge_index, edge_attr)
            else:
                x = self.cfd_model.hidden_layers[i-1](x, edge_index)

            if self.cfd_model.bn_bool:
                x = self.cfd_model.bn[i](x)

            x = self.cfd_model.activation(x)

            if hasattr(self.cfd_model, 'res_bool') and self.cfd_model.res_bool:
                x = x + x_res

        x = rearrange(x, '(b n) d -> b n d', n=1)
        x = self.blocks[-2](x, context=z)
        x = rearrange(x, 'b n d -> (b n) d')

        if hasattr(self.cfd_model, 'get_edge_attr'):
            x = self.cfd_model.out_layer(x, edge_index, edge_attr)
        else:
            x = self.cfd_model.out_layer(x, edge_index)

        x = rearrange(x, '(b n) d -> b n d', n=1)
        x = self.blocks[-1](x, context=z)
        x = rearrange(x, 'b n d -> (b n) d')

        if self.cfd_model.enc_dim == self.cfd_model.dec_dim:
            x = x + x_in

        x = self.cfd_model.decoder(x) 
        
        return x
'''

class FCLayer(nn.Module):
    def __init__(self, query_dim, context_dim=None, dropout=0.):
        super().__init__()
        context_dim = default(context_dim, query_dim)

        self.to_out = nn.Sequential(
            nn.Linear(context_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None):
        context = default(context, x)
        return self.to_out(context)

class GeoCA3DBlock(nn.Module):
    def __init__(self, dim, dropout=0., context_dim=None, gated_ff=True):
        super().__init__()
        self.fc1 = FCLayer(query_dim=dim, dropout=dropout)
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        self.fc2 = FCLayer(query_dim=dim, context_dim=context_dim, dropout=dropout)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x, context=None):
        x = self.fc1(self.norm1(x)) + x
        x = self.fc2(x, context=context) + x
        x = self.ff(self.norm2(x)) + x
        return x

class GeoCA3D(nn.Module):
    def __init__(self, cfd_model, geom_encoder=None, geom_proj=None, in_out_dim=64, dropout=0., context_dim=512, gated_ff=True) -> None:
        super().__init__()
        self.geom_encoder = geom_encoder
        self.geom_proj = geom_proj
        self.cfd_model = cfd_model
        if self.geom_encoder is not None:
            self.n_blocks = self.cfd_model.nb_hidden_layers + 2
            dims = [in_out_dim] + [self.cfd_model.size_hidden_layers] * self.cfd_model.nb_hidden_layers + [in_out_dim]
            self.blocks = nn.ModuleList([GeoCA3DBlock(dim=dim, dropout=dropout, context_dim=context_dim, gated_ff=True) for dim in dims])

    def forward(self, data):
        cfd_data, geom_data = data

        if self.geom_encoder is None:
            x = self.cfd_model(cfd_data)
            return x
        
        x, edge_index = cfd_data.x, cfd_data.edge_index

        if hasattr(self.cfd_model, 'get_edge_attr'):
            edge_attr = self.cfd_model.get_edge_attr(x, edge_index)

        x = self.cfd_model.encoder(x)

        z = self.geom_encoder(geom_data) @ self.geom_proj
        z = z / z.norm(dim=-1, keepdim=True) 
        z = z.repeat_interleave(x.shape[0] // z.shape[0], 0)
        z = rearrange(z, '(b n) d -> b n d', n=1)

        if self.cfd_model.enc_dim == self.cfd_model.dec_dim:
            x_in = x

        x = rearrange(x, '(b n) d -> b n d', n=1)
        x = self.blocks[0](x, context=z)
        x = rearrange(x, 'b n d -> (b n) d')

        if hasattr(self.cfd_model, 'get_edge_attr'):
            x = self.cfd_model.in_layer(x, edge_index, edge_attr)
        else:
            x = self.cfd_model.in_layer(x, edge_index)

        if self.cfd_model.bn_bool:
            x = self.cfd_model.bn[0](x)

        x = self.cfd_model.activation(x)

        for i in range(1, self.n_blocks - 2):
            if hasattr(self.cfd_model, 'res_bool') and self.cfd_model.res_bool:
                x_res = x

            x = rearrange(x, '(b n) d -> b n d', n=1)
            x = self.blocks[i](x, context=z)
            x = rearrange(x, 'b n d -> (b n) d')

            if hasattr(self.cfd_model, 'get_edge_attr'):
                x = self.cfd_model.hidden_layers[i-1](x, edge_index, edge_attr)
            else:
                x = self.cfd_model.hidden_layers[i-1](x, edge_index)

            if self.cfd_model.bn_bool:
                x = self.cfd_model.bn[i](x)

            x = self.cfd_model.activation(x)

            if hasattr(self.cfd_model, 'res_bool') and self.cfd_model.res_bool:
                x = x + x_res

        x = rearrange(x, '(b n) d -> b n d', n=1)
        x = self.blocks[-2](x, context=z)
        x = rearrange(x, 'b n d -> (b n) d')

        if hasattr(self.cfd_model, 'get_edge_attr'):
            x = self.cfd_model.out_layer(x, edge_index, edge_attr)
        else:
            x = self.cfd_model.out_layer(x, edge_index)

        x = rearrange(x, '(b n) d -> b n d', n=1)
        x = self.blocks[-1](x, context=z)
        x = rearrange(x, 'b n d -> (b n) d')

        if self.cfd_model.enc_dim == self.cfd_model.dec_dim:
            x = x + x_in

        x = self.cfd_model.decoder(x) 
        
        return x