import jax
import jax.numpy as jnp
from jax import random
from functools import partial
from typing import Optional, Mapping, Tuple, Sequence, Union, Any, Callable, Dict, List
import einops
import equinox as eqx
from abc import ABC, abstractmethod
import diffrax
from jaxtyping import Array, PRNGKeyArray
import jax.tree_util as jtu
import os
from jaxtyping import Array, PRNGKeyArray, Float, Scalar, Bool
from jax._src.util import curry
import pandas as pd
import numpy as np
from diffusion_crf.sde.ode_sde_solve import ODESolverParams, ode_solve
from diffusion_crf import TimeSeries

"""Taken from https://scipython.com/blog/the-double-pendulum/"""

@curry
def vector_field(L1, L2, m1, m2, t, xt):
  theta1, z1, theta2, z2 = xt[0], xt[1], xt[2], xt[3]

  g = 9.81
  c, s = jnp.cos(theta1-theta2), jnp.sin(theta1-theta2)

  theta1dot = z1
  z1dot = (m2*g*jnp.sin(theta2)*c - m2*s*(L1*z1**2*c + L2*z2**2) -
            (m1+m2)*g*jnp.sin(theta1)) / L1 / (m1 + m2*s**2)
  theta2dot = z2
  z2dot = ((m1+m2)*(L1*z1**2*s - g*jnp.sin(theta2) + g*jnp.sin(theta1)*c) +
            m2*L2*z2**2*s*c) / L2 / (m1 + m2*s**2)
  return jnp.array([theta1dot, z1dot, theta2dot, z2dot])

def create_raw_pendulum_data(key,
                             t0: float = 0.0,
                             t1: float = 100.0,
                             n_points: int = 1_000_000,
                             length1: float = 1.0,
                             length2: float = 1.0,
                             mass1: float = 1.0,
                             mass2: float = 1.0,
                             save_path: str = 'data/pendulum/pendulum_data.npz'):

  x0 = jnp.array([3*jnp.pi/7, 0, 3*jnp.pi/4, 0])
  vf = vector_field(length1, length2, mass1, mass2)
  # solver_params = ODESolverParams(stepsize_controller='pid',
  solver_params = ODESolverParams(stepsize_controller='constant',
                                  rtol=1e-10,
                                  atol=1e-10,
                                  max_steps=10_000,
                                  # max_steps=10_000_000,
                                  solver='euler',
                                  adjoint='direct',
                                  # solver='kvaerno5',
                                  progress_meter='tqdm')

  # Randomly sample save times
  burn_in = int(0.1*n_points)
  # save_time_intervals = random.uniform(key, (n_points + burn_in,), minval=0.0, maxval=1.0)
  save_time_intervals = jnp.ones(n_points + burn_in)*0.5 # We want this to be uniformly sampled!
  save_times2 = jnp.cumsum(save_time_intervals)*2 # Average interval is 1.0
  save_times = save_times2/n_points*(t1 - t0) + t0

  print('Solving the pendulum ODE...')
  yts = ode_solve(vf, x0, save_times, solver_params).yts
  print('Done!')
  ts = save_times[burn_in:]
  yts = yts[burn_in:]

  # Also compute the vector field at the times
  vf_values = jax.vmap(vf)(ts, yts)

  # Convert to cartesian coordinates
  def to_cartesian(thetas):
    theta1, theta2 = thetas[0], thetas[1]
    x1 = length1*jnp.sin(theta1)
    y1 = -length1*jnp.cos(theta1)
    x2 = x1 + length2*jnp.sin(theta2)
    y2 = y1 - length2*jnp.cos(theta2)
    return jnp.stack([x1, y1, x2, y2], axis=-1)

  def pushforward(yts, vts):
    return jax.jvp(to_cartesian, (yts[jnp.array([0,2])],), (vts[jnp.array([0,2])],))

  yts_cartesian, vfs_cartesian = jax.vmap(pushforward)(yts, vf_values)

  save_items = dict(ts=np.array(ts),
                    yts=np.array(yts_cartesian),
                    vfs=np.array(vfs_cartesian),
                    length1=length1,
                    length2=length2,
                    mass1=mass1,
                    mass2=mass2)

  # Save the data to a file
  print(f"Saving data to: {save_path}")
  np.savez(save_path, **save_items)


