import subprocess
# Run nvidia-smi to check GPU status before executing main
try:
  print("="*80)
  print("GPU STATUS:")
  subprocess.run(["nvidia-smi"], check=True)
  print("="*80)
except Exception as e:
  print(f"Error running nvidia-smi: {e}")
  print("="*80)


from debug import *
import os
import argparse
import numpy as np
from Utils.io_utils import load_yaml_config
from diffusion_crf import *
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int, PRNGKeyArray, Scalar, Bool
import jax.random as random
import jax.tree_util as jtu
import optax
from Models.trainer import Trainer
import jax.numpy as jnp
import optax
import pickle
from Models.model_selector import create_model
import equinox as eqx
from Models.models.base import EmptyModel, AbstractModel
from Utils.Data_utils.my_datasets import get_improved_pendulum_dataset, get_physics_dataset, get_harmonic_oscillator_dataset, get_harmonic_oscillator_dataset2
from diffusion_crf.sde.sde_base import max_likelihood_ltisde
from jax._src.util import curry
import tqdm
from Utils.Data_utils.real_datasets import get_stocks_dataset, get_energy_dataset, get_etth_dataset, get_fmri_dataset
from Utils.Data_utils.mujoco_dataset import get_mujoco_dataset
from Utils.Data_utils.sine_dataset import get_sine_dataset
import pandas as pd
from Models.experiment_identifier import ExperimentIdentifier
import threading
import time
import datetime
import wadler_lindig as wl
import optuna
import copy

def config_hack(config: dict):

  dataset_configs = config['dataset']
  model_configs = config['model']
  command_line_args = config['command_line_args']
  freq = command_line_args['freq']

  # Hack to ensure that we won't run out of memory
  if freq == 2:
    dataset_configs['train_batch_size'] = dataset_configs['train_batch_size']//2
    dataset_configs['gradient_accumulation_batch_size_multiplier'] = dataset_configs['gradient_accumulation_batch_size_multiplier']*2
  elif freq == 4:
    dataset_configs['train_batch_size'] = dataset_configs['train_batch_size']//4
    dataset_configs['gradient_accumulation_batch_size_multiplier'] = dataset_configs['gradient_accumulation_batch_size_multiplier']*4
  dataset_configs['val_batch_size'] = dataset_configs['train_batch_size'] # Avoid OOM

  # Return the final configs that we'll use
  output_configs = dict(dataset=dataset_configs,
                        model=model_configs,
                        command_line_args=command_line_args)

  if 'override_params' in model_configs:
    def deep_update(original, updates):
      for key, value in updates.items():
        if isinstance(value, dict) and key in original and isinstance(original[key], dict):
          deep_update(original[key], value)
        else:
          original[key] = value

    for key, value in model_configs['override_params'].items():
      if key in output_configs and isinstance(output_configs[key], dict) and isinstance(value, dict):
        deep_update(output_configs[key], value)
      else:
        output_configs[key] = value

  return output_configs

