from debug import *
import os
import argparse
import numpy as np
from Utils.io_utils import load_yaml_config
from diffusion_crf import *
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int, PRNGKeyArray, Scalar, Bool
import jax.random as random
import jax.tree_util as jtu
import optax
from Models.trainer import Trainer as JaxTrainer, TrainingState
import jax.numpy as jnp
import optax
import pickle
import equinox as eqx
import json
import tqdm
from jax._src.util import curry
import tqdm
from main import get_dataset_no_leakage, load_jax_model
from Models.experiment_identifier import ExperimentIdentifier
from Models.empirical_metrics import wasserstein2_distance, compute_univariate_metrics
from Models.models.base import AbstractModel
from Utils.discriminative_metric_jax import discriminative_score_metrics
from Utils.metric import get_context_fid_score
import pandas as pd
import numpy as np
import os
import datetime
import filelock
import time
from Models.result_data_checkpointer import ResultDataCheckpointer, ResultDataCheckpointer, ExperimentResultData, SampleGenerationMetadata, SampleGenerationMetadata

"""
This module implements a test-time evaluation framework with checkpointing capabilities for diffusion models.
It provides functions for:

1. get_model_evaluation_samples - Generates and manages batched samples from trained models with checkpoint support
2. evaluate_and_save_result - Performs comprehensive evaluation using multiple metrics and saves results
3. save_experiment_results - Handles the persistence of evaluation results in a thread-safe manner

The framework allows for resumable evaluation of large test datasets, generating both latent sequences
and model samples for comparison. It calculates various metrics including CRPS, NLL, discriminative scores,
and FID scores under different evaluation conditions.
"""

def get_result_data_for_model(experiment_identifier: ExperimentIdentifier,
                              model: AbstractModel,
                              key: PRNGKeyArray,
                              series: TimeSeries,
                              n_samples_for_empirical_distribution: int,
                              _to_latent: bool = True):
  assert series.batch_size is None, "This function is not batched!"

  def pull_latent_seq_and_model_seq(key: PRNGKeyArray):
    # p(x_{1:N} | y_{1:N})
    latent_seq = model.basic_interpolation(key, series)

    # q(x_{1:N} | y_{1:k})
    samples = model.sample(key, series, to_latent=_to_latent)

    # Downsample the latent sequence and the samples to the original frequency
    latent_seq = model.downsample_seq_to_original_freq(latent_seq)
    samples = model.downsample_seq_to_original_freq(samples)

    return latent_seq, samples

  # Pull a bunch of keys for the number of samples we want to draw
  keys = random.split(key, n_samples_for_empirical_distribution)
  latent_seq, model_samples = jax.lax.map(pull_latent_seq_and_model_seq, keys)
  result_data = ExperimentResultData(experiment_identifier=experiment_identifier,
                                      data_series=series, # (T, D)
                                      latent_seq=latent_seq, # (batch_size, T, D)
                                      model_samples=model_samples) # (batch_size, T, D)

  return result_data