def get_raw_pendulum_data(num_samples: int = 10_000,
                          noise_std: float = 0.0,
                          noisy: bool = False,
                          key: Optional[PRNGKeyArray] = None):

  # Load the data from the file
  save_path = f'Data/pendulum/pendulum_data.npz'
  data = np.load(save_path)

  ts, yts, observation_mask, max_ts = data['ts'], data['yts'], data['observation_mask'], data['max_ts']

  if noisy:
    assert key is not None, 'key must be provided if noisy is True'
    yts = yts + noise_std*random.normal(key, yts.shape)

  data = TimeSeries(ts=jnp.array(ts), yts=jnp.array(yts), observation_mask=jnp.array(observation_mask))

  return np.array(data[-num_samples:].yts)

def get_raw_improved_pendulum_data(max_length: Optional[int] = None,
                                   sample_rate: Optional[int] = 100,
                                   noise_std: Optional[float] = None,
                                   key: Optional[PRNGKeyArray] = None,
                                   save_path: str = 'Data/pendulum/pendulum_50_000_seconds_64_hz.npz'):
  """
  Get the raw improved pendulum data.

  Args:
    max_length: The maximum length of the data to return
    sample_rate: The sample rate of the data (in Hz).
  """

  # Load the data from the file
  data = np.load(save_path)

  ts, yts, vfs = data['ts'], data['yts'], data['vfs']

  full_original_series = TimeSeries(ts=jnp.array(ts), yts=jnp.array(yts))
  full_vfs_series = TimeSeries(ts=jnp.array(ts), yts=jnp.array(vfs))

  # Downsample the data.  The original data is at 100 Hz, and we want to downsample to sample_rate Hz
  # Downsample from 100 Hz to sample_rate Hz
  downsample_factor = 100 // sample_rate
  full_original_series = full_original_series[::downsample_factor]
  full_vfs_series = full_vfs_series[::downsample_factor]

  # Take the last max_length
  if max_length is not None:
    full_original_series = full_original_series[-max_length:]
    full_vfs_series = full_vfs_series[-max_length:]

  # Add noise to the data if noise_std is not None
  ts = full_original_series.ts
  samples = full_original_series.yts
  if noise_std is not None:
    assert key is not None, 'key must be provided if noise_std is not None'
    noisy_samples = samples + noise_std*random.normal(key, samples.shape)
  else:
    noisy_samples = samples
  full_data_series = TimeSeries(ts=ts, yts=noisy_samples)

  return full_data_series, full_original_series, full_vfs_series

################################################################################################################

