from debug import *
import os
import argparse
import numpy as np
from Utils.io_utils import load_yaml_config
from diffusion_crf import *
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int, PRNGKeyArray, Scalar, Bool
import jax.random as random
import jax.tree_util as jtu
import optax
from Models.trainer import Trainer as JaxTrainer, TrainingState
import jax.numpy as jnp
import optax
import pickle
import equinox as eqx
import json
from jax._src.util import curry
import tqdm
from main import get_dataset_no_leakage, load_jax_model
from Models.experiment_identifier import ExperimentIdentifier
from Models.empirical_metrics import wasserstein2_distance, compute_univariate_metrics
from Models.models.base import AbstractModel
from Utils.discriminative_metric_jax import discriminative_score_metrics
import pandas as pd
import numpy as np
import os
import datetime
import filelock
import time
import glob
import wadler_lindig as wl
import hashlib

def get_discretization_info(ts: Float[Array, 'T'], freq: int):
  """Get the discretization info for the data.  We do a uniform interpolation of
  the input times to get the new times."""
  assert ts.ndim == 1
  dts = jnp.diff(ts)
  dts = jnp.concatenate([dts[:1], dts])
  offsets = dts[:,None]*jnp.arange(-freq, 1)/(freq + 1)
  new_times = ts[:,None] + offsets
  out = new_times[...,:-1].ravel()

  # We don't want to include points outside of the observed times
  out = out[freq:]

  info = DiscretizeInfo(out, ts)
  return info

def get_random_discretization_info(key: PRNGKeyArray, ts: Float[Array, 'T'], freq: int) -> DiscretizeInfo:
  def get_random_new_times(key):
    offsets = random.uniform(key, (ts.shape[0],)) - 1.0
    return ts[1:] + offsets[:-1]*(ts[1:] - ts[:-1])

  keys = random.split(key, freq)
  all_new_times = jax.vmap(get_random_new_times)(keys)
  new_times = jnp.sort(all_new_times.ravel())
  info = DiscretizeInfo(new_times=new_times, base_times=ts)
  return info

def get_batched_intermediate_series(model: AbstractModel, key: PRNGKeyArray, yts: TimeSeries, freq: int, debug: bool = False):
  # Construct the linear SDE
  prob_series = model.encoder(yts)
  cond_sde = ConditionedLinearSDE(model.linear_sde, prob_series)
  L = model.linear_sde.L

  # Get the discretization info
  # info = get_discretization_info(yts.ts, freq)
  info = get_random_discretization_info(key, yts.ts, freq)

  # Discretize the SDE at the times in info
  result = cond_sde.discretize(info.ts, info=info)
  crf = result.crf

  # Sample all of the intermediate times and values
  xts_values = crf.sample(key)
  xts_new_values = info.filter_new_times(xts_values)
  xts_base_values = info.filter_base_times(xts_values)

  xts_new = TimeSeries(info.new_times, xts_new_values)
  xts_base = TimeSeries(info.base_times, xts_base_values)

  # Rebatch the intermediate xts
  batched_xts_new = jtu.tree_map(lambda x: einops.rearrange(x, '(T F) ... -> F T ...', F=freq), xts_new)

  # Get the true backward + node messages
  bwd = crf.get_backward_messages()
  beta = (bwd + crf.node_potentials).to_mixed()
  h = beta.to_nat().h

  # Get the covariances at the intermediate times
  Jt_beta = info.filter_new_times(beta.J)
  ht_beta = info.filter_new_times(h)
  LLT = L@L.T

  def matmul(A, b):
    return A@b
  LLT_ht_beta = jax.vmap(matmul, in_axes=(None, 0))(LLT, ht_beta)

  if debug:
    import pdb; pdb.set_trace()

  return xts_base, batched_xts_new, Jt_beta, LLT, LLT_ht_beta, info