def get_model_evaluation_samples(experiment_identifier: ExperimentIdentifier,
                                 model: AbstractModel,
                                 test_data: TimeSeries,
                                 key: PRNGKeyArray,
                                 max_allowable_batch_size: int = 256,
                                 n_samples_for_empirical_distribution: int = 32,
                                 restart_evaluation: bool = False):
  """Pull samples from p(x_{1:N} | y_{1:N}) and q(x_{1:N} | y_{1:k}) for each test sequence.

  This improved version uses index-based tracking that is invariant to batch size changes,
  allowing evaluation to be resumed with different batch sizes on different hardware.
  """

  def pull_latent_and_model_samples(inputs: Tuple[PRNGKeyArray, TimeSeries, jnp.ndarray]) -> ExperimentResultData:
    key, input_seq, index = inputs
    assert input_seq.batch_size is None, "This function is not batched!"

    def pull_latent_seq_and_model_seq(key: PRNGKeyArray):
      # p(x_{1:N} | y_{1:N})
      latent_seq = model.basic_interpolation(key, input_seq)

      # q(x_{1:N} | y_{1:k})
      samples = model.sample(key, input_seq)

      # Downsample the latent sequence and the samples to the original frequency
      latent_seq = model.downsample_seq_to_original_freq(latent_seq)
      samples = model.downsample_seq_to_original_freq(samples)

      return latent_seq, samples

    # Pull a bunch of keys for the number of samples we want to draw
    keys = random.split(key, n_samples_for_empirical_distribution)
    latent_seq, model_samples = jax.vmap(pull_latent_seq_and_model_seq)(keys)
    # latent_seq, model_samples = jax.lax.map(pull_latent_seq_and_model_seq, keys)
    result_data = ExperimentResultData(experiment_identifier=experiment_identifier,
                                       data_series=input_seq, # (T, D)
                                       latent_seq=latent_seq, # (batch_size, T, D)
                                       model_samples=model_samples) # (batch_size, T, D)

    jax.debug.print("index: {index}", index=index)
    return result_data

  # Use a smaller batch size so that we end up calling our model with max_allowable_batch_size number of elements
  model_batch_size = max_allowable_batch_size//n_samples_for_empirical_distribution
  all_keys = random.split(key, test_data.batch_size)
  all_indices = jnp.arange(test_data.batch_size)

  if 'my_neural_sde' in experiment_identifier.model_name or 'my_neural_ode' in experiment_identifier.model_name:
    ######################
    # Use a for loop for neural ODE and neural SDE models and just process one element at a time
    ######################
    jitted_pull_latent_and_model_samples = eqx.filter_jit(pull_latent_and_model_samples)

    def sample_loop(key: PRNGKeyArray, input_seq: TimeSeries, index: jnp.ndarray):

      n_total_elements = key.shape[0]
      items = (key, input_seq, index)

      result_data = []
      for i in tqdm.tqdm(jnp.arange(n_total_elements)):
        single_item = jtu.tree_map(lambda x: x[i], items)
        rd = jitted_pull_latent_and_model_samples(single_item)
        result_data.append(rd)

      result_data = jtu.tree_map(lambda *xs: jnp.array(xs), *result_data)
      return result_data

  else:
    @jax.jit
    def sample_loop(key: PRNGKeyArray, input_seq: TimeSeries, index: jnp.ndarray):
      result_data = jax.lax.map(pull_latent_and_model_samples, (key, input_seq, index), batch_size=model_batch_size)
      return result_data

  ######################
  # Create a checkpointer that will save the result data to a file
  ######################
  checkpointer = ResultDataCheckpointer(experiment_identifier, restart_evaluation=restart_evaluation)

  # Check if metadata exists, create it if it doesn't
  if not checkpointer.has_evaluation_metadata():
    print(f"No metadata found. Creating new metadata for evaluation.")
    checkpointer.create_evaluation_metadata(
      test_data_size=test_data.batch_size,
      n_samples_for_empirical_distribution=n_samples_for_empirical_distribution
    )
  else:
    # Verify that the test data size matches the stored metadata
    metadata: SampleGenerationMetadata = checkpointer.get_evaluation_metadata()
    if not metadata.parameters_match(test_data.batch_size, n_samples_for_empirical_distribution):
      error_msg = (f"Parameters do not match stored metadata. "
                   f"Expected test_data_size={metadata.test_data_size}, "
                   f"n_samples_for_empirical_distribution={metadata.n_samples_for_empirical_distribution}. "
                   f"Got test_data_size={test_data.batch_size}, "
                   f"n_samples_for_empirical_distribution={n_samples_for_empirical_distribution}.")
      raise ValueError(error_msg)

  # Process data in batches defined by max_allowable_batch_size
  while True:
    # Get the next batch of indices to process
    metadata: SampleGenerationMetadata = checkpointer.get_evaluation_metadata()
    indices_to_process = metadata.get_next_indices(max_allowable_batch_size)

    # If no indices left to process, we're done
    if not indices_to_process:
      print("All indices have been processed.")
      break

    # Get the start and end indices for this batch
    start_index = indices_to_process[0]
    end_index = indices_to_process[-1]

    print(f"Processing indices {start_index}-{end_index} ({len(indices_to_process)} samples)")

    # Get the test data for the current batch
    current_test_data = test_data[start_index:end_index+1]

    # Get the keys for the current batch
    current_keys = all_keys[start_index:end_index+1]

    # Get indices for logging purposes
    current_indices = all_indices[start_index:end_index+1]

    # Pull the latent and model samples for the current batch
    result_data = sample_loop(current_keys, current_test_data, current_indices)

    # Save the result data to file
    checkpointer.save_result_data(result_data, indices_to_process)

  # Load all the results from the checkpointer
  results = checkpointer.load_all_results()

  return results

