from typing import Callable
import jax
from jax import vmap
import einops
import flax.linen as nn
import jax.numpy as jnp
from einops import rearrange, repeat
from jax.nn.initializers import normal, xavier_uniform

import itertools



class RMSNorm(nn.Module):
    
    eps: float = 1e-5

    @nn.compact
    def __call__(self, x):
        scale = self.param("scale", nn.initializers.ones, x.shape[-1])
        rms = jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
        return scale * x / rms



class PairwiseNLFluxFieldsNorm(nn.Module):

   
    Npou:int=5
    extra_params:int=0
    num_dimensions:int=1
    num_fields:int=1
    num_hidden_layers:int=3
    hidden_layer_width:int=32
    layer_norm_eps:float=1e-5
    oriented_areas:bool=False
    use_norm:bool=True
    use_res:bool=True
    activation:Callable=nn.gelu
    kernel_init:Callable=xavier_uniform()
    


    def setup(self):
        self.tail=nn.Dense(self.hidden_layer_width,kernel_init=self.kernel_init)
        self.hidden_layers=[nn.Dense(self.hidden_layer_width,kernel_init=self.kernel_init) 
                            for _ in range(self.num_hidden_layers)]
        
        if self.use_norm:
            self.norm=[RMSNorm() 
                            for _ in range(self.num_hidden_layers)]

        self.head=nn.Dense(self.num_fields,kernel_init=self.kernel_init)
        self.indexer=jnp.array([i for i in range(self.Npou)])
        if self.oriented_areas:
            self.num_one_forms = self.num_dimensions*(self.Npou * (self.Npou - 1) // 2)##num_fields?
        else:
            self.num_one_forms = 0
        self.indexer=jnp.array(list(itertools.combinations(range(self.Npou), 2)))
    
    @nn.compact   
    def __call__(self,x):

        
        flip=True

        x1=vmap(self.pairing,in_axes=(0,None))(x,not flip)
        x2=vmap(self.pairing,in_axes=(0,None))(x,flip)


        x1=self.activation(self.tail(x1))
        x2=self.activation(self.tail(x2))


        for i,hidden_layer in enumerate(self.hidden_layers):
            #norm=nn.LayerNorm(epsilon=self.layer_norm_eps)
            #x1=norm(x1)
            #x2=norm(x2)
            
            y1 = self.activation(hidden_layer(x1))
            y2 = self.activation(hidden_layer(x2))
            
            if self.use_res:
                x1 = x1+y1
                x2 = x2+y2
            else:
                x1=y1
                x2=y2
            
            if self.use_norm:
                x1= self.norm[i](x1)
                x2= self.norm[i](x2)

        return jnp.squeeze(self.head(x1)+self.head(x2))



    def pairing(self,x,flip=False):
        """

        Args:
            x:
            flip:

        Returns:

        """
        extra_params_vector = x[
                -(self.extra_params + self.num_one_forms) : -self.num_one_forms
            ]
        
       
           
        x_short = x[: -(self.extra_params + self.num_one_forms)]

      
        # the field values are stacked in the input like
        # (F0_0, F0_1, ..., F0_Npou, F1_0, F1_1, ..., F1_Npou, F2_0, F2_1, ..., F2_Npou, ...);
        # we need to get these into a shape where each row contains all the fields for a given POU
        x_short = jnp.reshape(x_short,(self.num_fields, -1)).T

        #x = jnp.reshape(x[np.flip(self.indexer, axis=1)],(-1, 2 * self.num_fields))
        # Assuming x and oriented_areas are NumPy arrays
        #x=lax.cond(
        #flip,
        #lambda l: l[jnp.flip(self.indexer,axis=1)],   # Executed when flip flag is True
        #lambda l: l[self.indexer] ,   # Executed when flip flag is False
        #x
        #)
    
        x_short=jnp.where(flip,
                          x_short[jnp.flip(self.indexer,axis=1)],
                          x_short[self.indexer])

        x_short=jnp.reshape(x_short,(-1, 2 * self.num_fields))

        if self.oriented_areas:
            oriented_areas = x[-self.num_one_forms:]
            if self.num_dimensions == 1:
                # Use expand_dims to add an extra dimension at the end
                x = jnp.concatenate([x_short, oriented_areas[...,None]], axis=-1)
            else:
                # Use reshape to reshape oriented_areas
                x = jnp.concatenate([x_short, oriented_areas.reshape(-1, self.num_dimensions)], axis=-1)
        else:
            x = x_short
        # and tack the extra parameters back on
        x = jnp.concatenate([x,jnp.tile(extra_params_vector, (x.shape[0], 1))], axis=1)

        
        return x



# Positional embedding from masked autoencoder https://arxiv.org/abs/2111.06377
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = jnp.arange(embed_dim // 2, dtype=jnp.float32)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = jnp.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = jnp.sin(out)  # (M, D/2)
    emb_cos = jnp.cos(out)  # (M, D/2)

    emb = jnp.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, length):
    return jnp.expand_dims(
        get_1d_sincos_pos_embed_from_grid(
            embed_dim, jnp.arange(length, dtype=jnp.float32)
        ),
        0,
    )


def get_2d_sincos_pos_embed(embed_dim, grid_size):
    def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
        assert embed_dim % 2 == 0
        # use half of dimensions to encode grid_h
        emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])  # (H*W, D/2)
        emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])  # (H*W, D/2)
        emb = jnp.concatenate([emb_h, emb_w], axis=1)  # (H*W, D)
        return emb

    grid_h = jnp.arange(grid_size[0], dtype=jnp.float32)
    grid_w = jnp.arange(grid_size[1], dtype=jnp.float32)
    grid = jnp.meshgrid(grid_w, grid_h, indexing="ij")  # here w goes first
    grid = jnp.stack(grid, axis=0)

    grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)

    return jnp.expand_dims(pos_embed, 0)