def get_mean_and_h(model: AbstractModel, yts: TimeSeries, key: PRNGKeyArray, freq: int, debug: bool = False):
  # Get the intermediate xts
  model_freq = model.interpolation_freq
  out = get_batched_intermediate_series(model, key, yts, model_freq*freq, debug=debug)
  xts, batched_intermediate_xts, intermediate_Jt_beta, LLT, LLT_ht_beta, info = out

  def reshape_and_trim(x):
    x_reshaped = jtu.tree_map(lambda x: einops.rearrange(x, '(T F) ... -> F T ...', F=model_freq*freq), x)
    x_trimmed = x_reshaped[:,model.latent_generation_start_index:]
    return jtu.tree_map(lambda x: einops.rearrange(x, 'F T ... -> (T F) ...', F=model_freq*freq), x_trimmed)

  LLT_ht_beta = reshape_and_trim(LLT_ht_beta)
  intermediate_Jt_beta = reshape_and_trim(intermediate_Jt_beta)

  # Split the intermediate xts
  bix = batched_intermediate_xts[:,model.latent_generation_start_index:]
  x = xts[model.latent_generation_start_index:]

  batched_intermediate_betat = jax.vmap(model.predict_current_backward_messages, in_axes=(0, None, None))(bix, yts, x)
  # batched_intermediate_betat = jax.vmap(model.predict_current_backward_messages, in_axes=(0, None, None))(batched_intermediate_xts, yts, xts)

  # Unbatch the intermediate xts and betat
  intermediate_betat = jtu.tree_map(lambda x: einops.rearrange(x, 'F T ... -> (T F) ...', F=freq), batched_intermediate_betat)

  def matmul(J: AbstractMatrix, mu: Float[Array, 'D']):
    if J.batch_size is not None:
      return jax.vmap(matmul)(J, mu)
    return LLT@J@mu

  intermediate_mut = intermediate_betat.mu
  intermediate_LLT_ht = matmul(intermediate_Jt_beta, intermediate_mut)

  if debug:
    import pdb; pdb.set_trace()

  return intermediate_mut, intermediate_LLT_ht, LLT_ht_beta


def get_extended_log_prob(model: AbstractModel,
                                          yts: TimeSeries,
                                          key: PRNGKeyArray,
                                          freq: int,
                                          debug: bool = False):

  # Construct the CRF for the observed series and precompute the forward and backward messages
  crf_state = model.make_crf_state(yts)
  crf: CRF = crf_state.crf
  messages: Messages = crf_state.messages
  info: DiscretizeInfo = crf_state.discretization_info

  # Sample from p(x_{1:N} | Y_{1:N})
  xts_values = crf.sample(key, messages=messages)
  xts = TimeSeries(info.ts, xts_values)
  xts_generation_buffer = xts[model.latent_generation_start_index:] # Get x_{l:N}

  # Compute the true smoothed transitions, p(x_{i+1} | x_{l:i}, Y_{1:N})
  # and condition on the previous latent variable to get the next distribution
  bwd = crf_state.messages.bwd
  bwd_and_node = bwd + crf.node_potentials
  true_next_bwd_and_node = bwd_and_node[model.latent_generation_start_index + 1:]
  assert true_next_bwd_and_node.batch_size == model.generation_len - 1

  # Predict our models next backward messages, beta_{t_{k+1}}(x_{1:k})
  predicted_next_bwd_and_node = model.predict_next_backward_messages(yts, xts_generation_buffer)
  tkp1 = xts_generation_buffer.ts[1:]

  # Now upsample the generation buffer by a factor of freq
  info2 = get_random_discretization_info(key, xts_generation_buffer.ts, freq)
  batched_intermediate_times = info2.new_times.reshape((-1, freq))

  def get_continuous_extension(t, tkp1, bwd):
    base_transition = model.linear_sde.get_transition_distribution(t, tkp1)
    return base_transition.update_and_marginalize_out_y(bwd)

  out = jax.vmap(jax.vmap(get_continuous_extension), in_axes=(1, None, None), out_axes=1)(batched_intermediate_times, tkp1, predicted_next_bwd_and_node)

  def reshape(x):
    return einops.rearrange(x, 'T F ... -> (T F) ...', F=freq)

  all_intermediate_bwd_messages = jtu.tree_map(reshape, out)

  # We need to prepend a dummy backward message for the first variable (even though we will discard it)
  zero = predicted_next_bwd_and_node[0].zeros_like(predicted_next_bwd_and_node[0])
  dummy_bwd_message = jtu.tree_map(lambda x, y: jnp.concatenate([x[None], y], axis=0), zero, predicted_next_bwd_and_node)

  # Interleave with the predicted next backward messages
  all_bwd_messages = info2.interleave(all_intermediate_bwd_messages, dummy_bwd_message)[1:]
  all_bwd_message_ts = info2.ts[1:]

  # Sample at all of the times
  prob_series = model.encoder(yts)
  cond_sde = ConditionedLinearSDE(model.linear_sde, prob_series)
  k1, k2 = random.split(key)
  all_xts = cond_sde.sample(k1, info2.ts)

  # Compute the log likelihood of all_xts
  def log_likelihood(tk, tkp1, x_tk, x_tkp1, beta_tkp1):
    phi_tk_tkp1 = model.linear_sde.get_transition_distribution(tk, tkp1)
    conditioned_transition = phi_tk_tkp1.unnormalized_update_y(beta_tkp1)
    predictive_distribution = conditioned_transition.condition_on_x(x_tk)
    log_prob = predictive_distribution.log_prob(x_tkp1)
    return conditioned_transition, predictive_distribution, log_prob

  tk = info2.ts[:-1]
  tkp1 = info2.ts[1:]
  x_tk = all_xts.yts[:-1]
  x_tkp1 = all_xts.yts[1:]
  beta_tkp1 = all_bwd_messages
  conditioned_transition, predictive_distribution, log_prob = jax.vmap(log_likelihood)(tk, tkp1, x_tk, x_tkp1, beta_tkp1)
  return log_prob


