import jax.numpy as jnp
from jax.nn.initializers import normal, xavier_uniform
from jax import vmap

import flax.linen as nn
import itertools
from typing import Callable

import einops
from einops import rearrange, repeat







class RMSNorm(nn.Module):
    
    eps: float = 1e-5

    @nn.compact
    def __call__(self, x):
        scale = self.param("scale", nn.initializers.ones, x.shape[-1])
        rms = jnp.sqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)
        return scale * x / rms




class PairwiseNLFluxFieldsNorm(nn.Module):

    Npou:int=5
    extra_params:int=0
    num_dimensions:int=1
    num_fields:int=1
    num_hidden_layers:int=3
    hidden_layer_width:int=32
    layer_norm_eps:float=1e-5
    oriented_areas:bool=False
    use_norm:bool=True
    use_res:bool=True
    activation:Callable=nn.gelu
    kernel_init:Callable=xavier_uniform()
    


    def setup(self):
        self.tail=nn.Dense(self.hidden_layer_width,kernel_init=self.kernel_init)
        self.hidden_layers=[nn.Dense(self.hidden_layer_width,kernel_init=self.kernel_init) 
                            for _ in range(self.num_hidden_layers)]
        
        if self.use_norm:
            self.norm=[RMSNorm() 
                            for _ in range(self.num_hidden_layers)]

        self.head=nn.Dense(self.num_fields,kernel_init=self.kernel_init)
        self.indexer=jnp.array([i for i in range(self.Npou)])
        if self.oriented_areas:
            self.num_one_forms = self.num_dimensions*(self.Npou * (self.Npou - 1) // 2)##num_fields?
        else:
            self.num_one_forms = 0
        self.indexer=jnp.array(list(itertools.combinations(range(self.Npou), 2)))
    
    @nn.compact   
    def __call__(self,x):

        
        flip=True

        x1=vmap(self.pairing,in_axes=(0,None))(x,not flip)
        x2=vmap(self.pairing,in_axes=(0,None))(x,flip)


        x1=self.activation(self.tail(x1))
        x2=self.activation(self.tail(x2))


        for i,hidden_layer in enumerate(self.hidden_layers):
            #norm=nn.LayerNorm(epsilon=self.layer_norm_eps)
            #x1=norm(x1)
            #x2=norm(x2)
            
            y1 = self.activation(hidden_layer(x1))
            y2 = self.activation(hidden_layer(x2))
            
            if self.use_res:
                x1 = x1+y1
                x2 = x2+y2
            else:
                x1=y1
                x2=y2
            
            if self.use_norm:
                x1= self.norm[i](x1)
                x2= self.norm[i](x2)

        return jnp.squeeze(self.head(x1)+self.head(x2))



    def pairing(self,x,flip=False):
        """

        Args:
            x:
            flip:

        Returns:

        """
        extra_params_vector = x[
                -(self.extra_params + self.num_one_forms) : -self.num_one_forms
            ]
        
       
           
        x_short = x[: -(self.extra_params + self.num_one_forms)]

      
        # the field values are stacked in the input like
        # (F0_0, F0_1, ..., F0_Npou, F1_0, F1_1, ..., F1_Npou, F2_0, F2_1, ..., F2_Npou, ...);
        # we need to get these into a shape where each row contains all the fields for a given POU
        x_short = jnp.reshape(x_short,(self.num_fields, -1)).T

        #x = jnp.reshape(x[np.flip(self.indexer, axis=1)],(-1, 2 * self.num_fields))
        # Assuming x and oriented_areas are NumPy arrays
        #x=lax.cond(
        #flip,
        #lambda l: l[jnp.flip(self.indexer,axis=1)],   # Executed when flip flag is True
        #lambda l: l[self.indexer] ,   # Executed when flip flag is False
        #x
        #)
    
        x_short=jnp.where(flip,
                          x_short[jnp.flip(self.indexer,axis=1)],
                          x_short[self.indexer])

        x_short=jnp.reshape(x_short,(-1, 2 * self.num_fields))

        if self.oriented_areas:
            oriented_areas = x[-self.num_one_forms:]
            if self.num_dimensions == 1:
                # Use expand_dims to add an extra dimension at the end
                x = jnp.concatenate([x_short, jnp.expand_dims(oriented_areas, axis=-1)], axis=1)
            else:
                # Use reshape to reshape oriented_areas
                x = jnp.concatenate([x_short, oriented_areas.reshape(-1, self.num_dimensions)], axis=1)
        else:
            x = x_short
        # and tack the extra parameters back on
        x = jnp.concatenate([x,jnp.tile(extra_params_vector, (x.shape[0], 1))], axis=1)

        
        return x



def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
    assert embed_dim % 2 == 0
    omega = jnp.arange(embed_dim // 2, dtype=jnp.float32)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D/2,)

    pos = pos.reshape(-1)  # (M,)
    out = jnp.einsum("m,d->md", pos, omega)  # (M, D/2), outer product

    emb_sin = jnp.sin(out)  # (M, D/2)
    emb_cos = jnp.cos(out)  # (M, D/2)

    emb = jnp.concatenate([emb_sin, emb_cos], axis=1)  # (M, D)
    return emb


def get_1d_sincos_pos_embed(embed_dim, length):
    return jnp.expand_dims(
        get_1d_sincos_pos_embed_from_grid(
            embed_dim, jnp.arange(length, dtype=jnp.float32)
        ),
        0,
    )


class PatchEmbed1D(nn.Module):
    patch_size: tuple = (4,)
    emb_dim: int = 768
    use_norm: bool = False
    kernel_init: Callable = xavier_uniform()
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, x):
        b, h, c = x.shape

        x = nn.Conv(
            self.emb_dim,
            (self.patch_size[0],),
            (self.patch_size[0],),
            kernel_init=self.kernel_init,
            name="proj",
        )(x)

        if self.use_norm:
            x = nn.LayerNorm(name="norm", epsilon=self.layer_norm_eps)(x)
        return x


class MlpBlock(nn.Module):
    dim: int = 256
    out_dim: int = 256
    kernel_init: Callable = xavier_uniform()

    @nn.compact
    def __call__(self, inputs):
        x = nn.Dense(self.dim, kernel_init=self.kernel_init)(inputs)
        x = nn.gelu(x)
        x = nn.Dense(self.out_dim, kernel_init=self.kernel_init)(x)
        return x


class SelfAttnBlock(nn.Module):
    num_heads: int
    emb_dim: int
    mlp_ratio: int
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, inputs):
        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(inputs)
        x = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads, qkv_features=self.emb_dim
        )(x, x)
        x = x + inputs

        y = nn.LayerNorm(epsilon=self.layer_norm_eps)(x)
        y = MlpBlock(self.emb_dim * self.mlp_ratio, self.emb_dim)(y)

        return x + y


