from functools import partial
from typing import Literal, Optional, Union, Tuple, Callable, List, Any
import einops
import equinox as eqx
import jax.random as random
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import numpy as np
from jaxtyping import Array, Float, Int, PRNGKeyArray, Scalar, Bool
from Models.models.timefeatures import TimeFeatures
from diffusion_crf.gaussian import StandardGaussian
from diffusion_crf.base import AbstractBatchableObject, auto_vmap
import abc
from Models.models.base import AbstractEncoderDecoderModel, AbstractHiddenState
from diffusion_crf import TimeSeries
from Models.models.wavenet import AbstractHyperParams

################################################################################################################

class GRUCellHypers(AbstractHyperParams):
  hidden_size: int = 16

class GRUCell(AbstractBatchableObject):
  gru: eqx.nn.GRUCell
  hypers: GRUCellHypers

  def __init__(
    self,
    in_channels: int,
    hypers: GRUCellHypers,
    *,
    key: PRNGKeyArray
  ):
    self.gru = eqx.nn.GRUCell(input_size=in_channels,
                              hidden_size=hypers.hidden_size,
                              key=key)
    self.hypers = hypers

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self.gru.weight_ih.ndim == 2:
      return None
    elif self.gru.weight_ih.ndim == 3:
      return self.gru.weight_ih.shape[0]
    elif self.gru.weight_ih.ndim > 3:
      return self.gru.weight_ih.shape[:-2]
    else:
      raise ValueError(f"Invalid number of dimensions: {self.gru.weight_ih.ndim}")

  @auto_vmap
  def __call__(self,
               input: Float[Array, 'Din'],
               hidden: Float[Array, 'hidden_size']) -> Float[Array, 'hidden_size']:
    return self.gru(input, hidden)

################################################################################################################

class GRURNNState(AbstractHiddenState):
  index: int
  state: Float[Array, 'hidden_size']

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    if self.state.ndim == 1:
      return None
    elif self.state.ndim == 2:
      return self.state.shape[0]
    elif self.state.ndim == 3:
      return self.state.shape[:-1]
    else:
      raise ValueError(f"Invalid number of dimensions: {self.state.ndim}")

class GRURNNHypers(AbstractHyperParams):
  hidden_size: int

class GRURNN(AbstractBatchableObject):

  gru: GRUCell
  initial_state: Float[Array, 'hidden_size']
  out_proj: eqx.nn.Linear

  def __init__(self,
               in_channels: int,
               out_channels: int,
               hypers: GRURNNHypers,
               key: PRNGKeyArray):
    k1, k2, k3, k4 = random.split(key, 4)

    gru_hypers = GRUCellHypers(hidden_size=hypers.hidden_size)
    self.gru = GRUCell(in_channels=in_channels,
                       hypers=gru_hypers,
                       key=k1)

    self.out_proj = eqx.nn.Linear(in_features=hypers.hidden_size,
                                  out_features=out_channels,
                                  key=k2)

    self.initial_state = jnp.zeros(hypers.hidden_size)

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.gru.batch_size

  @auto_vmap
  def __call__(
    self,
    xts: Float[Array, 'T C'],
    global_context: Optional[Float[Array, 'hidden_size']] = None
  ) -> Float[Array, 'T C']:

    hidden_state = self.initial_state
    if global_context is not None:
      hidden_state = hidden_state + global_context

    # Scan over the time axis
    def scan_body(carry, inputs):
      hidden_state_in = carry
      x = inputs
      hidden_state_out = self.gru(x, hidden_state_in)
      return hidden_state_out, hidden_state_out

    last_hidden_state, hidden_states = jax.lax.scan(scan_body, hidden_state, xts)
    out = jax.vmap(self.out_proj)(hidden_states)
    return out

  def get_initial_state(self, global_context: Optional[Float[Array, 'hidden_size']] = None) -> GRURNNState:
    if global_context is not None:
      return GRURNNState(index=0, state=self.initial_state + global_context)
    else:
      return GRURNNState(index=0, state=self.initial_state)

  def single_step(
    self,
    xt: Float[Array, 'D'],
    state: GRURNNState
  ) -> Tuple[Float[Array, 'D'], GRURNNState]:
    hidden_state_out = self.gru(xt, state.state)
    out = self.out_proj(hidden_state_out)
    return out, GRURNNState(index=state.index + 1, state=hidden_state_out)

################################################################################################################

class StackedGRUState(AbstractHiddenState):
  gru_states: List[GRURNNState] # Really a batched state of GRURNNStates

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.gru_states[0].batch_size

  @property
  def index(self) -> int:
    return self.gru_states[0].index

class StackedGRUSequenceHypers(AbstractHyperParams):
  hidden_size: int
  intermediate_channels: int
  num_layers: int