def single_bwd_extension_evaluation(experiment_identifier: ExperimentIdentifier, random_key_seed: int):

  key = random.PRNGKey(random_key_seed)

  # Load the data and split it into train, validation, and test
  datasets = experiment_identifier.get_data_fixed()
  train_data, _, test_data = datasets['train_data'], datasets['val_data'], datasets['test_data']

  # Load and train the models if they are not done training
  from main import run_trial
  if experiment_identifier.training_is_complete() or True:
    train_state = experiment_identifier.get_train_state()
  else:
    train_state = run_trial(experiment_identifier, train_if_needed=True, retrain=True)
  model: AbstractModel = train_state.model # train_state.best_model

  """
  We want to take the same number of total samples to estimate the full path KL divergence.
  """

  freq = 4 # Upsample by 4x

  @curry
  def filled_extension_fn(model, freq, key_and_series):
    key, series = key_and_series
    return get_extended_log_prob(model, series, key, freq)

  batch_size = 256
  keys = random.split(key, test_data.batch_size)
  log_probs = jax.lax.map(filled_extension_fn(model, freq), (keys, test_data), batch_size=batch_size)

  import pdb; pdb.set_trace()

  # Get the model info
  config = experiment_identifier.create_config()
  latent_sigma = config['dataset']['tracking_sigma']
  sde_type = config['command_line_args']['sde_type']
  model_freq = config['command_line_args']['freq']
  dataset_name = config['dataset']['name']

  outputs = dict(kl_div_matching=float(kl_div_matching),
                 latent_sigma=float(latent_sigma),
                 sde_type=sde_type,
                 model_freq=int(model_freq),
                 dataset_name=dataset_name,
                 random_key_seed=int(random_key_seed))

  # Create save directory if it doesn't exist
  os.makedirs(save_dir, exist_ok=True)

  # Create a filename with experiment identifiers
  sorted_items = json.dumps(outputs, sort_keys=True)
  json_name = hashlib.md5(sorted_items.encode()).hexdigest()
  filename = f"{save_dir}/{json_name}_run_single.json"

  # Save the results to a JSON file
  with open(filename, 'w') as f:
    json.dump(outputs, f, indent=2)

  print(f"Results saved to {filename}")

  return pd.Series(outputs)