class CrossAttnBlock(nn.Module):
    num_heads: int
    emb_dim: int
    mlp_ratio: int
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, q_inputs, kv_inputs):
        q = nn.LayerNorm(epsilon=self.layer_norm_eps)(q_inputs)
        kv = nn.LayerNorm(epsilon=self.layer_norm_eps)(kv_inputs)

        x = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads, qkv_features=self.emb_dim
        )(q, kv)
        x = x + q_inputs
        y = nn.LayerNorm(epsilon=self.layer_norm_eps)(x)
        y = MlpBlock(self.emb_dim * self.mlp_ratio, self.emb_dim)(y)
        return x + y


class Mlp(nn.Module):
    num_layers: int
    hidden_dim: int
    out_dim: int
    kernel_init: Callable = xavier_uniform()
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, inputs):
        x = inputs
        for _ in range(self.num_layers):
            y = nn.Dense(features=self.hidden_dim, kernel_init=self.kernel_init)(x)
            y = nn.gelu(y)
            x = x + y
            x = nn.LayerNorm(epsilon=self.layer_norm_eps)(x)

        x = nn.Dense(features=self.out_dim)(x)
        return x


x_emb_init = get_1d_sincos_pos_embed


class Encoder1D(nn.Module):
    patch_size: int = (4,)
    emb_dim: int = 256
    depth: int = 3
    num_heads: int = 8
    mlp_ratio: int = 1
    out_dim: int = 1
    layer_norm_eps: float = 1e-5

    @nn.compact
    def __call__(self, x):
        _, n, _ = x.shape

        x = PatchEmbed1D(self.patch_size, self.emb_dim)(x)

        x_emb = self.variable(
            "pos_emb",
            "enc_emb",
            x_emb_init,
            self.emb_dim,
            n // self.patch_size[0],
        )

        x = x + x_emb.value

        for _ in range(self.depth):
            x = SelfAttnBlock(
                self.num_heads, self.emb_dim, self.mlp_ratio, self.layer_norm_eps
            )(x)

        return x
    


