from functools import partial
from typing import Literal, Optional, Union, Tuple, Callable, List, Any
import einops
import equinox as eqx
import jax.random as random
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Float, Int, PRNGKeyArray, Scalar, Bool
from diffusion_crf.gaussian import StandardGaussian
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
import abc

"""
This module implements WaveNet-inspired neural network architectures for time series processing.
It provides:

1. Core building blocks:
   - AbstractHyperParams - Base class for hyperparameter configuration
   - CausalConv1d - Causal 1D convolution for sequence modeling
   - WaveNetResBlock - Residual block with dilated convolutions and gating mechanisms

2. Complete architectures:
   - WaveNet - Full implementation of a causal dilated convolutional network
     with skip connections and autoregressive capabilities

These components enable the construction of models that can effectively capture
long-range temporal dependencies while maintaining causal relationships in time series data.
The implementation is fully compatible with JAX transformations including automatic
differentiation and vectorization.
"""

class AbstractHyperParams(eqx.Module, abc.ABC):
  pass

def partition_hypers(module: eqx.Module) -> Tuple[eqx.Module, eqx.Module]:
  return eqx.partition(module, lambda x: isinstance(x, AbstractHyperParams) == False, is_leaf=lambda x: isinstance(x, AbstractHyperParams))

################################################################################################################

class CausalConv1dHypers(AbstractHyperParams):
  kernel_width: int = 3
  stride: int = 1
  dilation: int = 1
  use_bias: bool = True

  @property
  def padding(self) -> int:
    return self.kernel_width - 1

class CausalConv1d(AbstractBatchableObject):
  conv1d: eqx.nn.Conv1d
  hypers: CausalConv1dHypers

  def __init__(
    self,
    in_channels: int,
    out_channels: int,
    hypers: Optional[CausalConv1dHypers] = CausalConv1dHypers(),
    *,
    key: PRNGKeyArray
  ):
    self.conv1d = eqx.nn.Conv1d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=hypers.kernel_width,
                                stride=hypers.stride,
                                padding=hypers.padding,
                                use_bias=hypers.use_bias,
                                key=key)
    self.hypers = hypers



  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self.conv1d.weight.ndim == 3:
      return None
    elif self.conv1d.weight.ndim == 4:
      return self.conv1d.weight.shape[0]
    elif self.conv1d.weight.ndim > 4:
      return self.conv1d.weight.shape[:-3]
    else:
      raise ValueError(f"Invalid number of dimensions: {self.conv1d.weight.ndim}")

  @auto_vmap
  def __call__(self, x: Float[Array, 'T Din']) -> Float[Array, 'T Dout']:
    return self.conv1d(x.T)[:,:x.shape[0]].T

################################################################################################################

class WaveNetResBlockHypers(AbstractHyperParams):
  kernel_width: int = 2
  dilation: int = 1
  hidden_channels: int = 32

class WaveNetResBlock(AbstractBatchableObject):
  gating_conv: CausalConv1d
  filter_conv: CausalConv1d
  out_conv: CausalConv1d
  skip_conv: CausalConv1d

  def __init__(self,
               in_channels: int,
               hypers: Optional[WaveNetResBlockHypers] = WaveNetResBlockHypers(),
               *,
               key: PRNGKeyArray):
    k1, k2, k3, k4 = random.split(key, 4)

    dilation_conv_hypers = CausalConv1dHypers(kernel_width=hypers.kernel_width,
                                     stride=1,
                                     dilation=hypers.dilation,
                                     use_bias=True)

    self.gating_conv = CausalConv1d(in_channels=in_channels,
                                    out_channels=hypers.hidden_channels,
                                    hypers=dilation_conv_hypers,
                                    key=k1)
    self.filter_conv = CausalConv1d(in_channels=in_channels,
                                    out_channels=hypers.hidden_channels,
                                    hypers=dilation_conv_hypers,
                                    key=k2)


    conv1x1_hypers = CausalConv1dHypers(kernel_width=1,
                                        stride=1,
                                        dilation=1,
                                        use_bias=True)

    self.out_conv = CausalConv1d(in_channels=hypers.hidden_channels,
                                 out_channels=in_channels,
                                 hypers=conv1x1_hypers,
                                 key=k3)
    self.skip_conv = CausalConv1d(in_channels=hypers.hidden_channels,
                                  out_channels=in_channels,
                                  hypers=conv1x1_hypers,
                                  key=k4)

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.gating_conv.batch_size

  def __call__(self, x: Float[Array, 'T C']) -> Tuple[Float[Array, 'T C'], Float[Array, 'T C']]:
    # sigmoid = gx.square_sigmoid
    # tanh = lambda x: 2*sigmoid(x) - 1

    sigmoid = jax.nn.sigmoid
    tanh = jax.nn.tanh

    gate_out = sigmoid(self.gating_conv(x))
    filter_out = tanh(self.filter_conv(x))


    p = gate_out * filter_out
    out = self.out_conv(p)

    new_hidden = out + x
    skip = self.skip_conv(p)
    return new_hidden, skip