def evaluate_and_save_result(trained_model_config: dict,
                             experiment_identifier: ExperimentIdentifier,
                             train_state: Optional[TrainingState] = None) -> pd.Series:
  """Run the experiments for the trained model with the improved batch-size invariant implementation.

  This function uses the new index-based checkpointing system that allows for variable batch sizes.
  """
  key = random.PRNGKey(0)

  # Check if the experiment is done training
  sanity_check = trained_model_config['command_line_args']['sanity_check']
  if experiment_identifier.get_training_metadata().is_complete == False and not sanity_check:
    print(f"Experiment {experiment_identifier} is not done training. Skipping evaluation.")
    raise ValueError(f"Experiment {experiment_identifier} is not done training. Skipping evaluation.")

  # Load the data and split it into train, validation, and test
  if experiment_identifier.has_denoised_data():
    out = get_dataset_no_leakage(experiment_identifier, trained_model_config, return_denoised_data=True)
    train_data, val_data, test_data, denoised_train_data, denoised_val_data, denoised_test_data = out
  else:
    train_data, val_data, test_data = get_dataset_no_leakage(experiment_identifier, trained_model_config)

  if train_state is None:
    # Load the model
    trainer, train_state = load_jax_model(train_data,
                                          val_data,
                                          test_data,
                                          experiment_identifier,
                                          trained_model_config)
  model: AbstractModel = train_state.best_model

  # Get batch size from experiment config
  config = experiment_identifier.create_config()
  max_allowable_batch_size = config['dataset']['train_batch_size']*2

  if experiment_identifier.model_name == 'neural_ode' or experiment_identifier.model_name == 'neural_sde':
    # Save after every sample because these are extremely slow to run
    max_allowable_batch_size = 32

  #########################
  # Compute the empirical distribution based metrics.  These
  # mainly evaluate the marginal distribution of our trained model
  # and not the full distributional properties of the model.
  #########################

  # This is a fixed parameter that we use to compute the empirical distribution
  n_samples_for_empirical_distribution = 32

  # Get the model evaluation samples using the new batch-size invariant implementation
  result_data = get_model_evaluation_samples(experiment_identifier,
                                              model,
                                              test_data,
                                              key,
                                              max_allowable_batch_size=max_allowable_batch_size,
                                              n_samples_for_empirical_distribution=n_samples_for_empirical_distribution,
                                              restart_evaluation=trained_model_config['command_line_args']['restart_evaluation'])

  print("Done with get_model_evaluation_samples!")
  print("========================================")
  print("========================================")
  print("========================================")

  # Skip this and go directly to computing the metrics
  # if trained_model_config['command_line_args']['only_generate_samples']:
  #   return

  #########################
  # Compute the distributional metrics using our new functions
  #########################
  config = result_data.experiment_identifier.create_config()
  dataset_config = config['dataset']
  pred_len = dataset_config['pred_length']
  evaluation_settings = dataset_config['evaluation_settings']

  # Setting: future_latent
  target_samples_latent: TimeSeries = result_data.latent_seq
  model_samples_latent: TimeSeries = result_data.model_samples

  # Take a random batch index as the true latent sequence
  n_total_samples = target_samples_latent.yts.shape[0]
  n_batch_elements = target_samples_latent.yts.shape[1]
  indices = random.randint(key, shape=(n_total_samples,), minval=0, maxval=n_batch_elements)
  target_samples_future_latent: TimeSeries = target_samples_latent[jnp.arange(n_total_samples),indices,-pred_len:]
  model_samples_future_latent: TimeSeries = model_samples_latent[:,:,-pred_len:]

  if 'future_latent' in evaluation_settings:
    compute_and_save_metrics(
      target_samples=target_samples_future_latent,
      model_samples=model_samples_future_latent,
      setting="future_latent",
      experiment_identifier=experiment_identifier
    )

  # Setting: future_observation
  target_samples_observation: TimeSeries = result_data.data_series
  model_samples_obs_yts = result_data.model_samples.yts[...,:target_samples_observation.yts.shape[-1]]
  model_samples_observation: TimeSeries = TimeSeries(result_data.model_samples.ts, model_samples_obs_yts)
  target_samples_future_obs: TimeSeries = result_data.data_series[:,-pred_len:]
  model_samples_future_obs: TimeSeries = model_samples_observation[:,:,-pred_len:]

  if 'future_observation' in evaluation_settings:
    compute_and_save_metrics(
      target_samples=target_samples_future_obs,
      model_samples=model_samples_future_obs,
      setting="future_observation",
      experiment_identifier=experiment_identifier
    )

  # Setting: future_denoised_observation
  if 'future_denoised_observation' in evaluation_settings:
    # Compute the metrics for the denoised data
    denoised_samples_future_obs: TimeSeries = denoised_test_data[:,-pred_len:]
    denoised_model_samples_future_obs = model_samples_future_obs

    compute_and_save_metrics(
      target_samples=denoised_samples_future_obs,
      model_samples=denoised_model_samples_future_obs,
      setting="future_denoised_observation",
      experiment_identifier=experiment_identifier
    )
  print(f"Computed and saved metrics for all settings to {experiment_identifier.result_metric_path}")