def parse_args():
  parser = argparse.ArgumentParser(description='Launch experiments')
  parser.add_argument('--config_file', type=str, default=None,
                      help='path of config file')

  # Standard args
  parser.add_argument('--retrain', action='store_true', default=False, help='Retrain the model.')
  parser.add_argument('--train', action='store_true', default=False, help='Train or Test using JAX model.')
  parser.add_argument('--model_name', type=str, help='JAX model to use.')
  parser.add_argument('--debug', action='store_true', default=False, help='Debug the loss function.')
  parser.add_argument('--test_time_evaluation', action='store_true', default=False, help='Test the model on the test set.')
  parser.add_argument('--log_plots', action='store_true', default=False, help='Log plots to wandb.')
  parser.add_argument('--checkpoint_every', type=int, default=200, help='Checkpoint the model every N steps.')
  parser.add_argument('--sanity_check', action='store_true', default=False, help='Run a sanity check to see if all parts of the experimental pipeline are working.')
  parser.add_argument('--global_key_seed', type=int, default=0, help='Global key seed for random number generators.')

  parser.add_argument('--only_generate_samples', action='store_true', default=False, help='Only generate samples, don\'t compute metrics.')
  parser.add_argument('--restart_evaluation', action='store_true', default=False, help='Restart the evaluation from the beginning.')
  parser.add_argument('--nfe_evaluation', action='store_true', default=False, help='Run the NFE evaluation.')

  # Experiment dependent args
  parser.add_argument('--freq', type=int, default=100, help='Frequency of data interpolation.  We will get `freq` values in between every two time points in the original data.')
  parser.add_argument('--sde_type', type=str, default='brownian', help='Type of SDE to use.', choices=['brownian', 'ornstein_uhlenbeck', 'langevin', 'max_likelihood', 'tracking', 'harmonic_oscillator'])

  # For grouping in wandb
  parser.add_argument('--group', type=str, required=True, help='Group to use.')

  # Data leakage fixing
  parser.add_argument('--no_leakage', action='store_true', default=True, help='Fix the data leakage issue.')
  parser.add_argument('--use_latent_cond_len', action='store_true', default=True, help='Use latent conditioning length.')

  # Hyperparameter tuning
  parser.add_argument('--hyperparameter_tuning', action='store_true', default=False, help='Run hyperparameter tuning.')

  args = parser.parse_args()

  args.just_load = args.train == False
  config = load_yaml_config(args.config_file)

  # Get the model configs
  model_configs = config[args.model_name]

  if args.use_latent_cond_len == False:
    if 'latent_cond_len' in model_configs:
      del model_configs['latent_cond_len']

  # Get the dataset configs
  dataset_configs = config['dataset']

  # Get the command line args
  command_line_args = vars(args)

  # Only log plots during training if we're not using an expensive model
  if args.model_name not in ['my_neural_ode']:
  # if args.model_name not in ['my_neural_sde', 'my_neural_ode']:
    command_line_args['log_plots'] = True
    print('Logging plots during training')
  else:
    command_line_args['log_plots'] = False
    print('Not logging plots during training')
  command_line_args['log_plots'] = False


  # Return the final configs that we'll use
  output_configs = dict(dataset=dataset_configs,
                        model=model_configs,
                        command_line_args=command_line_args)

  # Final hacks that we'll force on the config
  output_configs = config_hack(output_configs)

  print('=======================')
  wl.pprint(output_configs)
  print('=======================')

  return output_configs

################################################################################################################

def load_empty_model(experiment_identifier: ExperimentIdentifier, **kwargs):
  config = experiment_identifier.create_config(**kwargs)
  args = config['command_line_args']
  model_configs = config['model']
  data_configs = config['dataset']

  #########################
  # Create the interpolator
  #########################
  y_dim = data_configs['dim']
  freq = args['freq']
  sde_type = args['sde_type']
  if experiment_identifier.group == 'exp_april_12':
    latent_sigma = 0.1
  else:
    latent_sigma = data_configs.get('latent_sigma', 0.1)

  if 'latent_sigma_updated' in data_configs:
    latent_sigma = data_configs['latent_sigma_updated']

  # Override the latent sigma if specified
  latent_sigma = kwargs.get('latent_sigma', latent_sigma)

  if sde_type == 'brownian':
    sde = BrownianMotion(sigma=latent_sigma, dim=y_dim)
  elif sde_type == 'ornstein_uhlenbeck':
    sde = OrnsteinUhlenbeck(sigma=latent_sigma, lambda_=0.1, dim=y_dim)
  elif sde_type == 'langevin':
    sde = CriticallyDampedLangevinDynamics(mass=0.1, beta=0.1, dim=y_dim)
  elif sde_type == 'tracking':
    tracking_sigma = kwargs.get('tracking_sigma', data_configs['tracking_sigma'])
    sde = HigherOrderTrackingModel(sigma=tracking_sigma, position_dim=y_dim, order=2)
  elif sde_type == 'harmonic_oscillator':
    sde = HarmonicOscillator(freq=1.0, coeff=0.0, sigma=latent_sigma, observation_dim=1)
  else:

    raise ValueError(f'Unknown SDE type: {sde_type}')


  time_scale_mult = data_configs.get('time_scale_mult', 1.0)

  # Override the time scale multiplier if specified
  time_scale_mult = kwargs.get('time_scale_mult', time_scale_mult)

  sde = TimeScaledLinearTimeInvariantSDE(sde, time_scale=time_scale_mult)

  if experiment_identifier.group == 'exp_april_12':
    noise_std = 0.01
  else:
    noise_std = data_configs.get('noise_std', 0.01)

  if 'noise_std_updated' in data_configs:
    noise_std = data_configs['noise_std_updated']

  # Override the noise std if specified
  noise_std = kwargs.get('noise_std', noise_std)

  use_encoder_prior = model_configs.get('use_encoder_prior', True)
  encoder = PaddingLatentVariableEncoderWithPrior(y_dim=y_dim,
                                                x_dim=sde.dim,
                                                sigma=noise_std,
                                                use_prior=use_encoder_prior)

  return EmptyModel(linear_sde=sde,
                    encoder=encoder,
                    interpolation_freq=freq,
                    obs_seq_len=data_configs['seq_length'],
                    cond_len=data_configs['seq_length'] - data_configs['pred_length'])