class StackedGRURNN(AbstractBatchableObject):

  gru_blocks: GRURNN
  initial_state: Float[Array, 'N hidden_size']
  in_proj: eqx.nn.Linear
  out_proj: eqx.nn.Linear
  hypers: StackedGRUSequenceHypers

  def __init__(self,
               in_channels: int,
               out_channels: int,
               hypers: StackedGRUSequenceHypers,
               key: PRNGKeyArray):
    k1, k2, k3, k4 = random.split(key, 4)

    gru_hypers = GRURNNHypers(hidden_size=hypers.hidden_size)

    self.in_proj = eqx.nn.Linear(in_features=in_channels,
                            out_features=hypers.intermediate_channels,
                            key=k1)

    self.out_proj = eqx.nn.Linear(in_features=hypers.intermediate_channels,
                             out_features=out_channels,
                             key=k2)

    def make_block(key: PRNGKeyArray) -> GRURNN:
      return GRURNN(in_channels=hypers.intermediate_channels,
                    out_channels=hypers.intermediate_channels,
                    hypers=gru_hypers,
                    key=key)
    keys = random.split(k1, hypers.num_layers)
    self.gru_blocks = jax.vmap(make_block)(keys)

    self.initial_state = jnp.zeros((hypers.num_layers, hypers.hidden_size))
    self.hypers = hypers

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.gru_blocks[0].batch_size

  @auto_vmap
  def __call__(
    self,
    xts: Float[Array, 'T C'],
    global_context: Optional[Float[Array, 'N hidden_size']] = None
  ) -> Float[Array, 'T C']:

    hidden_state = self.initial_state
    if global_context is not None:
      hidden_state = hidden_state + global_context

    params, static = eqx.partition(self.gru_blocks, eqx.is_array)

    # Scan over the depth of the GRU
    def scan_body(carry, inputs):
      hts = carry
      params, starting_hidden_state = inputs
      gru = eqx.combine(params, static)
      hts = gru(hts, starting_hidden_state)
      return hts, ()

    hts = jax.vmap(self.in_proj)(xts)
    hts, _ = jax.lax.scan(scan_body, hts, (params, hidden_state))
    out = jax.vmap(self.out_proj)(hts)
    return out

  def get_initial_state(self, global_context: Optional[Float[Array, 'N hidden_size']] = None) -> StackedGRUState:
    def make_state(vec: Float[Array, 'hidden_size']) -> GRURNNState:
      return GRURNNState(index=0, state=vec)
    if global_context is not None:
      gru_states = jax.vmap(make_state)(self.initial_state + global_context)
    else:
      gru_states = jax.vmap(make_state)(self.initial_state)
    return StackedGRUState(gru_states=gru_states)

  def single_step(
    self,
    xt: Float[Array, 'D'],
    state: StackedGRUState
  ) -> Tuple[Float[Array, 'D'], StackedGRUState]:

    params, static = eqx.partition(self.gru_blocks, eqx.is_array)

    # Scan over the depth of the GRU
    def scan_body(carry, inputs):
      hts = carry
      params, starting_hidden_state = inputs
      gru: GRURNN = eqx.combine(params, static)
      hts, hidden_state_out = gru.single_step(hts, starting_hidden_state)
      return hts, hidden_state_out

    hts = self.in_proj(xt)
    hts, hidden_states_out = jax.lax.scan(scan_body, hts, (params, state.gru_states))
    out = self.out_proj(hts)
    return out, StackedGRUState(gru_states=hidden_states_out)

################################################################################################################

class MyGRUModelState(AbstractHiddenState):
  decoder_state: StackedGRUState

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.decoder_state.batch_size

  @property
  def index(self) -> int:
    return self.decoder_state.index

class MyGRUModelHypers(AbstractHyperParams):

  hidden_size: int