class PatchEmbed(nn.Module):
    patch_size: tuple = (16, 16)
    emb_dim: int = 768
    use_norm: bool = False
    kernel_init: Callable = xavier_uniform()
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, x):
        b, w, h, _ = x.shape

        x = nn.Conv(
            self.emb_dim,
            (self.patch_size[0], self.patch_size[1]),
            (self.patch_size[0], self.patch_size[1]),
            kernel_init=self.kernel_init,
            name="proj",
        )(x)

        num_patches = (
            w // self.patch_size[0],
            h // self.patch_size[1],
        )

        x = jnp.reshape(
            x, (b,num_patches[0] * num_patches[1], self.emb_dim)
        )
        if self.use_norm:
            x = nn.LayerNorm(name="norm", epsilon=self.layer_norm_eps)(x)
        return x


class MlpBlock(nn.Module):
    dim: int = 256
    out_dim: int = 256
    kernel_init: Callable = xavier_uniform()

    @nn.compact
    def __call__(self, inputs):
        x = nn.Dense(self.dim, kernel_init=self.kernel_init)(inputs)
        x = nn.gelu(x)
        x = nn.Dense(self.out_dim, kernel_init=self.kernel_init)(x)
        return x


class SelfAttnBlock(nn.Module):
    num_heads: int
    emb_dim: int
    mlp_ratio: int
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, inputs):
        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(inputs)
        x = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads, qkv_features=self.emb_dim
        )(x, x)
        x = x + inputs

        y = nn.LayerNorm(epsilon=self.layer_norm_eps)(x)
        y = MlpBlock(self.emb_dim * self.mlp_ratio, self.emb_dim)(y)

        return x + y


class CrossAttnBlock(nn.Module):
    num_heads: int
    emb_dim: int
    mlp_ratio: int
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, q_inputs, kv_inputs):
        q = nn.LayerNorm(epsilon=self.layer_norm_eps)(q_inputs)
        kv = nn.LayerNorm(epsilon=self.layer_norm_eps)(kv_inputs)

        x = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads, qkv_features=self.emb_dim
        )(q, kv)
        
        x = x + q_inputs
        y = nn.LayerNorm(epsilon=self.layer_norm_eps)(x)
        y = MlpBlock(self.emb_dim * self.mlp_ratio, self.emb_dim)(y)
        return x + y