def load_jax_model(train_data: TimeSeries,
                   val_data: TimeSeries,
                   test_data: TimeSeries,
                   experiment_identifier: ExperimentIdentifier,
                   config: Optional[dict] = None):
  if config is None:
    config = experiment_identifier.create_config()

  args = config['command_line_args']
  model_configs = config['model']
  data_configs = config['dataset']

  #########################
  # Create the interpolator
  #########################
  dummy_model = load_empty_model(experiment_identifier)
  sde, encoder = dummy_model.linear_sde, dummy_model.encoder

  #########################
  # Create the model
  #########################
  key = random.PRNGKey(args['global_key_seed'])
  model = create_model(sde,
                       encoder,
                       config=config,
                       key=key)

  # Print out how many parameters the model has
  params, static = eqx.partition(model, eqx.is_array)
  flat_params, treedef = jax.tree_util.tree_flatten(params)
  n_params = sum([p.size for p in flat_params])
  print(f'Model has {n_params} parameters')

  #########################
  # Create the optimizer
  #########################
  num_steps = model_configs['optimizer']['max_train_steps']
  lr = model_configs['optimizer']['lr']
  warmup_steps = model_configs['optimizer']['warmup_steps']

  from optax.schedules._schedule import warmup_cosine_decay_schedule
  lr_fn = warmup_cosine_decay_schedule(init_value=0.0,
                                      peak_value=lr,
                                      warmup_steps=min(warmup_steps, num_steps - 1),
                                      decay_steps=num_steps)

  optimizer = optax.adamw(lr_fn)

  #########################
  # Create the trainer
  #########################
  trainer = Trainer(experiment_identifier=experiment_identifier,
                    config=config,
                    checkpoint_path=experiment_identifier.model_folder_name,
                    train_data=train_data,
                    validation_data=val_data,
                    test_data=test_data,
                    optimizer=optimizer)

  #########################
  # Figure out if we're training or just loading
  #########################
  if args['train']:
    if args['debug']:
      _retrain = False
      _just_load = True
    else:
      _retrain = False
      _just_load = False
  else:
    _retrain = False
    _just_load = True

  if args['retrain']:
    _retrain = True

  #########################
  # Initialize the trainer
  #########################
  train_state = trainer.init(model=model, retrain=_retrain, just_load=_just_load)
  return trainer, train_state

################################################################################################################

def get_dataset(*args, **kwargs):
  assert 0, "Deprecating this function"

