import equinox as eqx
import jax.numpy as jnp
import jax.random as random
from typing import Optional, Union, Tuple
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool, PyTree
from diffusion_crf import AbstractBatchableObject, auto_vmap
from Data.pendulum.double_pendulum import get_raw_pendulum_data, get_raw_improved_pendulum_data
from diffusion_crf import AbstractBatchableObject, auto_vmap, TimeSeries
import jax
from Data.dynamical_system.data_gen import get_raw_physics_data
from Data.harmonic_oscillator.make_data import get_raw_harmonic_oscillator_data
from Data.harmonic_oscillator.make_data2 import get_raw_harmonic_oscillator_data2
from Models.experiment_identifier import ExperimentIdentifier

def make_series(data: Float[Array, 'T D']):
  ts = jnp.arange(data.shape[0]).astype(data.dtype)
  return TimeSeries(ts, data)

def get_pendulum_dataset(config: dict) -> TimeSeries:
  data = get_raw_pendulum_data(num_samples=config['dataset_size'])
  data_series = jax.vmap(make_series)(data)
  return data_series

def get_improved_pendulum_dataset(config: dict,
                                  key: Optional[PRNGKeyArray] = None,
                                  return_vector_fields: bool = False,
                                  train_val_test_split: Optional[Tuple[int, int, int]] = None,
                                  return_raw_data: bool = False) -> TimeSeries:
  out = get_raw_improved_pendulum_data(max_length=config['dataset_size'] + config['seq_length'] - 1,
                                        sample_rate=config['sample_rate'],
                                        noise_std=config['noise_std'],
                                        key=key)
  data_series, original_series, vfs_series = out

  ts = data_series.ts
  ts = ts - ts[0]
  data_series = TimeSeries(ts, data_series.yts)
  denoised_data_series = TimeSeries(ts, original_series.yts)

  if return_raw_data:
    return data_series, denoised_data_series


  if train_val_test_split is not None:
    from Utils.Data_utils.real_datasets import make_train_val_test_split
    # Divide the data into train, val, test
    cond_length = config['seq_length'] - config['pred_length']
    data_batches = make_train_val_test_split(key, data_series, train_val_test_split, config['seq_length'], cond_length)
    denoised_data_batches = make_train_val_test_split(key, denoised_data_series, train_val_test_split, config['seq_length'], cond_length)
    return data_batches, denoised_data_batches

  # Batch the data
  batched_data_series = data_series.make_windowed_batches(window_size=config['seq_length'])
  batched_denoised_data_series = denoised_data_series.make_windowed_batches(window_size=config['seq_length'])

  if return_vector_fields:
    batched_vfs_series = vfs_series.make_windowed_batches(window_size=config['seq_length'])
    return batched_data_series, batched_denoised_data_series, batched_vfs_series

  return batched_data_series, batched_denoised_data_series


def get_physics_dataset(config: dict,
                                  key: Optional[PRNGKeyArray] = None,
                                  train_val_test_split: Optional[Tuple[int, int, int]] = None,
                                  return_raw_data: bool = False) -> TimeSeries:
  """Following setup from https://arxiv.org/pdf/2503.10375?"""
  dataset_name = config['name']
  assert config['dataset_size'] == 2400
  assert config['seq_length'] == 150
  assert config['pred_length'] == 75

  data_series, denoised_data_series = get_raw_physics_data(dataset_name, noise_std=config['noise_std'], key=key)

  if return_raw_data:
    return data_series, denoised_data_series

  # First 75 used for observed data
  # Next 75 for prediction
  # Last 75 for extrapolation
  seq_length = 150

  # 1600 for training, 400 for validation, 400 for testing
  n_train = 1600
  n_val = 400
  n_test = 400

  # Shuffle the data
  idx = jax.random.permutation(key, jnp.arange(data_series.batch_size))
  data_series = data_series[idx]
  denoised_data_series = denoised_data_series[idx]

  # Split the data into train and val
  train_and_val_series = data_series[:n_train + n_val, :seq_length]
  train_series = train_and_val_series[:n_train]
  val_series = train_and_val_series[n_train:]
  test_series = data_series[n_train + n_val:, -seq_length:]

  train_and_val_denoised_series = denoised_data_series[:n_train + n_val, :seq_length]
  train_denoised_series = train_and_val_denoised_series[:n_train]
  val_denoised_series = train_and_val_denoised_series[n_train:]
  test_denoised_series = denoised_data_series[n_train + n_val:, -seq_length:]

  datasets = train_series, val_series, test_series
  denoised_datasets = train_denoised_series, val_denoised_series, test_denoised_series
  return datasets, denoised_datasets

def extract_params_from_config_name(config_name: str) -> dict:
  """
  Extract parameters from a harmonic oscillator config name.

  Args:
    config_name: A string like 'ho_4_0p01_0p01'

  Returns:
    A dictionary with the extracted parameters:
    {
      'seq_length': int,
      'observation_noise': float,
      'process_noise': float
    }
  """
  # Split the config name by underscores
  parts = config_name.split('_')

  # The first part should be 'ho'
  if len(parts) != 4 or parts[0] != 'ho':
    raise ValueError(f"Invalid config name format: {config_name}")

  # Extract sequence length (second part)
  seq_length = int(parts[1])

  # Extract observation noise (third part)
  obs_noise_str = parts[2]
  observation_noise = float(obs_noise_str.replace('p', '.'))

  # Extract process noise (fourth part)
  proc_noise_str = parts[3]
  process_noise = float(proc_noise_str.replace('p', '.'))

  return dict(seq_length=seq_length,
              observation_noise=observation_noise,
              process_noise=process_noise)