class TimeAggregation(nn.Module):
    emb_dim: int
    depth: int
    num_heads: int = 8
    num_latents: int = 64
    mlp_ratio: int = 1
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, x):  # (B, T, S, D) --> (B, T', S, D)
        latents = self.param(
            "latents", normal(), (self.num_latents, self.emb_dim)  # (T', D)
        )

        latents = repeat(
            latents, "t d -> b s t d", b=x.shape[0], s=x.shape[2]
        )  # (B, T', S, D)
        x = rearrange(x, "b t s d -> b s t d")  # (B, S, T, D)

        # Transformer
        for _ in range(self.depth):
            latents = CrossAttnBlock(
                self.num_heads, self.emb_dim, self.mlp_ratio, self.layer_norm_eps
            )(latents, x)
        latents = rearrange(latents, "b s t d -> b t s d")  # (B, T', S, D)
        return latents


class Mlp(nn.Module):
    num_layers: int
    hidden_dim: int
    out_dim: int
    kernel_init: Callable = xavier_uniform()
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for _ in range(self.num_layers):
            y = nn.Dense(features=self.hidden_dim, kernel_init=self.kernel_init)(x)
            y = nn.gelu(y)
            x = x + y
            x = nn.LayerNorm(epsilon=self.layer_norm_eps)(x)

        x = nn.Dense(features=self.out_dim)(x)
        return x


s_emb_init = get_2d_sincos_pos_embed


