from functools import partial
from typing import Literal, Optional, Union, Tuple, Callable, List, Any
import einops
import equinox as eqx
import jax.random as random
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Float, Int, PRNGKeyArray, Scalar, Bool
from Models.models.timefeatures import TimeFeatures
from Models.models.attention import MultiheadAttention
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
import abc
from Models.models.wavenet import *
from diffusion_crf.timeseries import TimeSeries
from Models.models.base import AbstractEncoderDecoderModel
from Models.models.transformer_wavenet import TransformerWaveNet

"""
This module implements a transformer-based encoder-decoder architecture for time series modeling.
It provides:

1. MyModelHypers - Configuration class defining architecture hyperparameters
2. MyModel - An encoder-decoder model combining transformer and WaveNet architectures
   with time feature embeddings for effectively processing temporal data

The architecture uses separate encoder and decoder networks to process conditioning and
target sequences, with a context vector passing information between them. Time features
are explicitly incorporated to help the model understand temporal relationships and patterns.
"""

class MyModelHypers(AbstractHyperParams):

  n_blocks: int
  hidden_channel_size: int = 32

  wavenet_kernel_width: int = 8
  num_transformer_heads: int = 4

class MyModel(AbstractEncoderDecoderModel):

  encoder: TransformerWaveNet
  decoder: TransformerWaveNet
  time_features_encoder: TimeFeatures
  time_features_decoder: TimeFeatures

  hypers: MyModelHypers = eqx.field(static=True)

  def __init__(self,
               cond_in_channels: int,
               in_channels: int,
               out_channels: int,
               hypers: MyModelHypers,
               key: PRNGKeyArray,
               strict_autoregressive: bool = False,
               causal_decoder: bool = True):
    k1, k2, k3, k4 = random.split(key, 4)

    # Create the time embedding
    time_feature_size = hypers.hidden_channel_size
    self.time_features_encoder = TimeFeatures(embedding_size=2*time_feature_size,
                                       out_features=time_feature_size,
                                       key=k3)
    self.time_features_decoder = TimeFeatures(embedding_size=2*time_feature_size,
                                       out_features=time_feature_size,
                                       key=k4)

    encoder = TransformerWaveNet(in_channels=cond_in_channels + time_feature_size,
                                 out_channels=hypers.hidden_channel_size,
                                 hypers=hypers,
                                 key=k1,
                                 causal=False, # Don't need autoregression for the encoder
                                 strict_autoregressive=strict_autoregressive)

    decoder = TransformerWaveNet(in_channels=in_channels + time_feature_size,
                                 out_channels=out_channels,
                                 hypers=hypers,
                                 key=k2,
                                 cond_channels=hypers.hidden_channel_size,
                                 causal=causal_decoder,
                                 strict_autoregressive=strict_autoregressive)

    self.encoder = encoder
    self.decoder = decoder

    self.hypers = hypers

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.encoder.batch_size

  def create_context(self, condition_series: TimeSeries) -> Float[Array, 'S C']:
    time_features_enc = jax.vmap(self.time_features_encoder)(condition_series.ts)
    xts_enc = jnp.concatenate([condition_series.yts, time_features_enc], axis=-1)
    return self.encoder(xts_enc)

  def __call__(self,
               condition_series: TimeSeries,
               latent_series: TimeSeries,
               context: Optional[Float[Array, 'S C']] = None) -> Float[Array, 'T C']:

    if context is None:
      context = self.create_context(condition_series)

    # Create the time embedding
    time_features_dec = jax.vmap(self.time_features_decoder)(latent_series.ts)

    # Concatentate the time features with the series
    xts_dec = jnp.concatenate([latent_series.yts, time_features_dec], axis=-1)

    out = self.decoder(xts_dec, context)

    return out

################################################################################################################

if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import jax
  import jax.numpy as jnp
  import jax.random as random
  import equinox as eqx
  import einops
  import pickle
  import pandas as pd
  from Utils.timefeatures import time_features
  series = pickle.load(open('series.pkl', 'rb'))
  series = TimeSeries(ts=series.ts, yts=series.yts)[:20]
  key = random.PRNGKey(0)

  # Define model hyperparameters
  num_layers = 2
  hypers = MyModelHypers(
      n_blocks=num_layers,
      hidden_channel_size=32,
      wavenet_kernel_width=8,
      num_transformer_heads=4,
  )

  # Create the model
  model = MyModel(
    cond_in_channels=series.yts.shape[-1],
    in_channels=series.yts.shape[-1],
    out_channels=series.yts.shape[-1],
    hypers=hypers,
    key=key,
    strict_autoregressive=False
  )

  condition_series = series[:-10]
  latent_series = series
  output = model(condition_series, latent_series)

  context = model.create_context(condition_series)

  def model_fixed_enc(series_dec_yts):
    series_dec2 = TimeSeries(ts=latent_series.ts, yts=series_dec_yts)
    return model(condition_series, series_dec2)

  series_dec_yts = latent_series.yts
  out = model_fixed_enc(series_dec_yts)

  # Compute the Jacobian
  J = eqx.filter_jacfwd(model_fixed_enc)(series_dec_yts)
  J = J.sum(axis=(1, 3))

  # Check that the Jacobian is lower triangular
  assert jnp.allclose(J, jnp.tril(J))

  # Plot the Jacobian
  plt.imshow(J)
  plt.colorbar()
  plt.show()


  import pdb; pdb.set_trace()