def create_pendulum_gif(series,
                        index=0,
                        output_path='pendulum.gif',
                        fps=30,  # Set default to standard 30fps
                        time_scale=1.0,  # Speed multiplier
                        trail_secs=1,
                        length1=1.0,
                        length2=1.0,
                        vf_series=None,
                        arrow_scale=0.1,
                        show_time=False):  # Option to display time
  """
  Create a gif animation of a double pendulum from a TimeSeries object.

  Args:
    series: TimeSeries object containing pendulum state
    index: Index if the series is batched
    output_path: Path where the gif will be saved
    fps: Target frames per second for the output
    time_scale: Time speed multiplier (2.0 = twice as fast)
    trail_secs: Length of the trailing effect in seconds
    length1: Length of the first pendulum rod
    length2: Length of the second pendulum rod
    vf_series: Optional TimeSeries object containing velocity vectors
    arrow_scale: Scaling factor for velocity arrows
    show_time: Whether to display current time in the animation
  """
  import matplotlib.pyplot as plt
  import numpy as np
  from matplotlib.patches import Circle
  from matplotlib.animation import FuncAnimation

  # Handle batched series if necessary
  if series.batch_size is not None:
    series = series[index]

  # Handle velocity field series if provided
  has_velocity = vf_series is not None
  if has_velocity and vf_series.batch_size is not None:
    vf_series = vf_series[index]

  # Extract data from TimeSeries
  ts = np.array(series.ts)

  # Calculate time step and total time span
  total_time_span = ts[-1] - ts[0]

  # If we have too many frames for smooth playback, we need to sample
  n_frames = len(ts)
  target_frames = int(fps * (total_time_span / time_scale))

  if n_frames > target_frames:
    # Need to sample frames to match desired speed
    sample_indices = np.linspace(0, n_frames-1, target_frames, dtype=int)
    ts = ts[sample_indices]
    yts = np.array(series.yts)[sample_indices]
    if has_velocity:
      vf_data = np.array(vf_series.yts)[sample_indices]
  else:
    # Use all frames, but might need to adjust interval
    yts = np.array(series.yts)
    if has_velocity:
      vf_data = np.array(vf_series.yts)

  # Calculate actual fps for the animation
  actual_fps = fps

  # Extract position data
  x1, y1, x2, y2 = yts[:, 0], yts[:, 1], yts[:, 2], yts[:, 3]

  # Extract velocity data if provided
  if has_velocity:
    vx2, vy2 = vf_data[:, 2], vf_data[:, 3]  # Velocity of second pendulum bob
    v_mags = np.sqrt(vx2**2 + vy2**2)
    max_v_mag = np.max(v_mags)
    median_v_mag = np.median(v_mags)

  # Plotted bob circle radius
  r = 0.05 * max(length1, length2)

  # Maximum trail points (adjusted for time_scale)
  dt = (ts[1] - ts[0]) if len(ts) > 1 else 0.01
  max_trail = int(trail_secs / dt)

  # Create figure
  fig, ax = plt.subplots(figsize=(8, 8))

  def init():
    ax.set_xlim(-length1-length2-r, length1+length2+r)
    ax.set_ylim(-length1-length2-r, length1+length2+r)
    ax.set_aspect('equal', adjustable='box')
    plt.axis('off')
    return []

  def animate(i):
    ax.clear()
    ax.set_xlim(-length1-length2-r, length1+length2+r)
    ax.set_ylim(-length1-length2-r, length1+length2+r)
    ax.set_aspect('equal', adjustable='box')
    plt.axis('off')

    # Draw pendulum rods
    ax.plot([0, x1[i], x2[i]], [0, y1[i], y2[i]], lw=2, c='k')

    # Draw circles
    c0 = Circle((0, 0), r/2, fc='k', zorder=10)
    c1 = Circle((x1[i], y1[i]), r, fc='b', ec='b', zorder=10)
    c2 = Circle((x2[i], y2[i]), r, fc='r', ec='r', zorder=10)
    ax.add_patch(c0)
    ax.add_patch(c1)
    ax.add_patch(c2)

    # Draw trail with fading effect
    ns = 20  # Number of segments
    s = max(1, max_trail // ns)

    for j in range(ns):
      imin = i - (ns-j)*s
      if imin < 0:
        continue
      imax = min(i+1, imin + s + 1)
      alpha = (j/ns)**2
      ax.plot(x2[imin:imax], y2[imin:imax], c='r', solid_capstyle='butt',
              lw=2, alpha=alpha)

    # Display current time if requested
    if show_time:
      ax.text(0.02, 0.02, f"t = {ts[i]:.2f}s", transform=ax.transAxes,
              fontsize=10, bbox=dict(facecolor='white', alpha=0.7))

    # Draw velocity vector if provided
    if has_velocity:
      v_mag = v_mags[i]
      if v_mag > 0:
        # Scale arrow based on velocity magnitude
        scale_factor = arrow_scale * min(1.0, 2.0 * median_v_mag / max_v_mag)

        # Create a reasonable maximum arrow length relative to pendulum size
        max_arrow_length = 0.25 * (length1 + length2)

        # Scale the velocity vector
        arrow_dx = vx2[i] * scale_factor
        arrow_dy = vy2[i] * scale_factor

        # Cap the arrow length
        arrow_length = np.sqrt(arrow_dx**2 + arrow_dy**2)
        if arrow_length > max_arrow_length:
          arrow_dx = arrow_dx * max_arrow_length / arrow_length
          arrow_dy = arrow_dy * max_arrow_length / arrow_length

        # Draw arrow from second bob position
        ax.arrow(x2[i], y2[i],
                arrow_dx, arrow_dy,
                head_width=0.05, head_length=0.1,
                fc='g', ec='g',
                length_includes_head=True,
                zorder=11)

    return []

  # Calculate frame interval in milliseconds
  interval = 1000 / actual_fps

  # Create animation
  anim = FuncAnimation(fig, animate, frames=len(ts),
                      init_func=init, blit=True, interval=interval)

  # Save as gif
  anim.save(output_path, writer='pillow', fps=actual_fps, dpi=80)
  plt.close(fig)

  print(f"Animation saved to {output_path}")
  print(f"Original time span: {total_time_span:.2f} seconds")
  print(f"Animation duration: {total_time_span/time_scale:.2f} seconds at {actual_fps:.1f} fps")
  print(f"Speed multiplier: {time_scale:.1f}x")

  return output_path





if __name__ == '__main__':
  import matplotlib.pyplot as plt
  from debug import *
  import tqdm
  from diffusion_crf import *
  # turn on x64
  jax.config.update('jax_enable_x64', True)

  key = random.PRNGKey(0)
  T = 50_000
  sampling_freq = 64 # How many times per second we sample
  _ = create_raw_pendulum_data(key,
                                  t0=0.0,
                                  t1=T,
                                  n_points=sampling_freq*T,
                                  save_path='data/pendulum/pendulum_50_000_seconds_64_hz.npz')

  exit()

  seq_len = 64
  sample_rate = 8
  noise_std = 0.3
  out = get_raw_improved_pendulum_data(max_length=seq_len,
                                       sample_rate=sample_rate,
                                       noise_std=noise_std,
                                       key=key,
                                       save_path='data/pendulum/pendulum_50_000_seconds_64_hz.npz')
  full_data_series, full_original_series, full_vfs_series = out
  full_ground_truth_yts = jnp.concatenate([full_original_series.yts, full_vfs_series.yts], axis=-1)
  full_ground_truth_series = TimeSeries(ts=full_original_series.ts, yts=full_ground_truth_yts)

  full_observed_yts = jnp.concatenate([full_data_series.yts, jnp.zeros_like(full_vfs_series.yts)], axis=-1)
  full_observed_series = TimeSeries(ts=full_data_series.ts, yts=full_observed_yts)

  # import pdb; pdb.set_trace()

  # create_pendulum_gif(full_original_series,
  #                     vf_series=full_vfs_series,
  #                     output_path='ground_truth.gif',
  #                     time_scale=0.1)

  # Try inferring the velocity and smoothed position using a tracking model
  sigmas = [4.0]
  # sigmas = jnp.linspace(2.0, 6.0, 10)
  all_series = [full_observed_series, full_ground_truth_series]
  titles = ['observed', 'true']
  for sigma in sigmas:
    y_dim = full_data_series.yts.shape[-1]
    sde = HigherOrderTrackingModel(sigma=sigma, position_dim=y_dim, order=2)
    encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim, x_dim=sde.dim, sigma=noise_std)

    prob_series = encoder(full_data_series)
    conditioned_sde = ConditionedLinearSDE(sde, prob_series)
    sampled_series = conditioned_sde.sample(key, full_data_series.ts)
    all_series.append(sampled_series)
    titles.append(f'sigma={sigma}')

  TimeSeries.plot_multiple_series(all_series,
                                  titles=titles,
                                  marker_size=3,
                                  width_per_series=3,
                                    height_per_dim=1.5)


  # import pdb; pdb.set_trace()
  pred_position_series = TimeSeries(ts=sampled_series.ts, yts=sampled_series.yts[..., :y_dim])
  pred_velocity_series = TimeSeries(ts=sampled_series.ts, yts=sampled_series.yts[..., y_dim:])

  # create_pendulum_gif(pred_position_series,
  #                     vf_series=pred_velocity_series,
  #                     output_path='filtered_behavior.gif',
  #                     time_scale=0.1)



  import pdb; pdb.set_trace()