from functools import partial
from typing import Literal, Optional, Union, Tuple, Callable, List, Any
import einops
import equinox as eqx
import jax.random as random
import jax
import jax.tree_util as jtu
import jax.numpy as jnp
import numpy as np
from diffusion_crf import *


from Models.models.autoregressive import MyAutoregressiveModel
from Models.models.autoregressive import MyReparameterizedAutoregressiveModel, MyReparameterizedAutoregressiveRNNModel
from Models.models.non_probabilistic import MyNonProbabilisticModel
from Models.models.non_probabilistic import MyNonProbabilisticRNNModel
from Models.models.neural_sde import MyNeuralSDE
from Models.models.neural_sde import MyNeuralSDERNN
from Models.models.diffusion_model import MyDiffusionModel
from Models.models.diffusion_model import MyDiffusionRNNModel
from Models.models.baseline_autoregressive import MyBaselineAutoregressiveModel
from Models.models.baseline_autoregressive import MyBaselineAutoregressiveRNNModel
from Models.models.baseline_diffusion_model import MyBaselineDiffusionModel
from Models.models.baseline_diffusion_model import MyBaselineDiffusionRNNModel
from Models.models.ho_models.discrete import MyReparameterizedAutoregressiveRNNBwdModel
from Models.models.ho_models.continuous import MyNeuralSDERNNBwd

################################################################################################################