def compute_and_save_metrics(target_samples: TimeSeries,
                            model_samples: TimeSeries,
                            setting: str,
                            experiment_identifier: ExperimentIdentifier) -> None:
  """
  Compute and save metrics for comparing target and model samples in a specific setting.

  Args:
    target_samples: The ground truth samples to compare against
    model_samples: The samples generated by the model
    setting: The evaluation setting ('full_latent', 'future_latent', 'full_observation', 'future_observation', 'future_denoised_observation')
    experiment_identifier: The identifier for the current experiment
    results_csv_file_path: Path to the CSV file where results should be saved
  """
  # Sanity check the input batch shapes
  # We expect target_samples to have a single batch dimension
  # and model_samples to have two batch dimensions (n_target_samples, n_eval_samples)
  n_target_samples = target_samples.batch_size
  assert isinstance(model_samples.batch_size, tuple)
  n_eval_samples = model_samples.batch_size[1]
  assert model_samples.batch_size[0] == n_target_samples, f"Expected {n_target_samples} target samples, got {model_samples.batch_size[0]}"

  # Create experiment_id and check which metrics need to be computed
  results_csv_file_path = experiment_identifier.result_metric_path
  metrics_to_compute = check_metrics_to_compute(experiment_identifier, setting, results_csv_file_path)

  if not metrics_to_compute:
    print(f"All metrics for setting '{setting}' already computed. Skipping.")
    return

  # Compute CRPS and NLL if needed
  # if 'crps' in metrics_to_compute or 'nll' in metrics_to_compute or 'nrmse' in metrics_to_compute:
  if True: # This is fast to compute
    crps, nll, nrmse = jax.vmap(compute_crps_nll_nrmse_over_each_time_and_dimension)(target_samples, model_samples)
    # crps, nll = jax.vmap(compute_crps_and_nll_over_each_time_and_dimension)(target_samples, model_samples)

    # Average over the number of samples
    if 'crps' in metrics_to_compute:
      crps_mean = float(jnp.mean(crps))
      update_results_file(setting=setting,
                         experiment_identifier=experiment_identifier,
                         path=results_csv_file_path,
                         crps=crps_mean)
      print(f"Saved CRPS={crps_mean:.4f} for setting '{setting}'")

    if 'nll' in metrics_to_compute:
      nll_mean = float(jnp.mean(nll))
      update_results_file(setting=setting,
                         experiment_identifier=experiment_identifier,
                         path=results_csv_file_path,
                         nll=nll_mean)
      print(f"Saved NLL={nll_mean:.4f} for setting '{setting}'")

    if 'nrmse' in metrics_to_compute:
      nrmse_mean = float(jnp.mean(nrmse))
      update_results_file(setting=setting,
                         experiment_identifier=experiment_identifier,
                         path=results_csv_file_path,
                         nrmse=nrmse_mean)
      print(f"Saved NRMSE={nrmse_mean:.4f} for setting '{setting}'")
  else:
    print(f"Skipping CRPS and NLL for setting '{setting}' because they are not in metrics_to_compute")

  # Compute the discriminative metric if needed
  if 'discrim' in metrics_to_compute:
    discrim, _, _ = discriminative_score_metrics(target_samples.yts, model_samples.yts[:,0])
    discrim_value = float(discrim)
    update_results_file(setting=setting,
                       experiment_identifier=experiment_identifier,
                       path=results_csv_file_path,
                       discrim=discrim_value)
    print(f"Saved Discriminative={discrim_value:.4f} for setting '{setting}'")
  else:
    print(f"Skipping Discriminative for setting '{setting}' because it is not in metrics_to_compute")

  # Compute the FID metric if needed
  if 'fid' in metrics_to_compute:
    fid = get_context_fid_score(target_samples.yts, model_samples.yts[:,0])
    fid_value = float(fid.mean)
    update_results_file(setting=setting,
                       experiment_identifier=experiment_identifier,
                       path=results_csv_file_path,
                       fid=fid_value)
    print(f"Saved FID={fid_value:.4f} for setting '{setting}'")
  else:
    print(f"Skipping FID for setting '{setting}' because it is not in metrics_to_compute")

