import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, List
import einops
import equinox as eqx
from abc import ABC, abstractmethod
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from jax._src.util import curry
import pandas as pd
import numpy as np
from diffusion_crf import TimeSeries, HarmonicOscillator, GaussianTransition, OrnsteinUhlenbeck
from diffusion_crf.sde.sde_base import AbstractSDE, AbstractLinearTimeInvariantSDE
from diffusion_crf.matrix import DenseMatrix, DiagonalMatrix, TAGS, Diagonal2x2BlockMatrix
from diffusion_crf import AbstractLinearSDE, ConditionedLinearSDE, PaddingLatentVariableEncoderWithPrior


def create_harmonic_oscillator_data(key: PRNGKeyArray,
                                    *,
                                    observation_noise: float,
                                    process_noise: float,
                                    n_trajectories: int = 1_000_000,
                                    save_folder: str = 'Data/harmonic_oscillator/harmonic_oscillator_data'):
  """Save at an interval of 1000 points per trajectory"""

  observation_dim = 1
  t0 = 0.0
  tmax = 20.0

  freq = 1.0  # Harmonic Oscillator frequency
  coeff = 0.0 # No friction
  linear_sde = HarmonicOscillator(freq, coeff, process_noise, observation_dim)


  # Create the points that we will construct the raw data from
  # y = cos(t)
  ts = jnp.linspace(t0, tmax, 20)
  points = jnp.cos(ts)[:,None]
  data_gen_points = TimeSeries(ts=ts, yts=points)

  # Construct a potential for each point
  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=1, x_dim=2, sigma=observation_noise)
  prob_series = encoder(data_gen_points)

  # Construct the conditioned Harmonic Oscillator
  sde = ConditionedLinearSDE(linear_sde, prob_series)

  # Sample a trajectory from the SDE
  save_times = jnp.linspace(t0, tmax, 1000)

  # Sample a bunch of trajectories
  batch_size = 2048
  keys = random.split(key, n_trajectories)

  def sample_fn(key):
    xts = sde.sample(key, save_times)
    ts, yts = xts.ts, xts.yts
    return TimeSeries(ts=ts, yts=yts[...,:1])

  xts = jax.lax.map(sample_fn, keys, batch_size=batch_size)

  # Format floating point numbers consistently for filenames
  obs_noise_str = f"{observation_noise:.6f}".rstrip('0').rstrip('.')
  obs_noise_str = obs_noise_str.replace('.', '_')

  proc_noise_str = f"{process_noise:.6f}".rstrip('0').rstrip('.')
  proc_noise_str = proc_noise_str.replace('.', '_')

  # Save the data
  os.makedirs(save_folder, exist_ok=True)
  path_name = os.path.join(save_folder, f'harmonic_oscillator_data_obs_noise_{obs_noise_str}_process_noise_{proc_noise_str}.npz')
  print(f"Saving data to: {path_name}")
  np.savez(path_name, ts=xts.ts, yts=xts.yts)

def get_raw_harmonic_oscillator_data(observation_noise: float,
                                process_noise: float,
                                seq_length: Union[int, None],
                                key: Optional[PRNGKeyArray] = None,
                                save_folder: str = 'Data/harmonic_oscillator/harmonic_oscillator_data'):

  # Format floating point numbers consistently for filenames
  obs_noise_str = f"{observation_noise:.6f}".rstrip('0').rstrip('.')
  obs_noise_str = obs_noise_str.replace('.', '_')

  proc_noise_str = f"{process_noise:.6f}".rstrip('0').rstrip('.')
  proc_noise_str = proc_noise_str.replace('.', '_')

  save_path = os.path.join(save_folder, f'harmonic_oscillator_data_obs_noise_{obs_noise_str}_process_noise_{proc_noise_str}.npz')
  data = np.load(save_path)
  ts = jnp.array(data['ts'])
  yts = jnp.array(data['yts'])
  data_series = TimeSeries(ts=ts, yts=yts)

  print(f'Loaded data from {save_path}')

  if seq_length is None:
    return data_series

  # Subsample the data to get the desired time interval between observations
  indices = jnp.linspace(0, len(data_series[0]) - 1, seq_length, dtype=int)
  downsampled_data_series = data_series[:,indices]

  assert len(downsampled_data_series[0]) == seq_length, f"Downsampled data series has length {len(downsampled_data_series[0])} instead of {seq_length}"

  return downsampled_data_series

def get_experiment_param_ranges():

  # # Observation noise
  # observation_noise_range = jnp.array([0.01, 0.5, 1.0])

  # # Process noise
  # process_noise_range = jnp.array([0.0, 0.5, 1.0])

  # # Sequence lengths
  # seq_length_range = jnp.array([25, 50, 100, 200])


  # Observation noise
  observation_noise_range = jnp.array([0.01, 0.5, 1.0])
  # observation_noise_range = jnp.array([0.01, 0.5, 1.0])

  # Process noise
  process_noise_range = jnp.array([0.01, 0.25, 0.5])
  # process_noise_range = jnp.array([0.01, 0.25, 0.5, 0.75, 1.0])

  # Sequence lengths
  seq_length_range = jnp.array([4, 8, 16])
  # seq_length_range = jnp.array([25, 50, 100, 200])


  return dict(observation_noise_range=observation_noise_range,
              process_noise_range=process_noise_range,
              seq_length_range=seq_length_range)

if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  import tqdm
  from generax.util import tree_shapes
  # turn on x64
  # jax.config.update('jax_enable_x64', True)

  key = random.PRNGKey(0)
  experiment_param_ranges = get_experiment_param_ranges()

  for observation_noise in experiment_param_ranges['observation_noise_range']:
    for process_noise in experiment_param_ranges['process_noise_range']:
      create_harmonic_oscillator_data(key,
                                      observation_noise=observation_noise,
                                      process_noise=process_noise,
                                      n_trajectories=10_000)

      get_raw_harmonic_oscillator_data(observation_noise=observation_noise,
                                  process_noise=process_noise,
                                  seq_length=98,
                                  key=key)