class MyGRUModel(AbstractEncoderDecoderModel):

  encoder: GRURNN
  decoder: GRURNN
  time_features_encoder: TimeFeatures
  time_features_decoder: TimeFeatures

  hypers: MyGRUModelHypers = eqx.field(static=True)

  def __init__(self,
               cond_in_channels: int,
               in_channels: int,
               out_channels: int,
               hypers: MyGRUModelHypers,
               key: PRNGKeyArray,
               n_layers: Optional[int] = None,
               intermediate_channels: Optional[int] = None):
    k1, k2, k3, k4 = random.split(key, 4)

    # Create the time embedding
    time_feature_size = hypers.hidden_size
    self.time_features_encoder = TimeFeatures(embedding_size=2*time_feature_size,
                                       out_features=time_feature_size,
                                       key=k3)
    self.time_features_decoder = TimeFeatures(embedding_size=2*time_feature_size,
                                       out_features=time_feature_size,
                                       key=k4)

    if n_layers is None:
      gru_hypers = GRURNNHypers(hidden_size=hypers.hidden_size)

      encoder = GRURNN(in_channels=cond_in_channels + time_feature_size,
                      out_channels=hypers.hidden_size,
                      hypers=gru_hypers,
                      key=k1)

      decoder = GRURNN(in_channels=in_channels + time_feature_size,
                      out_channels=out_channels,
                      hypers=gru_hypers,
                      key=k2)
    else:
      # Create the transformer
      gru_hypers = StackedGRUSequenceHypers(hidden_size=hypers.hidden_size,
                                        intermediate_channels=intermediate_channels,
                                        num_layers=n_layers)
      encoder = StackedGRURNN(in_channels=cond_in_channels + time_feature_size,
                      out_channels=hypers.hidden_size*n_layers,
                      hypers=gru_hypers,
                      key=k1)
      decoder = StackedGRURNN(in_channels=in_channels + time_feature_size,
                      out_channels=out_channels,
                      hypers=gru_hypers,
                      key=k2)

    self.encoder = encoder
    self.decoder = decoder

    self.hypers = hypers

  @property
  def batch_size(self) -> Union[None,int,Tuple[int]]:
    return self.encoder.batch_size

  def create_context(self, condition_series: TimeSeries) -> Float[Array, 'S C']:
    time_features_enc = jax.vmap(self.time_features_encoder)(condition_series.ts)
    xts_enc = jnp.concatenate([condition_series.yts, time_features_enc], axis=-1)
    out = self.encoder(xts_enc)[-1]
    if out.shape == (self.hypers.hidden_size,):
      return out
    else:
      return out.reshape((-1, self.hypers.hidden_size))

  def __call__(self,
               condition_series: TimeSeries,
               latent_series: TimeSeries,
               context: Optional[Float[Array, 'S C']] = None) -> Float[Array, 'T C']:

    if context is None:
      context = self.create_context(condition_series)

    # Create the time embedding
    time_features_dec = jax.vmap(self.time_features_decoder)(latent_series.ts)

    # Concatentate the time features with the series
    xts_dec = jnp.concatenate([latent_series.yts, time_features_dec], axis=-1)

    out = self.decoder(xts_dec, context)
    return out

  def get_initial_state(
    self,
    condition_series: TimeSeries,
    context: Optional[Float[Array, 'S C']] = None
  ) -> MyGRUModelState:
    if context is None:
      context = self.create_context(condition_series)
    decoder_state = self.decoder.get_initial_state(context)
    return MyGRUModelState(decoder_state=decoder_state)

  def single_step(
    self,
    t: Scalar,
    xt: Float[Array, 'D'],
    state: MyGRUModelState
  ) -> Tuple[Float[Array, 'D'], MyGRUModelState]:

    # Create the time embedding and concatentate with the series
    time_features_dec = self.time_features_decoder(t)
    xt_dec = jnp.concatenate([xt, time_features_dec], axis=-1)

    # Run the decoder and get the new state
    out, state = self.decoder.single_step(xt_dec, state.decoder_state)

    return out, MyGRUModelState(decoder_state=state)

################################################################################################################

def blah():
  key = random.PRNGKey(0)

  from Models.experiment_identifier import ExperimentIdentifier
  from main import load_empty_model
  ei = ExperimentIdentifier.make_experiment_id(config_name='noisy_double_pendulum',
                                              objective=None,
                                              model_name='my_autoregressive',
                                              sde_type='tracking',
                                              freq=0,
                                              group='no_leakage_latent_forecasting',
                                              seed=0)
  datasets = ei.get_data_fixed()
  train_data, val_data, test_data = datasets['train_data'], datasets['val_data'], datasets['test_data']
  series = test_data[0]

  model = MyGRUModel(cond_in_channels=series.yts.shape[-1],
                     in_channels=series.yts.shape[-1],
                     out_channels=series.yts.shape[-1],
                     hypers=MyGRUModelHypers(hidden_size=16),
                     key=random.PRNGKey(0))
  out = model(series, series)
  import pdb; pdb.set_trace()


if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import *
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde.langevin_dynamics import CriticallyDampedLangevinDynamics
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  import matplotlib.pyplot as plt

  # blah()

  N = 10
  series = pickle.load(open('series.pkl', 'rb'))[:N]
  series = eqx.tree_at(lambda x: x.yts, series, series.yts[:,:2])
  series = eqx.tree_at(lambda x: x.observation_mask, series, series.observation_mask[:,:2])
  data_times, yts = series.ts, series.yts


  key = random.PRNGKey(0)
  # model = MyGRUModel(cond_in_channels=yts.shape[-1],
  #                    in_channels=yts.shape[-1],
  #                    out_channels=yts.shape[-1],
  #                    hypers=MyGRUModelHypers(hidden_size=16),
  #                    key=key,
  #                    n_layers=2,
  #                    intermediate_channels=2)
  model = MyGRUModel(cond_in_channels=yts.shape[-1],
                     in_channels=yts.shape[-1],
                     out_channels=yts.shape[-1],
                     hypers=MyGRUModelHypers(hidden_size=16),
                     key=key,
                     n_layers=None)


  # Run the model in parallel mode (not really because this is an RNN)
  full_out = model(series, series)

  # Get the starting state for sequential mode
  state: MyGRUModelState = model.get_initial_state(series)

  # Run the model in sequential mode
  sequential_out = []
  for i in range(len(series)):
    t, xt = series.ts[i], series.yts[i]
    out, state = model.single_step(t, xt, state)
    sequential_out.append(out)
    print(state.decoder_state.index)

  sequential_out = jnp.array(sequential_out)

  import pdb; pdb.set_trace()