from functools import partial
from typing import Literal, Optional, Union, Tuple, Callable, List, Any
import einops
import equinox as eqx
import jax.random as random
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Float, Int, PRNGKeyArray, Scalar, Bool
from Models.models.attention import MultiheadAttention
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
import abc
from Models.models.wavenet import *
from diffusion_crf.timeseries import TimeSeries
from Models.models.base import AbstractEncoderDecoderModel

"""
This module implements a hybrid architecture combining transformer attention mechanisms
with WaveNet-style convolutional networks. It provides:

1. TransformerWaveNetResBlock - A residual block combining self-attention, optional
   cross-attention, and dilated convolutions for powerful sequence modeling

2. TransformerWaveNet - A complete neural network architecture that combines multiple
   transformer-wavenet blocks with projection layers for processing time series data

This hybrid approach leverages the strengths of both architectures: transformers excel at
capturing long-range dependencies while WaveNet's dilated convolutions efficiently model
local patterns. The implementation supports both causal (autoregressive) and non-causal
variants for different time series modeling tasks.
"""

class TransformerWaveNetResBlockHypers(AbstractHyperParams):
  wavenet_kernel_width: int = 8
  num_transformer_heads: int = 4

class TransformerWaveNetResBlock(AbstractBatchableObject):

  layernorm1: eqx.nn.LayerNorm
  layernorm2: eqx.nn.LayerNorm
  multiheaded_attention: MultiheadAttention
  wavenet_block: WaveNetResBlock

  cross_attention: Optional[MultiheadAttention]
  layernorm3: Optional[eqx.nn.LayerNorm]

  def __init__(self,
               in_channels: int,
               cond_channels: Optional[int] = None,
               hypers: Optional[TransformerWaveNetResBlockHypers] = TransformerWaveNetResBlockHypers(),
               *,
               causal: bool = True,
               key: PRNGKeyArray):
    k1, k2, k3, k4 = random.split(key, 4)

    # Create all of the layer norms
    self.layernorm1 = eqx.nn.LayerNorm(shape=(in_channels,))
    self.layernorm2 = eqx.nn.LayerNorm(shape=(in_channels,))

    # Create the multiheaded attention
    self.multiheaded_attention = MultiheadAttention(num_heads=hypers.num_transformer_heads,
                                        query_size=in_channels,
                                        key_value_size=in_channels,
                                        output_size=in_channels,
                                        causal=causal,
                                        key=k2)

    # Instead of using a feedforward, we'll use a single wavenet block
    wavenet_hypers = WaveNetResBlockHypers(kernel_width=hypers.wavenet_kernel_width,
                                   dilation=1,
                                   hidden_channels=2*in_channels)
    self.wavenet_block = WaveNetResBlock(in_channels=in_channels,
                                         hypers=wavenet_hypers,
                                         key=k1)

    # Optional cross attention
    if cond_channels is not None:
      self.cross_attention = MultiheadAttention(num_heads=hypers.num_transformer_heads,
                                                query_size=in_channels,
                                                key_value_size=cond_channels,
                                                output_size=in_channels,
                                                causal=False,
                                                key=k2)
      self.layernorm3 = eqx.nn.LayerNorm(shape=(in_channels,))
    else:
      self.cross_attention = None
      self.layernorm3 = None

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.wavenet_block.batch_size

  def __call__(self, xts: Float[Array, 'T C'], hts: Optional[Float[Array, 'S C']] = None) -> Tuple[Float[Array, 'T C'], Float[Array, 'T C']]:

    # Layer norm
    yts = jax.vmap(self.layernorm1)(xts)

    # Self attention
    yts = self.multiheaded_attention(query=yts, key_and_value=yts)
    xts = yts + xts

    if hts is not None:
      assert self.cross_attention is not None
      yts = self.cross_attention(query=yts, key_and_value=hts)
      xts = yts + xts

    # Layer norm
    yts = jax.vmap(self.layernorm2)(xts)

    # Wavenet includes final skip connection
    out, _ = self.wavenet_block(yts)
    return out

################################################################################################################

class TransformerWaveNetHypers(AbstractHyperParams):

  n_blocks: int
  hidden_channel_size: int = 32

  wavenet_kernel_width: int = 8
  num_transformer_heads: int = 4


class TransformerWaveNet(AbstractBatchableObject):

  in_projection_conv: CausalConv1d
  blocks: List[TransformerWaveNetResBlock]
  out_projection_conv: CausalConv1d

  strict_autoregressive: bool = eqx.field(static=True)

  def __init__(self,
               in_channels: int,
               out_channels: int,
               hypers: TransformerWaveNetHypers,
               key: PRNGKeyArray,
               cond_channels: Optional[int] = None,
               causal: bool = True,
               strict_autoregressive: bool = False):
    k1, k2, k3, k4 = random.split(key, 4)

    # Create the first projection
    initial_hypers = CausalConv1dHypers(kernel_width=hypers.wavenet_kernel_width,
                                       stride=1,
                                       dilation=1,
                                       use_bias=True)
    self.in_projection_conv = CausalConv1d(in_channels=in_channels,
                                          out_channels=hypers.hidden_channel_size,
                                          hypers=initial_hypers,
                                          key=k1)

    # Create the intermediate blocks
    def make_block(key: PRNGKeyArray) -> TransformerWaveNetResBlock:
      block_hypers = TransformerWaveNetResBlockHypers(wavenet_kernel_width=hypers.wavenet_kernel_width,
                                                      num_transformer_heads=hypers.num_transformer_heads)
      return TransformerWaveNetResBlock(in_channels=hypers.hidden_channel_size,
                                        cond_channels=cond_channels,
                                        hypers=block_hypers,
                                        causal=causal,
                                        key=key)

    keys = random.split(k2, hypers.n_blocks)
    self.blocks = jax.vmap(make_block)(keys)

    # Create the final projections
    conv1x1_hypers = CausalConv1dHypers(kernel_width=1,
                                       stride=1,
                                       dilation=1,
                                       use_bias=True)
    self.out_projection_conv = CausalConv1d(in_channels=hypers.hidden_channel_size,
                                           out_channels=out_channels,
                                           hypers=conv1x1_hypers,
                                           key=k4)

    self.strict_autoregressive = strict_autoregressive

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.in_projection_conv.batch_size

  @auto_vmap
  def __call__(self, x: Float[Array, 'T C'], hts: Optional[Float[Array, 'S C']] = None) -> Float[Array, 'T C']:
    if self.strict_autoregressive:
      x = jnp.pad(x, ((1, 0), (0, 0)))

    # Initial projection
    hidden = self.in_projection_conv(x)

    # Process through transformer-wavenet blocks
    params, static = eqx.partition(self.blocks, eqx.is_array)
    def f(hidden, params):
      block = eqx.combine(params, static)
      hidden = block(hidden, hts)
      return hidden, ()

    hidden, _ = jax.lax.scan(f, hidden, params)

    # Output projection
    out = self.out_projection_conv(hidden)

    if self.strict_autoregressive:
      out = out[:-1]

    return out

################################################################################################################