################################################################################################################

class WaveNetHypers(AbstractHyperParams):
  dilations: Int[Array, 'N']
  initial_filter_width: int = 4
  filter_width: int = 2
  residual_channels: int = 32
  dilation_channels: int = 32
  skip_channels: int = 32

class WaveNet(AbstractBatchableObject):
  blocks: List[WaveNetResBlock]

  in_projection_conv: CausalConv1d
  skip_conv: CausalConv1d
  out_projection_conv: CausalConv1d

  strict_autoregressive: bool = eqx.field(static=True)

  def __init__(self,
               in_channels: int,
               out_channels: int,
               hypers: WaveNetHypers,
               key: PRNGKeyArray,
               strict_autoregressive: bool = False):
    k1, k2, k3, k4 = random.split(key, 4)

    # Create the first projection
    initial_hypers = CausalConv1dHypers(kernel_width=hypers.initial_filter_width,
                                        stride=1,
                                        dilation=1,
                                        use_bias=True)
    self.in_projection_conv = CausalConv1d(in_channels=in_channels,
                                           out_channels=hypers.residual_channels,
                                           hypers=initial_hypers,
                                           key=k1)

    # Create the intermediate blocks
    def make_block(dilation: int, key: PRNGKeyArray) -> WaveNetResBlock:
      block_hypers = WaveNetResBlockHypers(kernel_width=hypers.filter_width,
                                     dilation=dilation,
                                     hidden_channels=hypers.dilation_channels)
      return WaveNetResBlock(in_channels=hypers.residual_channels,
                             hypers=block_hypers,
                             key=key)

    keys = random.split(k2, len(hypers.dilations))
    self.blocks = jax.vmap(make_block)(hypers.dilations, keys)

    # Create the final projections
    conv1x1_hypers = CausalConv1dHypers(kernel_width=1,
                                        stride=1,
                                        dilation=1,
                                        use_bias=True)
    self.skip_conv = CausalConv1d(in_channels=hypers.residual_channels,
                                   out_channels=hypers.skip_channels,
                                   hypers=conv1x1_hypers,
                                   key=k3)
    self.out_projection_conv = CausalConv1d(in_channels=hypers.skip_channels,
                                            out_channels=out_channels,
                                            hypers=conv1x1_hypers,
                                            key=k4)

    self.strict_autoregressive = strict_autoregressive

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.blocks[0].batch_size

  @auto_vmap
  def __call__(self, x: Float[Array, 'T C']) -> Float[Array, 'T C']:
    if self.strict_autoregressive:
      x = jnp.pad(x, ((1, 0), (0, 0)))

    # Initial projection
    hidden = self.in_projection_conv(x)

    # Residual blocks
    params, static = eqx.partition(self.blocks, eqx.is_array)
    def f(hidden, params):
      block = eqx.combine(params, static)
      hidden, out_partial = block(hidden)
      return hidden, (hidden, out_partial)

    # last_hidden, (all_hiddens, outs) = scan(f, hidden, params)
    last_hidden, (all_hiddens, outs) = jax.lax.scan(f, hidden, params)
    out_pre_swish = outs.sum(axis=0)

    # Output projection
    out = jax.nn.swish(out_pre_swish)
    out = self.skip_conv(out)
    out = jax.nn.swish(out)# + out
    out = self.out_projection_conv(out)

    if self.strict_autoregressive:
      out = out[:-1]

    return out

################################################################################################################

if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import *
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde.langevin_dynamics import CriticallyDampedLangevinDynamics
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  import matplotlib.pyplot as plt

  N = 3
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts, series.yts

  freq = 2
  discretization_times = jnp.linspace(data_times[0], data_times[-1], freq * data_times.shape[0])

  key = random.PRNGKey(0)

  # Create the SDE and emission potential encoder
  y_dim = yts.shape[1]
  sde = BrownianMotion(sigma=0.1, dim=y_dim)
  sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  # sde = TimeScaledLinearTimeInvariantSDE(sde, time_scale=1/freq)

  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                  x_dim=sde.dim,
                                                  sigma=0.01)
  decoder = PaddingLatentVariableDecoder(y_dim=y_dim,
                                          x_dim=sde.dim)
  covariance_matrix = encoder(series).node_potentials[0].to_mixed().J

  I = sde.F.eye(sde.F.shape[0])
  prior = StandardGaussian(mu=jnp.zeros(sde.dim), Sigma=I)

  from diffusion_crf.matrix import *
  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'