def get_dataset_no_leakage(experiment_identifier: ExperimentIdentifier,
                           config: dict,
                           return_denoised_data: bool = False,
                           _return_indices_only: bool = False,
                           return_raw_data: bool = False):

  #########################
  # Load the data and split it into train, validation, and test
  #########################
  args = config['command_line_args'] # Convert to dict just to be consistent with the rest of the code
  dataset_config = config['dataset']
  name = dataset_config['name']
  key = random.PRNGKey(args['global_key_seed'])

  train_proportion = dataset_config['train_proportion']
  val_proportion = dataset_config['val_proportion']
  test_proportion = 1.0 - train_proportion - val_proportion
  train_val_test_split = (train_proportion, val_proportion, test_proportion)

  if name == 'double_pendulum':
    raise NotImplementedError('Depracated')
  elif name == 'noisy_double_pendulum':
    dataset, denoised_dataset = get_improved_pendulum_dataset(config=dataset_config,
                                                                key=key,
                                                                train_val_test_split=train_val_test_split,
                                                                return_raw_data=return_raw_data)
  elif name == 'stocks':
    dataset = get_stocks_dataset(config=dataset_config,
                                 args=args,
                                 key=key,
                                 train_val_test_split=train_val_test_split,
                                 return_raw_data=return_raw_data)
  elif name == 'energy':
    dataset = get_energy_dataset(config=dataset_config,
                                 args=args,
                                 key=key,
                                 train_val_test_split=train_val_test_split,
                                 return_raw_data=return_raw_data)
  elif name == 'etth':
    dataset = get_etth_dataset(config=dataset_config,
                                 args=args,
                                 key=key,
                                 train_val_test_split=train_val_test_split,
                                 return_raw_data=return_raw_data)
  elif name == 'mujoco':
    dataset = get_mujoco_dataset(config=dataset_config,
                                 args=args,
                                 key=key,
                                 train_val_test_split=train_val_test_split,
                                 return_raw_data=return_raw_data)
  elif name == 'sines':
    dataset = get_sine_dataset(config=dataset_config,
                                 args=args,
                                 key=key,
                                 train_val_test_split=train_val_test_split,
                                 return_raw_data=return_raw_data)
  elif name == 'fmri':
    dataset = get_fmri_dataset(config=dataset_config,
                                 args=args,
                                 key=key,
                                 train_val_test_split=train_val_test_split,
                                 return_raw_data=return_raw_data)
  elif name == 'm4' \
    or name == 'solar' \
    or name == 'electricity' \
    or name == 'traffic' \
    or name == 'exchange' \
    or name == 'uber_tlc' \
    or name == 'kdd_cup' \
    or name == 'wiki':
    from Data.gluonts.dataset import get_gluonts_dataset
    train_prop, val_prop, _ = train_val_test_split
    dataset = get_gluonts_dataset(config=dataset_config,
                                  args=args,
                                  key=key,
                                  train_val_test_split=(train_prop, val_prop, None),
                                  return_raw_data=return_raw_data)
  elif name == 'lorenz' \
    or name == 'fitzhugh' \
    or name == 'lotka' \
    or name == 'brusselator' \
    or name == 'van_der_pol':
    dataset, denoised_dataset = get_physics_dataset(config=dataset_config,
                                                    key=key,
                                                    train_val_test_split=train_val_test_split,
                                                    return_raw_data=return_raw_data)
  elif name == 'harmonic_oscillator':
    dataset = get_harmonic_oscillator_dataset(config=config,
                                              experiment_identifier=experiment_identifier,
                                              key=key,
                                              train_val_test_split=train_val_test_split,
                                              return_raw_data=return_raw_data)
  elif name == 'harmonic_oscillator2':
    dataset = get_harmonic_oscillator_dataset2(config=config,
                                              experiment_identifier=experiment_identifier,
                                              key=key,
                                              train_val_test_split=train_val_test_split,
                                              return_raw_data=return_raw_data)
  else:
    raise ValueError(f'Unknown dataset: {name}')

  if True:
    # Try making all of the data have starting times of 0

    def reset_start_time(series: TimeSeries):
      if series.batch_size is not None:
        return jax.vmap(reset_start_time)(series)
      assert series.batch_size is None
      ts_adjusted = series.ts - series.ts[0]
      return eqx.tree_at(lambda x: x.ts, series, ts_adjusted)

    if isinstance(dataset, tuple) or isinstance(dataset, list):
      dataset = tuple(map(reset_start_time, dataset))
    else:
      dataset = reset_start_time(dataset)

    if return_denoised_data:
      if isinstance(denoised_dataset, tuple) or isinstance(denoised_dataset, list):
        denoised_dataset = tuple(map(reset_start_time, denoised_dataset))
      else:
        denoised_dataset = reset_start_time(denoised_dataset)


  if return_raw_data:
    return dataset

  train_data, val_data, test_data = dataset

  """Scale and shift by the observed parts of each sequence.  This is valid because we only
  shift and scale by the observed parts of each sequence.  We don't do this for the noisy double
  pendulum because we've ensured that this dataset is well behaved.
  """
  cond_len = dataset_config['seq_length'] - dataset_config['pred_length']
  def scale_and_shift(data: TimeSeries):
    obs_data = data[:cond_len]
    mean = obs_data.yts.mean(axis=(0, 1))
    std = obs_data.yts.std(axis=(0, 1))
    data = eqx.tree_at(lambda x: x.yts, data, (data.yts - mean)/std)
    return data

  if name != 'harmonic_oscillator':
    # Don't do this for the harmonic oscillator
    train_data = scale_and_shift(train_data)
    val_data = scale_and_shift(val_data)
    test_data = scale_and_shift(test_data)

  if return_denoised_data:
    denoised_train_data, denoised_val_data, denoised_test_data = denoised_dataset

    denoised_train_data = scale_and_shift(denoised_train_data)
    denoised_val_data = scale_and_shift(denoised_val_data)
    denoised_test_data = scale_and_shift(denoised_test_data)

    return train_data, val_data, test_data, denoised_train_data, denoised_val_data, denoised_test_data

  return train_data, val_data, test_data

