import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, List
import einops
import equinox as eqx
from abc import ABC, abstractmethod
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from jax._src.util import curry
import pandas as pd
import numpy as np
from diffusion_crf import TimeSeries, HarmonicOscillator, GaussianTransition, OrnsteinUhlenbeck
from diffusion_crf.sde.sde_base import AbstractSDE, AbstractLinearTimeInvariantSDE
from diffusion_crf.matrix import DenseMatrix, DiagonalMatrix, TAGS, Diagonal2x2BlockMatrix
from diffusion_crf import AbstractLinearSDE, ConditionedLinearSDE, PaddingLatentVariableEncoderWithPrior

def make_sample(sigma, key):
  observation_dim = 1
  linear_sde = HarmonicOscillator(freq=1.0, coeff=0.0, sigma=sigma, observation_dim=observation_dim)
  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=1, x_dim=2, sigma=0.01)

  ts = random.uniform(key, minval=-2.0, maxval=2.0, shape=(1,))
  points = jnp.array([[0.0]])

  points = TimeSeries(ts=ts, yts=points)
  prob_series = encoder(points)

  sde = ConditionedLinearSDE(linear_sde, prob_series)

  # Sample a trajectory from the SDE
  save_times = jnp.linspace(0.0, 20.0, 1000)
  xts = sde.sample(key, save_times)

  ts, xts_values = xts.ts, xts.yts

  xts_values = (xts_values - xts_values.mean(axis=0)) / xts_values.std(axis=0)
  return TimeSeries(ts=ts, yts=xts_values)

def create_harmonic_oscillator_data2(key: PRNGKeyArray,
                                    *,
                                    process_noise: float,
                                    n_trajectories: int = 10_000,
                                    save_folder: str = 'Data/harmonic_oscillator/harmonic_oscillator_data2'):
  """Save at an interval of 1000 points per trajectory"""

  # Sample a bunch of trajectories
  batch_size = 256
  keys = random.split(key, n_trajectories)

  def sample_fn(key):
    xts = make_sample(sigma=process_noise, key=key)
    ts, yts = xts.ts, xts.yts
    return TimeSeries(ts=ts, yts=yts[...,:1])

  xts = jax.lax.map(sample_fn, keys, batch_size=batch_size)

  # Format floating point numbers consistently for filenames
  proc_noise_str = f"{process_noise:.6f}".rstrip('0').rstrip('.')
  proc_noise_str = proc_noise_str.replace('.', '_')

  # Save the data
  os.makedirs(save_folder, exist_ok=True)
  path_name = os.path.join(save_folder, f'harmonic_oscillator_data_process_noise_{proc_noise_str}_2.npz')
  print(f"Saving data to: {path_name}")
  np.savez(path_name, ts=xts.ts, yts=xts.yts)

def get_raw_harmonic_oscillator_data2(process_noise: float,
                                seq_length: Union[int, None],
                                key: Optional[PRNGKeyArray] = None,
                                save_folder: str = 'Data/harmonic_oscillator/harmonic_oscillator_data2'):

  proc_noise_str = f"{process_noise:.6f}".rstrip('0').rstrip('.')
  proc_noise_str = proc_noise_str.replace('.', '_')

  save_path = os.path.join(save_folder, f'harmonic_oscillator_data_process_noise_{proc_noise_str}_2.npz')
  data = np.load(save_path)
  ts = jnp.array(data['ts'])
  yts = jnp.array(data['yts'])
  data_series = TimeSeries(ts=ts, yts=yts)

  print(f'Loaded data from {save_path}')

  if seq_length is None:
    return data_series

  # Subsample the data to get the desired time interval between observations
  if key is None:
    key = random.PRNGKey(0)
  indices = random.choice(key, len(data_series[0]), shape=(seq_length,), replace=False)
  indices = jnp.sort(indices)  # Sort to maintain temporal ordering
  downsampled_data_series = data_series[:,indices]

  assert len(downsampled_data_series[0]) == seq_length, f"Downsampled data series has length {len(downsampled_data_series[0])} instead of {seq_length}"

  return downsampled_data_series

def get_experiment_param_ranges():

  # Process noise
  process_noise_range = jnp.array([0.5])
  # process_noise_range = jnp.array([0.1])
  # process_noise_range = jnp.array([0.01, 0.25, 0.5, 1.0, 2.0, 4.0])
  # process_noise_range = jnp.array([0.01, 0.25, 0.5, 0.75, 1.0])

  # Sequence lengths
  # seq_length_range = jnp.array([2, 4, 6, 10, 15, 20, 30, 51, 60, 68, 85, 102])
  seq_length_range = jnp.array([4, 8, 16, 32, 64, 128])

  return dict(process_noise_range=process_noise_range,
              seq_length_range=seq_length_range)

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  import tqdm
  from generax.util import tree_shapes
  # turn on x64
  # jax.config.update('jax_enable_x64', True)

  key = random.PRNGKey(0)
  experiment_param_ranges = get_experiment_param_ranges()

  for process_noise in experiment_param_ranges['process_noise_range']:
    create_harmonic_oscillator_data2(key,
                                    process_noise=process_noise,
                                    n_trajectories=10_000)

    get_raw_harmonic_oscillator_data2(process_noise=process_noise,
                                  seq_length=98,
                                  key=key)