"""
class ConditionalEncoder1D(nn.Module):
    patch_size: int = (4,)
    emb_dim: int = 256
    depth: int = 3
    num_heads: int = 8
    mlp_ratio: int = 1
    out_dim: int = 1
    layer_norm_eps: float = 1e-5
    kernel_init: Callable = xavier_uniform()

    @nn.compact
    def __call__(self, x, z):
        _,n,_ = x.shape

        x = PatchEmbed1D(self.patch_size, self.emb_dim)(x)

        x_emb = self.variable(
            "pos_emb",
            "enc_emb",
            x_emb_init,
            self.emb_dim,
            n // self.patch_size[0],
        )

        x = x + x_emb.value

        for _ in range(self.depth):
            
            cond=nn.Dense(self.emb_dim,kernel_init=self.kernel_init)(z)
            cond=Mlp(num_layers=1, hidden_dim=self.emb_dim, out_dim=self.emb_dim )(cond)##2 layers was working perfect
            cond=jnp.concat([x,cond[:,None]],axis=-2)
            
            x = CrossAttnBlock(
                self.num_heads, self.emb_dim, self.mlp_ratio, self.layer_norm_eps
            )(x,cond)

        return x


"""

        
class CVit1D(nn.Module):
    patch_size: tuple = (4,)
    grid_size: tuple = (200,)
    latent_dim: int = 256
    emb_dim: int = 256
    depth: int = 3
    num_heads: int = 8
    dec_emb_dim: int = 256
    dec_num_heads: int = 8
    dec_depth: int = 1

    exp_short_dim: int = 8
    #exp_emb_dim: int = 256
    exp_dim: int = 2

    num_mlp_layers: int = 1
    mlp_ratio: int = 1
    out_dim: int = 1
    layer_norm_eps: float = 1e-5
    embedding_type: str = "grid"

    def setup(self):
        if self.embedding_type == "grid":
            # Create grid and latents
            n_x = self.grid_size[0]
            self.grid = jnp.linspace(0, 1, n_x)
            self.latents = self.param("latents", normal(), (n_x, self.latent_dim))

    @nn.compact
    def __call__(self, x, coords):
        b, h, c = x.shape

        if self.embedding_type == "grid":
            d2 = (coords - self.grid[None, :]) ** 2
            w = jnp.exp(-1e5 * d2) / jnp.exp(-1e5 * d2).sum(axis=1, keepdims=True)

            coords = jnp.einsum("ic,pi->pc", self.latents, w)
            coords = nn.Dense(self.dec_emb_dim)(coords)
            coords = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)

        elif self.embedding_type == "mlp":
            coords = MlpBlock(self.dec_emb_dim, self.dec_emb_dim)(coords)
            coords = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)

        coords = einops.repeat(coords, "n d -> b n d", b=b)

        x = Encoder1D(
            self.patch_size,
            self.emb_dim,
            self.depth,
            self.num_heads,
            self.mlp_ratio,
            self.layer_norm_eps,
        )(x)

        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(x)

        #latent

         
        x = nn.Dense(self.dec_emb_dim)(x)

        for _ in range(self.dec_depth):
            coords = CrossAttnBlock(
                num_heads=self.dec_num_heads,
                emb_dim=self.dec_emb_dim,
                mlp_ratio=self.mlp_ratio,
                layer_norm_eps=self.layer_norm_eps,
            )(coords, x)

        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)
        
        x = Mlp(
            num_layers=self.num_mlp_layers,
            hidden_dim=self.dec_emb_dim,
            out_dim=self.out_dim,
            layer_norm_eps=self.layer_norm_eps,
        )(x)

        return x
    




class AttnLatentFlux(nn.Module):
    emb_dim: int
    depth: int
    num_heads: int = 8
    mlp_ratio: int = 1
    layer_norm_eps: float = 1e-5
    extra_params: int = 1

    @nn.compact
    def __call__(self, x):  # (B, T, S, D) --> (B, T', S, D)
       
        
        latents = self.param(
            "latents", normal(), (1, self.emb_dim)  # (T', D)
        )

        latents = repeat(
            latents, "s d -> b s d", b=x.shape[0])  # (B, T', S, D)
        

        # Transformer
        for _ in range(self.depth):
            latents = CrossAttnBlock(
                self.num_heads, self.emb_dim, self.mlp_ratio, self.layer_norm_eps
            )(latents, x)
        
        x=nn.LayerNorm(self.layer_norm_eps)(x)##remove if it does not work

        x=rearrange(
            latents, "b s d -> b (s d)", b=x.shape[0])  # (B, T', S, D)
        x=nn.Dense(self.extra_params)(x)
        
        return x
    