def backward_extension_evaluation(experiment_identifier_ar: ExperimentIdentifier,
                                  experiment_identifier_sde: ExperimentIdentifier,
                                  *,
                                  random_key_seed: int,
                                  save_dir: str = "Models/dynamic_latent_size_models/ho_models/results") -> pd.Series:
  """Run the experiments for the trained model with the improved batch-size invariant implementation.

  This function uses the new index-based checkpointing system that allows for variable batch sizes.
  """
  key = random.PRNGKey(random_key_seed)

  # Load the data and split it into train, validation, and test
  datasets_ar = experiment_identifier_ar.get_data_fixed()
  datasets_sde = experiment_identifier_sde.get_data_fixed()
  train_data_ar, _, test_data = datasets_ar['train_data'], datasets_ar['val_data'], datasets_ar['test_data']
  train_data_sde, _, _ = datasets_sde['train_data'], datasets_sde['val_data'], datasets_sde['test_data']

  # Load and train the models if they are not done training
  from main import run_trial
  train_state_ar = run_trial(experiment_identifier_ar, train_if_needed=True, retrain=True)
  model_ar: AbstractModel = train_state_ar.model # train_state_ar.best_model

  train_state_sde = run_trial(experiment_identifier_sde, train_if_needed=True, retrain=True)
  model_sde: AbstractModel = train_state_sde.model # train_state_sde.best_model

  """
  We want to take the same number of total samples to estimate the full path KL divergence.
  """
  n_total_random_steps = 1024 # How many Monte carlo samples on the time domain to use
  # n_total_random_steps = 1020 # This number is chosen on purpose for our experiment sequence lengths!
  ar_len = len(train_data_ar[0])
  sde_len = len(train_data_sde[0])

  ar_freq = n_total_random_steps // ar_len
  sde_freq = n_total_random_steps // sde_len

  @curry
  def filled_extension_fn(model, freq, key_and_series):
    key, series = key_and_series
    return get_mean_and_h(model, series, key, freq)

  batch_size = 64
  keys = random.split(key, test_data.batch_size)
  intermediate_mut_ar, intermediate_LLT_ht_ar, LLT_ht_beta_ar = jax.lax.map(filled_extension_fn(model_ar, ar_freq), (keys, test_data), batch_size=batch_size)
  intermediate_mut_sde, intermediate_LLT_ht_sde, _ = jax.lax.map(filled_extension_fn(model_sde, sde_freq), (keys, test_data), batch_size=batch_size)

  mu_diff = intermediate_mut_sde - intermediate_mut_ar
  mse = jnp.sum(mu_diff**2, axis=-1).mean()

  h_diff = intermediate_LLT_ht_sde - intermediate_LLT_ht_ar
  kl_div = 0.5*jnp.sum(h_diff**2, axis=-1).mean()


  h_diff = LLT_ht_beta_ar - intermediate_LLT_ht_ar
  kl_div_matching = 0.5*jnp.sum(h_diff**2, axis=-1).mean()


  # Get the model info
  ar_config = experiment_identifier_ar.create_config()
  sde_config = experiment_identifier_sde.create_config()
  ar_latent_sigma = ar_config['dataset']['data_latent_sigma']
  sde_latent_sigma = sde_config['dataset']['data_latent_sigma']
  ar_sde_type = ar_config['command_line_args']['sde_type']
  sde_sde_type = sde_config['command_line_args']['sde_type']

  outputs = dict(mse=float(mse),
                 kl_div=float(kl_div),
                 kl_div_matching=float(kl_div_matching),
                 ar_len=int(ar_len),
                 sde_len=int(sde_len),
                 ar_latent_sigma=float(ar_latent_sigma),
                 sde_latent_sigma=float(sde_latent_sigma),
                 ar_sde_type=ar_sde_type,
                 sde_sde_type=sde_sde_type,
                 random_key_seed=int(random_key_seed))

  # Create save directory if it doesn't exist
  os.makedirs(save_dir, exist_ok=True)

  # Create a filename with experiment identifiers
  sorted_items = json.dumps(outputs, sort_keys=True)
  json_name = hashlib.md5(sorted_items.encode()).hexdigest()
  filename = f"{save_dir}/{json_name}_run2.json"

  # Save the results to a JSON file
  with open(filename, 'w') as f:
    json.dump(outputs, f, indent=2)

  print(f"Results saved to {filename}")

  return pd.Series(outputs)