################################################################################################################

def run_trial(experiment_identifier: ExperimentIdentifier,
              command_line_args: Optional[dict] = None,
              trial: Optional[optuna.trial.Trial] = None,
              train_if_needed: bool = False,
              retrain: bool = False):

  if trial is not None:
    """Inject the hyperparameters into the config"""
    config = experiment_identifier.create_config(trial=trial,
                                                 inject_hyperparameters=True,
                                                 command_line_args=command_line_args)
    experiment_identifier = ExperimentIdentifier(config)

    model_configs = config['model']
    args = config['command_line_args'] # Convert to dict just to be consistent with the rest of the code
    dataset_config = config['dataset']

  else:
    config = experiment_identifier.create_config(command_line_args=command_line_args, train_if_needed=train_if_needed, retrain=retrain)
    model_configs = config['model']
    dataset_config = config['dataset']

    if command_line_args is None:
      args = config['command_line_args'] # Convert to dict just to be consistent with the rest of the code
    else:
      args = command_line_args

  #########################
  # Load the data and split it into train, validation, and test
  #########################
  # We had data leakage in the previous set of experiments!
  train_data, val_data, test_data = get_dataset_no_leakage(experiment_identifier, config)

  #########################
  # Load the model
  #########################
  trainer, train_state = load_jax_model(train_data,
                                        val_data,
                                        test_data,
                                        experiment_identifier,
                                        config)

  #########################
  # Retrieve the trainer object that will train/load the model
  #########################
  max_train_steps = model_configs['optimizer']['max_train_steps']

  if args['sanity_check']:
    max_train_steps = 10

  if args['checkpoint_every'] == -1:
    # Automatically
    effective_batch_size = dataset_config['train_batch_size'] * dataset_config['gradient_accumulation_batch_size_multiplier']
    prob_of_seeing_all_data = 0.99
    total_dataset_size = train_data.batch_size

    N, K, p = total_dataset_size, effective_batch_size, prob_of_seeing_all_data

    # Chat-GPT generated estimate for the number of gradient steps with a batch size of K
    # needed until we are confident with probability p that we have seen each of the N
    # examples at least once.
    checkpoint_every = int(N/K*jnp.log(N/(1-p)))
    print(f'Checkpointing every {checkpoint_every} gradient updates with a batch size of {K} and a total dataset size of {N} to ensure that we have seen each example at least once with probability {p}')
  else:
    checkpoint_every = args['checkpoint_every']

  # If we're retraining, reset the training metadata
  if args['retrain']:
    from Models.training_tracker import reset_training_metadata
    reset_training_metadata(experiment_identifier)

  train_state = trainer.train(train_state=train_state,
                              num_steps=max_train_steps,
                              checkpoint_every=checkpoint_every,
                              gradient_accumulation=dataset_config['gradient_accumulation_batch_size_multiplier'],
                              trial=trial)

  if trial is None:
    return train_state
  return train_state.best_validation_loss