def create_model(sde: AbstractLinearSDE,
                 encoder: AbstractEncoder,
                 config: dict,
                 key: PRNGKeyArray) -> PyTree:

  args = config['command_line_args']
  model_configs = config['model']
  data_configs = config['dataset']

  freq = args['freq']
  model_type = model_configs['model_type']
  dataset_config = data_configs
  model_config = model_configs['model']

  if isinstance(sde.F, DiagonalMatrix):
    potential_cov_type = 'diagonal'
  elif isinstance(sde.F, DenseMatrix):
    potential_cov_type = 'dense'
  elif isinstance(sde.F, Diagonal2x2BlockMatrix):
    potential_cov_type = 'block2x2'
  elif isinstance(sde.F, Diagonal3x3BlockMatrix):
    potential_cov_type = 'block3x3'

  if model_type == 'MyAutoregressiveModel':
    if 'latent_cond_len' in model_configs:
      model = MyAutoregressiveModel(sde,
                                            encoder,
                                            n_layers=model_config['n_layers'],
                                            filter_width=model_config['filter_width'],
                                            hidden_channel_size=model_config['hidden_channel_size'],
                                            num_transformer_heads=model_config['num_transformer_heads'],
                                            potential_cov_type=potential_cov_type,
                                            parametrization=model_config['parametrization'],
                                            interpolation_freq=freq,
                                            seq_len=dataset_config['seq_length'],
                                            cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                                            latent_cond_len=model_configs['latent_cond_len'],
                                            key=key,
                                            predict_cov=model_config['predict_cov'])

    else:
      assert 0, 'Deprecated model'

  elif model_type == 'MyReparameterizedAutoregressiveModel':
    model = MyReparameterizedAutoregressiveModel(sde,
                                                  encoder,
                                                  n_layers=model_config['n_layers'],
                                                  filter_width=model_config['filter_width'],
                                                  hidden_channel_size=model_config['hidden_channel_size'],
                                                  num_transformer_heads=model_config['num_transformer_heads'],
                                                  potential_cov_type=potential_cov_type,
                                                  parametrization=model_config['parametrization'],
                                                  interpolation_freq=freq,
                                                  seq_len=dataset_config['seq_length'],
                                                  cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                                                  latent_cond_len=model_configs['latent_cond_len'],
                                                  key=key,
                                                  predict_cov=model_config['predict_cov'])

  elif model_type == 'MyReparameterizedAutoregressiveRNNModel':
    model = MyReparameterizedAutoregressiveRNNModel(sde,
                                                  encoder,
                                                  hidden_size=model_config['hidden_size'],
                                                  potential_cov_type=potential_cov_type,
                                                  parametrization=model_config['parametrization'],
                                                  interpolation_freq=freq,
                                                  seq_len=dataset_config['seq_length'],
                                                  cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                                                  latent_cond_len=model_configs['latent_cond_len'],
                                                  key=key,
                                                  predict_cov=model_config['predict_cov'])

  elif model_type == 'MyReparameterizedAutoregressiveRNNBwdModel':
    model = MyReparameterizedAutoregressiveRNNBwdModel(sde,
                                                  encoder,
                                                  hidden_size=model_config['hidden_size'],
                                                  potential_cov_type=potential_cov_type,
                                                  parametrization=model_config['parametrization'],
                                                  interpolation_freq=freq,
                                                  seq_len=dataset_config['seq_length'],
                                                  cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                                                  latent_cond_len=model_configs['latent_cond_len'],
                                                  key=key,
                                                  predict_cov=model_config['predict_cov'])
  elif model_type == 'MyNeuralSDERNNBwd':
    model = MyNeuralSDERNNBwd(sde,
                               encoder,
                               hidden_size=model_config['hidden_size'],
                               potential_cov_type=potential_cov_type,
                               parametrization=model_config['parametrization'],
                               interpolation_freq=freq,
                               seq_len=dataset_config['seq_length'],
                               cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                               latent_cond_len=model_configs['latent_cond_len'],
                               key=key,
                               predict_cov=model_config['predict_cov'])

  elif model_type == 'MyNonProbabilisticModel':
    model = MyNonProbabilisticModel(sde,
                                          encoder,
                                          n_layers=model_config['n_layers'],
                                          filter_width=model_config['filter_width'],
                                          hidden_channel_size=model_config['hidden_channel_size'],
                                          num_transformer_heads=model_config['num_transformer_heads'],
                                          potential_cov_type=potential_cov_type,
                                          parametrization=model_config['parametrization'],
                                          interpolation_freq=freq,
                                          seq_len=dataset_config['seq_length'],
                                          cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                                          key=key,
                                          predict_cov=model_config['predict_cov'])

  elif model_type == 'MyNonProbabilisticRNNModel':
    model = MyNonProbabilisticRNNModel(sde,
                                        encoder,
                                        hidden_size=model_config['hidden_size'],
                                        potential_cov_type=potential_cov_type,
                                        parametrization=model_config['parametrization'],
                                        interpolation_freq=freq,
                                        seq_len=dataset_config['seq_length'],
                                        cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                                        key=key,
                                        predict_cov=model_config['predict_cov'])

  elif model_type == 'MyNeuralSDE':
    if 'latent_cond_len' in model_configs:
      model = MyNeuralSDE(sde,
                      encoder,
                      n_layers=model_config['n_layers'],
                      filter_width=model_config['filter_width'],
                      hidden_channel_size=model_config['hidden_channel_size'],
                      num_transformer_heads=model_config['num_transformer_heads'],
                      interpolation_freq=freq,
                      seq_len=dataset_config['seq_length'],
                      cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                      latent_cond_len=model_configs['latent_cond_len'],
                      key=key,
                      predict_flow_or_drift=model_config['predict_flow_or_drift'])
    else:
      assert 0, 'Deprecated model'

  elif model_type == 'MyNeuralSDERNN':
    model = MyNeuralSDERNN(sde,
                      encoder,
                      hidden_size=model_config['hidden_size'],
                      interpolation_freq=freq,
                      seq_len=dataset_config['seq_length'],
                      cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                      latent_cond_len=model_configs['latent_cond_len'],
                      key=key,
                      predict_flow_or_drift=model_config['predict_flow_or_drift'],
                      n_layers=model_config.get('n_layers', None),
                      intermediate_channels=model_config.get('intermediate_channels', None))

  elif model_type == 'MyDiffusionModel':
    if 'latent_cond_len' in model_configs:
      model = MyDiffusionModel(sde,
                      encoder,
                      n_layers=model_config['n_layers'],
                      filter_width=model_config['filter_width'],
                      hidden_channel_size=model_config['hidden_channel_size'],
                      num_transformer_heads=model_config['num_transformer_heads'],
                      interpolation_freq=freq,
                      seq_len=dataset_config['seq_length'],
                      cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                      latent_cond_len=model_configs['latent_cond_len'],
                      key=key)
    else:
      assert 0, 'Deprecated model'

  elif model_type == 'MyDiffusionRNNModel':
    model = MyDiffusionRNNModel(sde,
                      encoder,
                      hidden_size=model_config['hidden_size'],
                      interpolation_freq=freq,
                      seq_len=dataset_config['seq_length'],
                      cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                      latent_cond_len=model_configs['latent_cond_len'],
                      key=key)

  elif model_type == 'MyBaselineAutoregressiveModel':
    model = MyBaselineAutoregressiveModel(sde,
                                            encoder,
                                            n_layers=model_config['n_layers'],
                                            filter_width=model_config['filter_width'],
                                            hidden_channel_size=model_config['hidden_channel_size'],
                                            num_transformer_heads=model_config['num_transformer_heads'],
                                            potential_cov_type=potential_cov_type,
                                            parametrization=model_config['parametrization'],
                                            interpolation_freq=freq,
                                            seq_len=dataset_config['seq_length'],
                                            cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                                            key=key)

  elif model_type == 'MyBaselineAutoregressiveRNNModel':
    model = MyBaselineAutoregressiveRNNModel(sde,
                                            encoder,
                                            hidden_size=model_config['hidden_size'],
                                            potential_cov_type=potential_cov_type,
                                            parametrization=model_config['parametrization'],
                                            interpolation_freq=freq,
                                            seq_len=dataset_config['seq_length'],
                                            cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                                            key=key)

  elif model_type == 'MyBaselineDiffusionModel':
    model = MyBaselineDiffusionModel(sde,
                      encoder,
                      n_layers=model_config['n_layers'],
                      filter_width=model_config['filter_width'],
                      hidden_channel_size=model_config['hidden_channel_size'],
                      num_transformer_heads=model_config['num_transformer_heads'],
                      interpolation_freq=freq,
                      seq_len=dataset_config['seq_length'],
                      cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                      key=key)

  elif model_type == 'MyBaselineDiffusionRNNModel':
    model = MyBaselineDiffusionRNNModel(sde,
                      encoder,
                      hidden_size=model_config['hidden_size'],
                      interpolation_freq=freq,
                      seq_len=dataset_config['seq_length'],
                      cond_len=dataset_config['seq_length'] - dataset_config['pred_length'],
                      key=key)

  else:
    raise ValueError(f'Unknown model type: {model_type}')

  return model

################################################################################################################