def parse_args():
  import argparse
  parser = argparse.ArgumentParser(description="")
  parser.add_argument("--config_name", type=str, help="Config name")
  parser.add_argument("--sde", type=str, help="SDE type")
  parser.add_argument("--model_freq", type=int, help="Model frequency")
  parser.add_argument("--random_key_seed", type=int, help="Random key")


  parser.add_argument("--load_results", action="store_true", help="Load results from file")
  parser.add_argument("--other_evaluation", action="store_true", help="Run the other evaluation")

  # Parse arguments
  args = parser.parse_args()

  return args

BWD_EVAL_SAVE_DIR = "Models/dynamic_latent_size_models/ho_models/results"


if __name__ == '__main__':
  from debug import *
  import matplotlib.pyplot as plt
  import pickle
  from diffusion_crf.sde import *
  from diffusion_crf.ssm.simple_encoder import PaddingLatentVariableEncoderWithPrior
  from diffusion_crf.ssm.simple_decoder import PaddingLatentVariableDecoder
  from diffusion_crf.sde.conditioned_linear_sde import ConditionedLinearSDE
  from diffusion_crf.sde.langevin_dynamics import CriticallyDampedLangevinDynamics
  from diffusion_crf.sde.sde_base import TimeScaledLinearTimeInvariantSDE
  from diffusion_crf.neural_diffusion_crf.interpolate_potential import InterpolateResult, initialize_potential_interpolation
  import matplotlib.pyplot as plt
  from Models.experiment_identifier import ExperimentIdentifier
  from main import load_empty_model
  from diffusion_crf import TAGS

  args = parse_args()

  save_dir = BWD_EVAL_SAVE_DIR

  if args.other_evaluation:
    config_name = args.config_name

    ei_ar = ExperimentIdentifier.make_experiment_id(config_name=config_name,
                                              objective='mse',
                                              model_name='my_autoregressive_reparam_rnn_bwd',
                                              sde_type='tracking',
                                              freq=args.model_freq,
                                              group='final_models',
                                              seed=args.random_key_seed)

    results = single_bwd_extension_evaluation(ei_ar, args.random_key_seed)

  elif args.load_results == False:
    config_name = args.config_name
    sde = args.sde

    ei_ar = ExperimentIdentifier.make_experiment_id(config_name=config_name,
                                              objective='mse',
                                              model_name='my_autoregressive_reparam_rnn_bwd',
                                              sde_type=sde,
                                              freq=0,
                                              group='harmonic_oscillator',
                                              seed=0)
    ei_sde = ExperimentIdentifier.make_experiment_id(config_name=config_name,
                                              objective='mse',
                                              model_name='my_neural_sde_rnn_bwd',
                                              sde_type=sde,
                                              freq=1,
                                              group='harmonic_oscillator',
                                              seed=0)

    results = backward_extension_evaluation(ei_ar, ei_sde, random_key_seed=args.random_key_seed, save_dir=save_dir)


  elif args.load_results:
    # Find all result files in the save directory
    result_files = glob.glob(f"{save_dir}/*.json")
    result_files = [f for f in result_files if "combined_results" not in f]

    print(f"Found {len(result_files)} result files")

    all_results = {}
    result_data = []  # List to hold dictionaries for DataFrame conversion

    for file_path in result_files:
      with open(file_path, 'r') as f:
        results = json.load(f)
        # Extract config name from filename
        filename = os.path.basename(file_path)

        if 'random_key_seed' not in results:
          continue

        # Store the results with filename as key
        all_results[filename] = results
        result_data.append(results)

        print(f"Loaded results from {file_path}")

    # Create a pandas DataFrame from the collected results
    results_df = pd.DataFrame(result_data)

    # Print summary of the DataFrame
    print("\nDataFrame summary:")
    print(f"Shape: {results_df.shape}")
    print(f"Columns: {results_df.columns.tolist()}")

    # Display the first few rows
    print("\nFirst few rows:")
    print(results_df.head())


    import seaborn as sns
    import matplotlib.pyplot as plt
    import numpy as np

    # Set the style
    sns.set(style="whitegrid", palette="muted", font_scale=1.2)

    # Create a figure with multiple subplots
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))

    # 1. Plot MSE vs latent_sigma, grouped by sequence length and SDE type
    sns.lineplot(
        data=results_df,
        x="ar_latent_sigma",
        y="mse",
        hue="ar_len",
        style="ar_sde_type",
        markers=True,
        dashes=True,
        ax=axes[0, 0]
    )
    axes[0, 0].set_title("MSE vs Latent Sigma")
    axes[0, 0].set_xlabel("Latent Sigma")
    axes[0, 0].set_ylabel("MSE")
    axes[0, 0].set_xscale("log")
    axes[0, 0].legend(title="Sequence Length / SDE Type")

    # 2. Plot KL divergence vs latent_sigma, grouped by sequence length and SDE type
    sns.lineplot(
        data=results_df,
        x="ar_latent_sigma",
        y="kl_div",
        hue="ar_len",
        style="ar_sde_type",
        markers=True,
        dashes=True,
        ax=axes[0, 1]
    )
    axes[0, 1].set_title("KL Divergence vs Latent Sigma")
    axes[0, 1].set_xlabel("Latent Sigma")
    axes[0, 1].set_ylabel("KL Divergence")
    axes[0, 1].set_xscale("log")
    axes[0, 1].legend(title="Sequence Length / SDE Type")

    # 3. Plot MSE vs sequence length, grouped by latent_sigma and SDE type
    sns.lineplot(
        data=results_df,
        x="ar_len",
        y="mse",
        hue="ar_latent_sigma",
        style="ar_sde_type",
        markers=True,
        dashes=True,
        ax=axes[1, 0]
    )
    axes[1, 0].set_title("MSE vs Sequence Length")
    axes[1, 0].set_xlabel("Sequence Length")
    axes[1, 0].set_ylabel("MSE")
    axes[1, 0].set_xscale("log", base=2)
    axes[1, 0].legend(title="Latent Sigma / SDE Type")

    # 4. Plot KL divergence vs sequence length, grouped by latent_sigma and SDE type
    sns.lineplot(
        data=results_df,
        x="ar_len",
        y="kl_div",
        hue="ar_latent_sigma",
        style="ar_sde_type",
        markers=True,
        dashes=True,
        ax=axes[1, 1]
    )
    axes[1, 1].set_title("KL Divergence vs Sequence Length")
    axes[1, 1].set_xlabel("Sequence Length")
    axes[1, 1].set_ylabel("KL Divergence")
    axes[1, 1].set_xscale("log", base=2)
    axes[1, 1].legend(title="Latent Sigma / SDE Type")

    # Adjust layout
    plt.tight_layout()
    plt.show()

    # Create separate plots for each SDE type for better visibility
    for sde_type in results_df['ar_sde_type'].unique():
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))

        # Filter data for this SDE type
        sde_data = results_df[results_df['ar_sde_type'] == sde_type]

        # Plot MSE by sequence length for each latent sigma
        sns.lineplot(
            data=sde_data,
            x="ar_len",
            y="mse",
            hue="ar_latent_sigma",
            markers=True,
            ax=axes[0]
        )
        axes[0].set_title(f"MSE vs Sequence Length ({sde_type})")
        axes[0].set_xlabel("Sequence Length")
        axes[0].set_ylabel("MSE")
        axes[0].set_xscale("log", base=2)

        # Plot KL divergence by sequence length for each latent sigma
        sns.lineplot(
            data=sde_data,
            x="ar_len",
            y="kl_div",
            hue="ar_latent_sigma",
            markers=True,
            ax=axes[1]
        )
        axes[1].set_title(f"KL Divergence vs Sequence Length ({sde_type})")
        axes[1].set_xlabel("Sequence Length")
        axes[1].set_ylabel("KL Divergence")
        axes[1].set_xscale("log", base=2)

        plt.tight_layout()
        plt.show()

    import pdb; pdb.set_trace()
  else:
    raise ValueError("Invalid arguments")
