import flax.typing
import jax
import jax.numpy as jnp
from einops import rearrange, repeat
from flax import linen as nn

from ol.utils import Array, shuffle_arrays


class Extender(nn.Module):
  """
  Abstract class for a boundary function extender.
  It projects a boundary function to a domain function.
  """

  def setup(self):
    raise NotImplementedError

  def __call__(self,
    f_boundary: Array,
    f_domain: Array,
    **kwargs,
  ) -> Array:
    return self.call(f_boundary, f_domain, **kwargs)

  def call(self, f_boundary, f_domain, **kwargs) -> Array:
    raise NotImplementedError

  @property
  def configs(self):
    configs = {
      attr: self.__getattr__(attr)
      for attr in self.__annotations__.keys() if attr != 'parent'
    }
    return configs

class FeedForward(nn.Module):
  features: int = None
  mult: int = 4
  dropout: float = 0.0

  @nn.compact
  def __call__(self, x, deterministic=False):
    features = self.features if (self.features is not None) else x.shape[-1]
    x = nn.Dense(features * self.mult)(x)
    x = nn.swish(x)
    x = nn.Dropout(self.dropout)(x, deterministic=deterministic)
    x = nn.Dense(features)(x)
    return x

class Attention(nn.Module):
  heads: int = 8
  head_dim: int = 64
  dropout: float = 0.0

  @nn.compact
  def __call__(self, x, context=None, mask=None, deterministic=False):
    # Create the queries
    q = nn.Dense((self.heads * self.head_dim), use_bias=False)(x)
    # Create the keys and the values
    k, v = jnp.split(nn.Dense((self.heads * self.head_dim * 2), use_bias=False)(context if context is not None else x), 2, axis=-1)
    # Reshape
    q, k, v = map(lambda arr: rearrange(arr, 'b n (h d) -> (b h) n d', h=self.heads), (q, k, v),)

    # Compute the raw attention scores
    sim = jnp.einsum('b i d, b j d -> b i j', q, k) * self.head_dim ** -0.5
    # Repeat the mask to match the shape of the raw scores
    mask = repeat(mask, 'b m -> (b h) n m', h=self.heads, n=sim.shape[1]).astype('bool') if (mask is not None) else None
    # Softmax attention scores using the mask
    attn = nn.softmax(sim, where=mask, axis=-1)
    # Average the attention scores
    attn_avg = rearrange(attn, '(b h) i j -> b h i j', h=self.heads).mean(axis=1)
    self.sow(col='intermediates', name='scores', value=attn_avg)

    # Compute the output
    out = jnp.einsum('b i j, b j d -> b i d', attn, v)
    out = rearrange(out, '(b h) n d -> b n (h d)', h=self.heads)
    # Feed forward pass of the output
    out = nn.Dense(x.shape[-1])(out)
    out = nn.Dropout(self.dropout)(out, deterministic=deterministic)

    return out

class CrossAttentionExtender(Extender):
  out_dim: int = 4
  latent_dim: int = 16
  depth: int = 4
  n_heads: int = 2
  ff_mult: int = 1
  p_masking: float = 0.0
  attn_dropout: float = 0.0
  ff_dropout: float = 0.0

  def setup(self):

    self.ff_initial_boundary = nn.Sequential([
      FeedForward(mult=self.ff_mult, dropout=self.ff_dropout, features=self.latent_dim),
      nn.LayerNorm(),
    ], name='ff_initial_boundary')
    self.ff_initial_domain = nn.Sequential([
      FeedForward(mult=self.ff_mult, dropout=self.ff_dropout, features=self.latent_dim),
      nn.LayerNorm(),
    ], name='ff_initial_domain')
    self.attention_layers = [
      Attention(heads=self.n_heads, head_dim=self.latent_dim, dropout=self.attn_dropout, name=f'attention_{i}')
      for i in range(self.depth)
    ]
    self.attention_lns = [nn.LayerNorm() for _ in range(self.depth)]
    self.ff_layers = [
      FeedForward(mult=self.ff_mult, dropout=self.ff_dropout, name=f'ff_{i}')
      for i in range(self.depth)
    ]
    self.ff_lns = [nn.LayerNorm() for _ in range(self.depth)]
    self.ff_final = FeedForward(mult=self.ff_mult, dropout=self.ff_dropout, features=self.out_dim, name='ff_final')

  def call(self, f_boundary: Array, f_domain: Array, m_boundary: Array = None, deterministic: bool = False):
    # Embed the boundary features
    f_boundary = self.ff_initial_boundary(f_boundary, deterministic=deterministic)
    # Embed the domain features
    f_domain = self.ff_initial_domain(f_domain, deterministic=deterministic)

    # Cross-attention blocks
    for i in range(self.depth):
      # Randomly mask some of the boundary nodes
      if deterministic:
        _f_boundary, _m_boundary = f_boundary, m_boundary
      else:
        rngkey = self.make_rng('masking')
        size_boundary_masked = int(f_boundary.shape[1] * (1 - self.p_masking))
        _f_boundary, _m_boundary = shuffle_arrays(rngkey=rngkey, arrays=[f_boundary, m_boundary], axis=1)
        _f_boundary, _m_boundary = _f_boundary[:, :size_boundary_masked], _m_boundary[:, :size_boundary_masked]
      f_domain += self.attention_layers[i](f_domain, _f_boundary, _m_boundary, deterministic=deterministic)
      f_domain = self.attention_lns[i](f_domain)
      f_domain += self.ff_layers[i](f_domain, deterministic=deterministic)
      f_domain = self.ff_lns[i](f_domain)

    # Project the output to the desired dimension
    f_domain = self.ff_final(f_domain, deterministic=deterministic)

    return f_domain