class Encoder(nn.Module):
    patch_size: int = (16, 16)
    emb_dim: int = 256
    depth: int = 3
    num_heads: int = 8
    mlp_ratio: int = 1
    out_dim: int = 1
    layer_norm_eps: float = 1e-5
    

    @nn.compact
    def __call__(self, x):
        _, w, h, _ = x.shape

        x = PatchEmbed(self.patch_size, self.emb_dim)(x)

        s_emb = self.variable(
            "pos_emb",
            "enc_s_emb",
            s_emb_init,
            self.emb_dim,
            (w // self.patch_size[0], h // self.patch_size[1]),
        )

        x = x + s_emb.value

        #x = nn.LayerNorm(epsilon=self.layer_norm_eps)(x) ##probably useless
        
        for _ in range(self.depth):
            
            x = SelfAttnBlock(
                self.num_heads, self.emb_dim, self.mlp_ratio, self.layer_norm_eps
            )(x)

        return x


class ConditionalEncoder(nn.Module):

    patch_size: int = (16, 16)
    emb_dim: int = 256
    depth: int = 3
    num_heads: int = 8
    mlp_ratio: int = 1
    out_dim: int = 1
    layer_norm_eps: float = 1e-5
    kernel_init: Callable = xavier_uniform()

    @nn.compact
    def __call__(self, x, z):
        _, w, h, _ = x.shape

        x = PatchEmbed(self.patch_size, self.emb_dim)(x)

        s_emb = self.variable(
            "pos_emb",
            "enc_s_emb",
            s_emb_init,
            self.emb_dim,
            (w // self.patch_size[0], h // self.patch_size[1]),
        )

        x = x + s_emb.value

        #x = nn.LayerNorm(epsilon=self.layer_norm_eps)(x) ##probably removable
        
        for _ in range(self.depth):
            
            cond=nn.Dense(self.emb_dim,kernel_init=self.kernel_init)(z)
            cond=Mlp(num_layers=1, hidden_dim=self.emb_dim, out_dim=self.emb_dim )(cond)##2 layers was working perfect
            cond=jnp.concat([x,cond[:,None]],axis=-2)
            
            x = CrossAttnBlock(
                self.num_heads, self.emb_dim, self.mlp_ratio, self.layer_norm_eps
            )(x,cond)

        return x





class Vit(nn.Module):
    patch_size: tuple = (1, 16, 16)
    emb_dim: int = 256
    depth: int = 3
    num_heads: int = 8
    mlp_ratio: int = 1
    num_mlp_layers: int = 1
    out_dim: int = 1
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, x):
        x = Encoder(
            self.patch_size,
            self.emb_dim,
            self.depth,
            self.num_heads,
            self.mlp_ratio,
            self.layer_norm_eps,
        )(x)

        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(x)

        x = Mlp(
            num_layers=self.num_mlp_layers,
            hidden_dim=self.emb_dim,
            out_dim=self.patch_size[1] * self.patch_size[2] * self.out_dim,
            layer_norm_eps=self.layer_norm_eps,
        )(x)
        return x


class FourierEmbs(nn.Module):
    embed_scale: float
    embed_dim: int

    @nn.compact
    def __call__(self, x):
        kernel = self.param(
            "kernel", normal(self.embed_scale), (x.shape[-1], self.embed_dim // 2)
        )
        y = jnp.concatenate(
            [jnp.cos(jnp.dot(x, kernel)), jnp.sin(jnp.dot(x, kernel))], axis=-1
        )
        return y
    



        
class AttnLatentFlux(nn.Module):
    emb_dim: int
    depth: int
    num_heads: int = 8
    mlp_ratio: int = 1
    layer_norm_eps: float = 1e-5
    extra_params: int = 1

    @nn.compact
    def __call__(self, x):  # (B, T, S, D) --> (B, T', S, D)
       
        
        latents = self.param(
            "latents", normal(), (1, self.emb_dim)  # (T', D)
        )

        latents = repeat(
            latents, "s d -> b s d", b=x.shape[0])  # (B, T', S, D)
        

        # Transformer
        for _ in range(self.depth):
            latents = CrossAttnBlock(
                self.num_heads, self.emb_dim, self.mlp_ratio, self.layer_norm_eps
            )(latents, x)
        
        x=nn.LayerNorm(self.layer_norm_eps)(x)##remove if it does not work

        x=rearrange(
            latents, "b s d -> b (s d)", b=x.shape[0])  # (B, T', S, D)
        x=nn.Dense(self.extra_params)(x)
        
        return x
        




class CVit2DLatent(nn.Module):
    
    patch_size: tuple = (16, 16)
    grid_size: tuple = (128, 128)
    latent_dim: int = 256
    
    emb_dim: int = 256
    depth: int = 3
    num_heads: int = 8
    mlp_ratio: int = 1
    
    dec_emb_dim: int = 256
    dec_num_heads: int = 8
    dec_depth: int = 1
    
    
    num_mlp_layers: int = 1
    out_dim: int = 1
    
    
    embedding_type: str = "grid"
    eps: float = 1e5
    layer_norm_eps: float = 1e-5

    exp_dim: int = 1
    exp_depth: int = 1
    
    cond_enc:str="cond"


    def setup(self):
        if self.embedding_type == "grid":
            # Create grid and latents
            n_x, n_y = self.grid_size[0], self.grid_size[1]

            x = jnp.linspace(0, 1., n_x)
            y = jnp.linspace(0, 1., n_y)
            xx, yy = jnp.meshgrid(x, y, indexing="ij")

            self.grid = jnp.hstack([xx.flatten()[:, None], yy.flatten()[:, None]])
            self.latents = self.param("latents", normal(), (n_x * n_y, self.latent_dim))

    @nn.compact
    def __call__(self, x, coords, z=None):
        b = x.shape[0]

        if self.embedding_type == "grid":
            #
            d2 = ((coords[:, jnp.newaxis, :] - self.grid[jnp.newaxis, :, :]) ** 2).sum(
                axis=2
            )
            ww = jnp.exp(-self.eps * d2) / jnp.exp(-self.eps * d2).sum(
                axis=1, keepdims=True
            )

            
            coords = jnp.einsum("ic,pi->pc", self.latents, ww)
            coords = nn.Dense(self.dec_emb_dim)(coords)
            coords = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)
            

        elif self.embedding_type == "fourier":
            coords = FourierEmbs(embed_scale=2 * jnp.pi, embed_dim=self.dec_emb_dim)(
                coords
            )

        elif self.embedding_type == "mlp":
            coords = MlpBlock(self.dec_emb_dim, self.dec_emb_dim)(coords)
            coords = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)

        coords = einops.repeat(coords, "n d -> b n d", b=b)
        
        if (self.cond_enc=="cond"):
            x = ConditionalEncoder(
                    self.patch_size,
                    self.emb_dim,
                    self.depth,
                    self.num_heads,
                    self.mlp_ratio,
                    self.layer_norm_eps,
                    )(x,z)
        
        elif (self.cond_enc=="no_cond"):
            x  = Encoder(
                    self.patch_size,
                    self.emb_dim,
                    self.depth,
                    self.num_heads,
                    self.mlp_ratio,
                    self.layer_norm_eps,
                    )(x)



        z=AttnLatentFlux(emb_dim=64,#self.emb_dim,
                        depth=1,
                        num_heads=8,#self.num_heads,
                        mlp_ratio=self.mlp_ratio,
                        layer_norm_eps=self.layer_norm_eps,
                        extra_params=self.exp_dim,
                        name="latent_flux")(x)
        

        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(x) ##changed position- remove if it does not work


        x = nn.Dense(self.dec_emb_dim)(x)

        for _ in range(self.dec_depth):
            coords = CrossAttnBlock(
                num_heads=self.dec_num_heads,
                emb_dim=self.dec_emb_dim,
                mlp_ratio=self.mlp_ratio,
                layer_norm_eps=self.layer_norm_eps,
            )(coords, x)

        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)
        
        x = Mlp(
            num_layers=self.num_mlp_layers,
            hidden_dim=self.dec_emb_dim,
            out_dim=self.out_dim,
            layer_norm_eps=self.layer_norm_eps,
            name="mlp"
        )(x)

        return x,z