def get_harmonic_oscillator_dataset(config: dict,
                                    experiment_identifier: ExperimentIdentifier,
                                  key: Optional[PRNGKeyArray] = None,
                                  train_val_test_split: Optional[Tuple[int, int, int]] = None,
                                  return_raw_data: bool = False) -> TimeSeries:
  config_name = experiment_identifier.config_name
  latent_sde_params = extract_params_from_config_name(config_name)

  dataset_config = config['dataset']
  dataset_name = dataset_config['name']
  observation_noise = dataset_config['data_noise_std']
  process_noise = dataset_config['data_latent_sigma']
  seq_length = dataset_config['seq_length']

  data_series = get_raw_harmonic_oscillator_data(observation_noise,
                                                 process_noise,
                                                 seq_length,
                                                 key=key)

  train_percent, val_percent, test_percent = train_val_test_split

  total_dataset_size = data_series.batch_size
  n_train = int(total_dataset_size * train_percent)
  n_val = int(total_dataset_size * val_percent)
  n_test = total_dataset_size - n_train - n_val

  # Shuffle the data
  idx = jax.random.permutation(key, jnp.arange(data_series.batch_size))
  data_series = data_series[idx]

  # Split the data into train and val
  train_and_val_series = data_series[:n_train + n_val]
  train_series = train_and_val_series[:n_train]
  val_series = train_and_val_series[n_train:]
  test_series = data_series[n_train + n_val:]

  datasets = train_series, val_series, test_series
  return datasets

def get_harmonic_oscillator_dataset2(config: dict,
                                    experiment_identifier: ExperimentIdentifier,
                                  key: Optional[PRNGKeyArray] = None,
                                  train_val_test_split: Optional[Tuple[int, int, int]] = None,
                                  return_raw_data: bool = False) -> TimeSeries:
  dataset_config = config['dataset']
  dataset_name = dataset_config['name']
  process_noise = dataset_config['data_latent_sigma']
  seq_length = dataset_config['seq_length']

  data_series = get_raw_harmonic_oscillator_data2(process_noise,
                                                 seq_length,
                                                 key=key)

  train_percent, val_percent, test_percent = train_val_test_split

  total_dataset_size = data_series.batch_size
  n_train = int(total_dataset_size * train_percent)
  n_val = int(total_dataset_size * val_percent)
  n_test = total_dataset_size - n_train - n_val

  # Shuffle the data
  idx = jax.random.permutation(key, jnp.arange(data_series.batch_size))
  data_series = data_series[idx]

  # Split the data into train and val
  train_and_val_series = data_series[:n_train + n_val]
  train_series = train_and_val_series[:n_train]
  val_series = train_and_val_series[n_train:]
  test_series = data_series[n_train + n_val:]

  datasets = train_series, val_series, test_series
  return datasets

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  import tqdm
  from diffusion_crf import *
  from diffusion_crf.sde.sde_base import max_likelihood_ltisde
  # turn on x64
  jax.config.update('jax_enable_x64', True)

  from Models.experiment_identifier import ExperimentIdentifier
  from main import load_empty_model, AbstractModel
  from diffusion_crf import TAGS
  # ei = ExperimentIdentifier.make_experiment_id(config_name='noisy_double_pendulum',
  #                                             objective=None,
  #                                             model_name='my_autoregressive',
  #                                             sde_type='tracking',
  #                                             freq=0,
  #                                             group='no_leakage_latent_forecasting',
  #                                             seed=0)
  ei = ExperimentIdentifier.make_experiment_id(config_name='lorenz',
                                              objective=None,
                                              model_name='my_autoregressive_rnn',
                                              sde_type='tracking',
                                              freq=0,
                                              group='asdf',
                                              seed=0)
  datasets = ei.get_data_fixed()
  train_data, val_data, test_data = datasets['train_data'], datasets['val_data'], datasets['test_data']
  series = test_data[0][:20]

  dummy_model = load_empty_model(ei)
  encoder = dummy_model.encoder


  key = random.PRNGKey(0)
  config = dict(dataset_size=1000, seq_length=64, sample_rate=8, noise_std=0.3)
  dataset1, _, dataset2 = get_improved_pendulum_dataset(config, key, return_vector_fields=True)
  dataset = TimeSeries(ts=dataset1.ts, yts=jnp.concatenate([dataset1.yts, dataset2.yts], axis=-1))
  import pdb; pdb.set_trace()

  dt = jnp.diff(dataset.ts, axis=-1).ravel()[0]
  sde = max_likelihood_ltisde(dataset.yts, dt)

  t0 = 0.0
  x0 = random.normal(key, (1, 4))
  initial_point = TimeSeries(ts=jnp.array([t0]), yts=x0)

  prob_series = encoder(initial_point)
  cond_sde = ConditionedLinearSDE(sde, prob_series, parallel=jax.devices()[0].platform == 'gpu')
  sampled_series = cond_sde.sample(key, jnp.linspace(0.0, 5.0, 1000))

  import pdb; pdb.set_trace()