def check_metrics_to_compute(experiment_identifier: ExperimentIdentifier, setting: str, path: str) -> list:
  """
  Check which metrics need to be computed based on what's already in the results file.

  Args:
    experiment_identifier: The identifier for the experiment
    setting: The evaluation setting
    path: Path to the results CSV file

  Returns:
    List of metrics that need to be computed
  """
  # Default metrics to compute
  all_metrics = experiment_identifier.get_metrics_to_compute()

  # Convert experiment_id to string for consistent comparison
  experiment_id_str = str(experiment_identifier.get_model_identifier())

  # If file doesn't exist, compute all metrics
  if not os.path.exists(path):
    return all_metrics

  # Load existing data
  try:
    df = pd.read_csv(path, index_col=0)
  except:
    # If there's an error reading the file, compute all metrics
    return all_metrics

  # If experiment doesn't exist in the dataframe, compute all metrics
  if experiment_id_str not in df.index:
    return all_metrics

  # Check which metrics are already computed
  metrics_to_compute = []
  for metric in all_metrics:
    column_name = f"{setting}_{metric}"
    if column_name not in df.columns or pd.isna(df.loc[experiment_id_str, column_name]):
      metrics_to_compute.append(metric)

  return metrics_to_compute

def update_results_file(setting: str,
                       experiment_identifier: ExperimentIdentifier,
                       path: str,
                       **metrics) -> None:
  """
  Update the results CSV file with new metric values.

  Args:
    setting: The evaluation setting ('full_latent', 'future_latent', 'full_observation', 'future_observation', 'future_denoised_observation')
    experiment_identifier: The identifier for the current experiment
    path: Path to the CSV file where results should be saved
    **metrics: Dictionary of metric names and values to save
  """
  # Create experiment_id and prepare metric names with setting prefix
  experiment_id = experiment_identifier.get_model_identifier()
  experiment_id_str = str(experiment_id)
  prefixed_metrics = {f"{setting}_{k}": v for k, v in metrics.items()}

  # Ensure directory exists
  dirname = os.path.dirname(path)
  if dirname:
    os.makedirs(dirname, exist_ok=True)

  # Use file locking for concurrent access
  with filelock.FileLock(f"{path}.lock", timeout=60):
    # Load existing data or create new DataFrame
    if os.path.exists(path):
      df = pd.read_csv(path, index_col=0)
    else:
      df = pd.DataFrame()

    # Check if the experiment already exists in the dataframe
    if experiment_id_str in df.index:
      # Update only the specific metrics for this setting
      for metric_name, value in prefixed_metrics.items():
        df.loc[experiment_id_str, metric_name] = value
    else:
      # Create a new row with just the experiment ID (no need to store full identifier)
      new_row = pd.Series(prefixed_metrics, name=experiment_id_str)
      df = pd.concat([df, pd.DataFrame([new_row])])

    # Save to CSV
    df.to_csv(path)
    print(f"Updated metrics for setting '{setting}' in {path}")