def main():
  from Models.checkpointed_test_time_evaluation import evaluate_and_save_result
  config = parse_args()
  model_configs = config['model']
  args = config['command_line_args'] # Convert to dict just to be consistent with the rest of the code
  dataset_config = config['dataset']

  print('\n\n===================')
  print(f'Config: {config}')
  print('===================\n\n')

  experiment_identifier = ExperimentIdentifier(config=config)

  print('\n\n===================')
  print(f'Experiment identifier: \n{str(experiment_identifier)}')
  print('===================\n\n')

  if args['test_time_evaluation']:
    # Run the test time evaluation and save the results
    evaluate_and_save_result(config, experiment_identifier=experiment_identifier)
    return

  if args['hyperparameter_tuning']:
    from Models.hyper_params import create_study
    study = create_study(experiment_identifier)
    return study.optimize(lambda trial: run_trial(experiment_identifier, command_line_args=args, trial=trial),
                          n_trials=200)

  elif args['train']:
    train_state = run_trial(experiment_identifier, command_line_args=args)
    evaluate_and_save_result(config, experiment_identifier=experiment_identifier, train_state=train_state) # Evaluate the trained model as well

  else:
    # Just load the model to do debugging
    train_data, val_data, test_data = get_dataset_no_leakage(experiment_identifier, config)
    trainer, train_state = load_jax_model(train_data,
                                          val_data,
                                          test_data,
                                          experiment_identifier,
                                          config)

    trained_model = train_state.model
    best_model = train_state.best_model


    #########################
    # Optionally run the debug code
    #########################
    if args['debug']:

      key = random.PRNGKey(0)

      sample = trained_model.sample(key, test_data[0], debug=False)
      # # control = trained_model.predict_control(test_data[0], sample[trained_model.latent_generation_start_index:])
      # import pdb; pdb.set_trace()
      import pdb; pdb.set_trace()
      loss_out1 = trained_model.loss_fn(train_data[0], key, debug=True)
      import pdb; pdb.set_trace()
      # raw_data = experiment_identifier.get_raw_data()

      key = jnp.array([839183663, 3740430601], dtype=jnp.uint32)
      sample = trained_model.sample(key, test_data[0], debug=False)

      import matplotlib.pyplot as plt
      sample.plot_series(show_plot=False)
      plt.savefig('blah.png')
      import pdb; pdb.set_trace()


      def make_samples(model: AbstractModel):
        assert series.batch_size is None
        return jax.vmap(model.sample, in_axes=(0, None))(keys, series)

      sample = trained_model.sample(key, test_data[0], debug=False)
      import pdb; pdb.set_trace()

      # loss_out1 = trained_model.loss_fn(train_data[0], key, debug=True)
      loss_out2 = trained_model.loss_fn(val_data[0], key, debug=True)


      loss_out = best_model.loss_fn(series, key, debug=True)


      # sample = trained_model.sample(key, series, input_not_upsampled=True, debug=True)
      # trainer.log_plot(series, sample[None], 'blah', 0)

      import pdb; pdb.set_trace()


      # Get samples using a batch of data
      training_eval_batch = train_data[4]
      keys = random.split(key, 64) # 92%/s
      sample_fn = partial(trained_model.sample, input_not_upsampled=True)
      vmapped_sample_fn = eqx.filter_vmap(sample_fn, in_axes=(0, None))
      samples = eqx.filter_jit(vmapped_sample_fn)(keys, training_eval_batch)
      # # trainer.log_plot(training_eval_batch, samples, 'blah', 0)

      print('\n==============\n')

      # Get samples using a batch of data
      training_eval_batch = train_data[4]
      keys = random.split(key, 64) # 92%/s
      sample_fn = partial(trained_model.sample, input_not_upsampled=True)
      vmapped_sample_fn = eqx.filter_vmap(sample_fn, in_axes=(0, None))
      samples = eqx.filter_jit(vmapped_sample_fn)(keys, training_eval_batch)
      # # trainer.log_plot(training_eval_batch, samples, 'blah', 0)

      print('\n==============\n')

      # Get samples using a batch of data
      training_eval_batch = train_data[:2] # 1228%/s
      keys = random.split(key, training_eval_batch.batch_size)
      sample_fn = partial(trained_model.sample, input_not_upsampled=True)
      vmapped_sample_fn = eqx.filter_vmap(sample_fn, in_axes=(0, 0))
      samples = eqx.filter_jit(vmapped_sample_fn)(keys, training_eval_batch)

      print('\n==============\n')

      # Get samples using a batch of data
      training_eval_batch = train_data[:16] # 320%/s
      sample_fn = partial(trained_model.sample, input_not_upsampled=True)
      vmapped_sample_fn = eqx.filter_vmap(sample_fn, in_axes=(None, 0))
      samples = eqx.filter_jit(vmapped_sample_fn)(key, training_eval_batch)

      import pdb; pdb.set_trace()

  print(f'Done with experiment!')

if __name__ == '__main__':
  main()