class CVit2D(nn.Module):
    patch_size: tuple = (16, 16)
    grid_size: tuple = (128, 128)
    latent_dim: int = 256
    emb_dim: int = 256
    depth: int = 3
    num_heads: int = 8
    dec_emb_dim: int = 256
    dec_num_heads: int = 8
    dec_depth: int = 1
    num_mlp_layers: int = 1
    mlp_ratio: int = 1
    out_dim: int = 1
    eps: float = 1e5
    layer_norm_eps: float = 1e-5
    embedding_type: str = "grid"

    cond_enc:str="cond"
  

    def setup(self):
        if self.embedding_type == "grid":
            # Create grid and latents
            n_x, n_y = self.grid_size[0], self.grid_size[1]

            x = jnp.linspace(0, 1., n_x)
            y = jnp.linspace(0, 1., n_y)
            xx, yy = jnp.meshgrid(x, y, indexing="ij")

            self.grid = jnp.hstack([xx.flatten()[:, None], yy.flatten()[:, None]])
            self.latents = self.param("latents", normal(), (n_x * n_y, self.latent_dim))

    @nn.compact
    def __call__(self, x, coords, z=None):
        b = x.shape[0]

        if self.embedding_type == "grid":
            #
            d2 = ((coords[:, jnp.newaxis, :] - self.grid[jnp.newaxis, :, :]) ** 2).sum(
                axis=2
            )
            w = jnp.exp(-self.eps * d2) / jnp.exp(-self.eps * d2).sum(
                axis=1, keepdims=True
            )

            coords = jnp.einsum("ic,pi->pc", self.latents, w)
            coords = nn.Dense(self.dec_emb_dim)(coords)
            coords = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)

        elif self.embedding_type == "fourier":
            coords = FourierEmbs(embed_scale=2 * jnp.pi, embed_dim=self.dec_emb_dim)(
                coords
            )

        elif self.embedding_type == "mlp":
            coords = MlpBlock(self.dec_emb_dim, self.dec_emb_dim)(coords)
            coords = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)

        coords = einops.repeat(coords, "n d -> b n d", b=b)

        if (self.cond_enc=="cond"):
            x = ConditionalEncoder(
                    self.patch_size,
                    self.emb_dim,
                    self.depth,
                    self.num_heads,
                    self.mlp_ratio,
                    self.layer_norm_eps,
                    )(x,z)
        
        elif (self.cond_enc=="no_cond"):
            x  = Encoder(
                    self.patch_size,
                    self.emb_dim,
                    self.depth,
                    self.num_heads,
                    self.mlp_ratio,
                    self.layer_norm_eps,
                    )(x)

        
        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(x)

        x = nn.Dense(self.dec_emb_dim)(x)

        for _ in range(self.dec_depth):
            coords = CrossAttnBlock(
                num_heads=self.dec_num_heads,
                emb_dim=self.dec_emb_dim,
                mlp_ratio=self.mlp_ratio,
                layer_norm_eps=self.layer_norm_eps,
            )(coords, x)

        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)
        
        x = Mlp(
            num_layers=self.num_mlp_layers,
            hidden_dim=self.dec_emb_dim,
            out_dim=self.out_dim,
            layer_norm_eps=self.layer_norm_eps,
            name="mlp"
        )(x)

        return x