def compute_crps_and_nll_over_each_time_and_dimension(target_sample: TimeSeries,
                                                     model_samples: TimeSeries) -> Tuple[Float[Array, "..."],
                                                                                        Float[Array, "..."]]:
  """
  Compute CRPS and NLL metrics for each time step and dimension.

  Args:
    target_sample: A single target sample (without batch dimension)
    model_samples: Multiple model samples for the same target (with batch dimension)

  Returns:
    Tuple of (crps, nll) values
  """
  from Models.empirical_metrics import wasserstein2_distance, compute_univariate_metrics

  # Compute the CRPS and log likelihood of the test sequences under the model.
  eval_fn = jax.vmap(compute_univariate_metrics, in_axes=(1, 0)) # Vmap over the dimension axis
  eval_fn = jax.vmap(eval_fn, in_axes=(1, 0)) # Vmap over the time axis

  # Compute the CRPS and log likelihood of the test sequences under the model.
  crps, log_likelihood, _ = eval_fn(model_samples.yts, target_sample.yts)
  nll = -log_likelihood
  return crps, nll

def compute_crps_nll_nrmse_over_each_time_and_dimension(target_sample: TimeSeries,
                                                        model_samples: TimeSeries) -> Tuple[Float[Array, "..."],
                                                                                        Float[Array, "..."],
                                                                                        Float[Array, "..."]]:
  """
  Compute CRPS, NLL, and NRMSE metrics for each time step and dimension.

  Args:
    target_sample: A single target sample (without batch dimension)
    model_samples: Multiple model samples for the same target (with batch dimension)

  Returns:
    Tuple of (crps, nll, nrmse) values
  """
  assert target_sample.batch_size is None
  assert isinstance(model_samples.batch_size, int)

  from Models.empirical_metrics import wasserstein2_distance, compute_univariate_metrics

  # Compute the CRPS and log likelihood of the test sequences under the model.
  eval_fn = jax.vmap(compute_univariate_metrics, in_axes=(1, 0)) # Vmap over the dimension axis
  eval_fn = jax.vmap(eval_fn, in_axes=(1, 0)) # Vmap over the time axis

  # Compute the CRPS and log likelihood of the test sequences under the model.
  crps, log_likelihood, _ = eval_fn(model_samples.yts, target_sample.yts)
  nll = -log_likelihood

  # Compute the NRMSE
  @partial(jax.vmap, in_axes=(None, 0))
  def compute_nrmse(truth: TimeSeries, pred: TimeSeries) -> Float[Array, ""]:
    assert truth.batch_size == pred.batch_size == None
    rmse = jnp.sqrt(jnp.mean((truth.yts - pred.yts)**2, axis=0))
    std = jnp.std(truth.yts, axis=0)
    nrmse = rmse / std

    dim = truth.yts.shape[-1]
    assert nrmse.shape == (dim,)
    return nrmse

  nrmse = compute_nrmse(target_sample, model_samples)
  return crps, nll, nrmse