class CVit1DLatent(nn.Module):
    patch_size: tuple = (4,)
    grid_size: tuple = (200,)
    latent_dim: int = 256
    
    emb_dim: int = 256
    depth: int = 3
    num_heads: int = 8
    mlp_ratio: int = 1

    dec_emb_dim: int = 256
    dec_num_heads: int = 8
    dec_depth: int = 1

    num_mlp_layers: int = 1

    out_dim: int = 1
    
    embedding_type: str = "grid"
    eps: float = 1e5
    layer_norm_eps: float = 1e-5
    
    exp_dim: int = 2
    exp_depth: int =1
    #exp_emb_dim: int = 256
    

    def setup(self):
        if self.embedding_type == "grid":
            # Create grid and latents
            n_x = self.grid_size[0]
            self.grid = jnp.linspace(0, 1, n_x)
            self.latents = self.param("latents", normal(), (n_x, self.latent_dim))

    @nn.compact
    def __call__(self, x, coords):
        b, h, c = x.shape

        if self.embedding_type == "grid":
            d2 = (coords - self.grid[None, :]) ** 2
            w = jnp.exp(-self.eps * d2) / jnp.exp(-self.eps * d2).sum(axis=1, keepdims=True)

            coords = jnp.einsum("ic,pi->pc", self.latents, w)
            coords = nn.Dense(self.dec_emb_dim)(coords)
            coords = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)

        elif self.embedding_type == "mlp":
            coords = MlpBlock(self.dec_emb_dim, self.dec_emb_dim)(coords)
            coords = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)

        coords = einops.repeat(coords, "n d -> b n d", b=b)

        x = Encoder1D(
            self.patch_size,
            self.emb_dim,
            self.depth,
            self.num_heads,
            self.mlp_ratio,
            self.layer_norm_eps,
        )(x)

        z=AttnLatentFlux(emb_dim=64,#self.emb_dim,
                        depth=1,
                        num_heads=8,#self.num_heads,
                        mlp_ratio=self.mlp_ratio,
                        layer_norm_eps=self.layer_norm_eps,
                        extra_params=self.exp_dim,
                        name="latent_flux")(x)
        

        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(x)

        x = nn.Dense(self.dec_emb_dim)(x)

        for _ in range(self.dec_depth):
            coords= CrossAttnBlock(
                num_heads=self.dec_num_heads,
                emb_dim=self.dec_emb_dim,
                mlp_ratio=self.mlp_ratio,
                layer_norm_eps=self.layer_norm_eps,
            )(coords, x)

        x = nn.LayerNorm(epsilon=self.layer_norm_eps)(coords)
        
        x = Mlp(
            num_layers=self.num_mlp_layers,
            hidden_dim=self.dec_emb_dim,
            out_dim=self.out_dim,
            layer_norm_eps=self.layer_norm_eps,
            name='head'
            )(x)

        return x,z





'''
class PreModel(nn.Module):
   tr_config: ConfigDict
   in_features: int=128
   out_features: int=1
   activation: Callable = nn.gelu
   

   @nn.compact
   def __call__(self,x,coords):
       x=nn.Dense(self.in_features)(x)[...,None]
       x=self.activation(x)
       x=CVit1D(**self.tr_config)(x,coords)
       x=nn.softmax(x,axis=-1)
       x=nn.Dense(self.out_features)(x)
       return x
   

class PreModel2(nn.Module):
   tr_config: ConfigDict
   in_features: int=128
   out_features: int=1
   activation: Callable = nn.gelu
   
   @nn.compact
   def __call__(self,x,coords):
       z=x
       
       x=nn.Dense(self.in_features)(x)[...,None]
       x=self.activation(x)
       x=CVit1D(**self.tr_config)(x,coords)
       x=nn.softmax(x,axis=-1)
       
       z=nn.Dense(x.shape[-1]*self.out_features)(z)
       z=self.activation(z)
       
       z=jnp.reshape(z,(-1,x.shape[-1],self.out_features))
       x=vmap(jnp.matmul)(x,z)
       
       return x

class CrModel(nn.Module):
    tr_config: ConfigDict
    in_features: int=128
    activation: Callable = nn.gelu
    
    @nn.compact
    def __call__(self,x,coords):
        x=nn.Dense(self.in_features)(x)[...,None]
        x=self.activation(x)
        x=CVit1D(**self.tr_config)(x,coords)
        return x
'''
         

 