import tqdm
import jax.numpy as jnp
import jax
import jax.random as random
from functools import partial
from typing import Optional, Mapping, Tuple, List, Sequence, Union, Any, Callable, Dict, Iterator
import optax
import equinox as eqx
from jaxtyping import Array, PRNGKeyArray, PyTree, Scalar
import tqdm
import jax.tree_util as jtu
import os
from jax._src.util import curry
import wandb
from optax import tree_utils as otu
from optax import contrib
from diffusion_crf import TimeSeries, ProbabilisticTimeSeries, AbstractEncoder
from Models.empirical_metrics import compute_univariate_metrics
import matplotlib.pyplot as plt
import time
import numpy as np
import pandas as pd
from Utils.discriminative_metric_jax import discriminative_score_metrics
from Utils.predictive_metric_jax import predictive_score_metrics
from Utils.metric import get_context_fid_score
from Models.experiment_identifier import ExperimentIdentifier
import optuna
from pathlib import Path

__all__ = ['TrainingState',
            'Checkpointer',
            'Trainer',
            'default_optimizer']


def ensure_path_exists(path):
  Path(path).mkdir(parents=True, exist_ok=True)


@curry
def _apply_update(step_size: Scalar, u, p):
  return p if u is None else (1 - step_size)*p + step_size*u

def _is_none(x):
  return x is None

def incremental_update(model: PyTree,
                       updates: PyTree,
                       step_size: Scalar) -> PyTree:
  return jtu.tree_map(_apply_update(step_size), updates, model, is_leaf=_is_none)


class TrainingState(eqx.Module):

  i: float
  key: PRNGKeyArray
  model: eqx.Module
  opt_state: optax.OptState

  best_validation_loss: float
  number_of_steps_since_best_validation_loss: int

  best_model: eqx.Module

  def __init__(self,
               i: float, # float so that it is not treated as static
               key: PRNGKeyArray,
               model: eqx.Module,
               opt_state: optax.OptState):
    self.i = i
    self.key = key
    self.model = model
    self.opt_state = opt_state

    # Internal variables
    self.best_validation_loss = jnp.inf
    self.number_of_steps_since_best_validation_loss = 0
    self.best_model = model

  def apply_training_update(self, next_key: PRNGKeyArray, model: eqx.Module, opt_state: optax.OptState):
    new_train_state = eqx.tree_at(lambda x: x.i, self, self.i + 1)
    new_train_state = eqx.tree_at(lambda x: x.key, new_train_state, next_key)
    new_train_state = eqx.tree_at(lambda x: x.model, new_train_state, model)
    new_train_state = eqx.tree_at(lambda x: x.opt_state, new_train_state, opt_state)
    return new_train_state

  def update_best_validation_loss(self, val_loss: float) -> Tuple[bool, 'TrainingState']:
    if val_loss < self.best_validation_loss:
      new_val_is_lower = True

      # Reset the number of steps since the best validation loss
      new_train_state = eqx.tree_at(lambda x: x.number_of_steps_since_best_validation_loss, self, 0)

      # Update the best validation loss
      new_train_state = eqx.tree_at(lambda x: x.best_validation_loss, new_train_state, val_loss)

      # Update the best model
      new_train_state = eqx.tree_at(lambda x: x.best_model, new_train_state, self.model)

      return new_train_state
    else:
      new_val_is_lower = False

      # Increment the number of steps since the best validation loss
      new_train_state = eqx.tree_at(lambda x: x.number_of_steps_since_best_validation_loss, self, self.number_of_steps_since_best_validation_loss + 1)

      return new_train_state

################################################################################################################

class Checkpointer(eqx.Module):

  save_path: str
  model_folder: str
  experiment_identifier: ExperimentIdentifier

  def __init__(self,
               save_path: str,
               experiment_identifier: ExperimentIdentifier):
    self.save_path = save_path
    self.model_folder = os.path.join(save_path, 'models')
    ensure_path_exists(self.model_folder)
    self.experiment_identifier = experiment_identifier

  def model_exists(self) -> bool:
    return os.path.exists(self.saved_model_path)

  @property
  def saved_model_path(self):
    return os.path.join(self.model_folder, 'saved_model.pickle')

  def save_eqx_module(self,
                      model: eqx.Module):
    eqx.tree_serialise_leaves(self.saved_model_path, model)

  def load_eqx_module(self,
                      model_example: eqx.Module):
    return eqx.tree_deserialise_leaves(self.saved_model_path, model_example)

################################################################################################################

class Trainer():
  """Class that will monitor training and handle checkpointing.

  **Attributes**:

  - `checkpointer`: Object that saves checkpoints of the model
  """

  experiment_identifier: ExperimentIdentifier
  config: dict
  checkpointer: Checkpointer
  train_data: TimeSeries
  validation_data: TimeSeries
  test_data: TimeSeries
  optimizer: optax.GradientTransformation

  def __init__(self,
               experiment_identifier: ExperimentIdentifier,
               config: dict,
               checkpoint_path: str,
               train_data: TimeSeries,
               validation_data: TimeSeries,
               test_data: TimeSeries,
               optimizer: optax.GradientTransformation):
    self.experiment_identifier = experiment_identifier
    self.config = config
    print('Trainer checkpoint_path: ', checkpoint_path)
    self.checkpointer = Checkpointer(checkpoint_path, self.experiment_identifier)
    self.train_data = train_data
    self.validation_data = validation_data
    self.test_data = test_data
    self.optimizer = optimizer

  @property
  def pred_len(self):
    return self.config['dataset']['pred_length']

  def apply_predictive_mask(self, data: TimeSeries) -> TimeSeries:
    observation_mask = data.observation_mask
    observation_mask = observation_mask.at[...,-self.pred_len:,:].set(False)
    return eqx.tree_at(lambda x: x.observation_mask, data, observation_mask)

  @property
  def save_folder(self):
    return self.checkpointer.save_path

  @property
  def plot_folder(self):
    path = os.path.join(self.save_folder, 'plots')
    ensure_path_exists(path)
    return path

  @property
  def batch_size(self):
    return self.config['dataset']['train_batch_size']

  @property
  def val_batch_size(self):
    return self.config['dataset']['val_batch_size']

  def make_train_data_iterator(self):
    key = yield random.PRNGKey(0)
    try:
        while True:
            idx = random.randint(
              key,
              minval=0,
              maxval=self.train_data.batch_size,
              shape=(self.batch_size,)
            )
            key = yield self.train_data[idx]
    except StopIteration:
        return

  def make_sequential_data_iterator(self, dataset: TimeSeries):
    total_samples = dataset.batch_size
    num_batches = (total_samples + self.val_batch_size - 1) // self.val_batch_size  # Ceiling division

    for current_idx in range(num_batches):
      start_idx = current_idx * self.val_batch_size
      end_idx = min((current_idx + 1) * self.val_batch_size, total_samples)
      batch = dataset[start_idx:end_idx]
      yield batch

  def init(self,
           model: eqx.Module,
           retrain: bool = False,
           just_load: bool = False) -> TrainingState:
    key0 = random.PRNGKey(1)

    tags = self.checkpointer.save_path.split('/')
    tags = [tag for tag in tags if tag not in {'', 'pi_drsheldon_umass_edu', 'eddie', 'wandb', 'experiment_results', 'checkpoints', 'model_checkpoints', 'samples'}]

    dataset_name = tags[0]
    name = '-'.join(tags)

    if just_load == False:
      run = wandb.init(project=f'autoregressive_forecasting',
                      dir='/project/pi_drsheldon_umass_edu/eddie',
                      config=self.config,
                      resume='allow',
                      reinit=True,
                      name=name)

    # Load the most recent checkpoint
    opt_state = self.optimizer.init(eqx.filter(model, eqx.is_inexact_array))
    train_state = TrainingState(jnp.array(0.0), key0, model, opt_state)

    if retrain == False:
      if self.checkpointer.model_exists() == False:
        retrain = True

    # Load the most recent checkpoint
    if retrain == False:
      train_state = self.restore(train_state)

    return train_state

  def train_objective(
    self,
    model: eqx.Module,
    batched_data: TimeSeries,
    key: PRNGKeyArray,
    debug: bool = False
  ):

    def unbatched_loss(series, key):
      losses = model.loss_fn(series, key, debug=debug)
      objective_name = self.experiment_identifier.get_model_objective()
      if objective_name not in losses:
        raise ValueError(f'Invalid objective: {objective_name}')
      loss = losses[objective_name]

      return loss, losses

    keys = random.split(key, batched_data.batch_size)

    if debug:
      unbatched_loss(batched_data[0], keys[0])

    out = eqx.filter_vmap(unbatched_loss)(batched_data, keys)
    if debug:
      import pdb; pdb.set_trace()
    out = jtu.tree_map(jnp.mean, out)

    loss, losses = out
    return loss, losses

  def train_step(self,
                 train_state: TrainingState,
                 data: TimeSeries,
                 debug: bool = False) -> Tuple[TrainingState, Mapping[str, Any]]:
    model, opt_state = train_state.model, train_state.opt_state
    train_key, next_key = random.split(train_state.key)

    def apply_val_grad(data_and_key: Tuple[TimeSeries, PRNGKeyArray]):
      data, key = data_and_key
      (obj, aux), grads = eqx.filter_value_and_grad(self.train_objective, has_aux=True)(model, data, key)
      return (obj, aux), grads

    if isinstance(data.batch_size, int):
      # Compute the gradients of the objective
      (obj, aux), grads = apply_val_grad((data, train_key))
    else:
      # Otherwise we have a double batch of data and will do gradient accumulation
      assert len(data.batch_size) == 2
      outer_batch_size = data.batch_size[0]
      train_keys = random.split(train_key, outer_batch_size)
      (obj, aux), grads = jax.lax.map(apply_val_grad, (data, train_keys))
      obj, aux = jtu.tree_map(partial(jnp.mean, axis=0), (obj, aux))
      grads = jtu.tree_map(partial(jnp.sum, axis=0), grads)

    aux['objective'] = obj

    # We need aux to contain scalars in order to write it correctly.
    aux = jtu.tree_map(jnp.mean, aux)

    # Update the model
    model_params, model_state = eqx.partition(model, eqx.is_inexact_array)
    updates, new_opt_state = self.optimizer.update(grads, opt_state, model_params)

    if debug:
      if isinstance(data.batch_size, int):
        pass
      else:
        data = data[0]
      self.train_objective(model, data, train_key, debug=True)
      import pdb; pdb.set_trace()

    new_model_params = eqx.apply_updates(model_params, updates)
    new_model = eqx.combine(new_model_params, model_state)

    # See how much the parameters changed
    grad_norms = jtu.tree_map(lambda x, y: jnp.sqrt(jnp.mean((x - y)**2)), new_model_params, model_params)
    grad_norms_list = jnp.array(jtu.tree_flatten(grad_norms)[0])
    aux['param_diff'] = jnp.mean(grad_norms_list)

    # Package the updated training state
    updated_train_state = train_state.apply_training_update(next_key, new_model, new_opt_state)

    return updated_train_state, aux

  def train(self,
            train_state: TrainingState,
            num_steps: int,
            checkpoint_every: int = 1000,
            early_stopping_patience: int = 3,
            gradient_accumulation: Optional[int] = 4,
            trial: Optional[optuna.trial.Trial] = None) -> TrainingState:

    # If we're re-training a model that has already converged, we should just return straight away
    if train_state.number_of_steps_since_best_validation_loss > early_stopping_patience:
      print(f'Early stopping at step {train_state.i} because the validation loss has not improved for {early_stopping_patience} steps')
      wandb.finish()
      return train_state

    # Fill in the training step with the objective and optimizer
    train_step = self.train_step

    # JIT the training update here
    not_jitted_train_step = train_step
    train_step = eqx.filter_jit(train_step)

    # Construct the progress bar
    start = int(train_state.i)
    pbar = tqdm.tqdm(jnp.arange(start, num_steps))

    # Construct the training data iterator
    train_data_iterator = self.make_train_data_iterator()
    next(train_data_iterator)

    # Use the same set of training data for creating plots to monitor training
    training_eval_batch = None
    validation_eval_batch = None

    # Also create the jitted train objective function
    jitted_train_objective = eqx.filter_jit(self.train_objective)


    @partial(eqx.filter_vmap, in_axes=(None, 0, None))
    def sample_fn(model, key, series):
      assert series.batch_size is None

      # Create samples
      samples = model.sample(key, series, debug=False)

      # Also return the true latent series
      latent_series = model.basic_interpolation(key, series)

      return samples, latent_series

    jitted_sample_fn = eqx.filter_jit(sample_fn)

    # Training loop
    for i in pbar:

      # Get the training data
      if gradient_accumulation is None or gradient_accumulation <= 1:
        data = train_data_iterator.send(train_state.key)
      else:
        data_keys = random.split(train_state.key, gradient_accumulation)
        data_double_batch = []
        for key in data_keys:
          data = train_data_iterator.send(key)
          data_double_batch.append(data)
        data = jtu.tree_map(lambda *x: jnp.array(x), *data_double_batch)

      # Take a training step
      old_train_state = train_state
      train_state, aux = train_step(train_state, data)

      if training_eval_batch is None:
        training_eval_batch = data

      # Check for NaN values in aux
      if jnp.isnan(aux['param_diff']):
        print('NaN parameters detected, entering debugger')
        not_jitted_train_step(old_train_state, data, debug=True)
        raise ValueError("NaN parameters detected")

      # Update the progress bar
      description = ', '.join([f'{k}={float(v.mean()):.4f}' for k, v in aux.items()])
      pbar.set_description(description)

      if i%100 == 0:
        # Log the training metrics to wandb
        aux_for_wandb = {f'training/{k}': v for k, v in aux.items()}
        aux_for_wandb['training/step'] = train_state.i
        wandb.log(aux_for_wandb)

      # Checkpoint the model and save metrics
      if (i and checkpoint_every and checkpoint_every > 0 and (i%checkpoint_every == 0)):

        #################
        # Early stopping
        #################
        # Evaluate the model on the validation set
        val_data_iter = self.make_sequential_data_iterator(self.validation_data)
        objective_sum = None

        val_key = random.PRNGKey(0)

        for i, val_batch in tqdm.tqdm(list(enumerate(val_data_iter)), leave=False):
          if validation_eval_batch is None:
            validation_eval_batch = val_batch

          loss, _ = jitted_train_objective(train_state.model, val_batch, val_key)
          mean_objective = loss

          # Undo the mean reduction so that we can get the cumulative sum of the metrics
          val_batch_size = val_batch.batch_size
          if objective_sum is None:
            objective_sum = mean_objective*val_batch_size
          else:
            objective_sum = objective_sum + mean_objective*val_batch_size

          # Split the key
          _, val_key = random.split(val_key)

        # Compute the mean of the metrics
        objective_mean = objective_sum/self.validation_data.batch_size

        # Update the best validation loss
        train_state = train_state.update_best_validation_loss(objective_mean)

        print('About to checkpoint model')
        self.checkpoint(train_state)
        print('Checkpointed model')

        # Save off to wandb
        print(f'\nValidation loss: {objective_mean}')
        print('\n')
        for_wandb = {'validation/validation_loss': objective_mean, 'validation/step': train_state.i}
        wandb.log(for_wandb)

        # Determine if we should quit training
        if train_state.number_of_steps_since_best_validation_loss > early_stopping_patience:
          print(f'Early stopping at step {train_state.i} because the validation loss has not improved for {early_stopping_patience} steps')
          break

        if self.config['command_line_args']['log_plots']:

          # Compute the metrics on this batch of training data
          train_eval_key = random.PRNGKey(0)
          keys = random.split(train_eval_key, 8)

          # Log the training plots to wandb
          ground_truth = training_eval_batch[0]
          if (gradient_accumulation is None or gradient_accumulation <= 1) == False:
            ground_truth = ground_truth[0]
          samples, latent_series = jitted_sample_fn(train_state.model, keys, ground_truth)
          self.log_plot(ground_truth, samples, latent_series, 'training', train_state.i, log_to_wandb=True)

          # Log the validation plots to wandb
          ground_truth = validation_eval_batch[0]
          samples, latent_series = jitted_sample_fn(train_state.model, keys, ground_truth)
          self.log_plot(ground_truth, samples, latent_series, 'validation', train_state.i, log_to_wandb=True)

        if trial is not None:
          # Report the validation loss to the trial
          trial.report(objective_mean, train_state.i)

          # Prune if the trial is no longer promising
          if trial.should_prune():
            raise optuna.TrialPruned()

    # Final checkpoint
    self.checkpoint(train_state)
    print('Checkpointed model')
    wandb.finish()
    return train_state

  def log_plot(self,
               ground_truth: TimeSeries, # Original y_{1:N}
               predictions: TimeSeries, # samples from q(x_{1:N} | y_{1:k})
               latent_series: TimeSeries, # samples from p(x_{1:N} | y_{1:N})
               dataset_name: str,
               step: int,
               log_to_wandb: bool = True):

    fig, axes = TimeSeries.plot_multiple_series(series_list=[ground_truth, predictions, latent_series],
                                    titles=['Ground Truth', 'Predictions', 'Latent Series'],
                                    common_title=f'Dataset: {dataset_name}, Step: {step}',
                                    index='all',
                                    show_plot=False,
                                    use_max_dims=True)

    # Save and log
    filename = f'{self.plot_folder}/{dataset_name}_plot_step_{step}'
    plt.savefig(f'{filename}.png')

    if log_to_wandb:
      # Save the plot to wandb
      wandb.save(f'{filename}.png', base_path=self.plot_folder)
      print(f'Logged plot to wandb at step {step} for {dataset_name} dataset with filename {filename}')
    plt.close()

  def checkpoint(self, train_state: TrainingState):
    # Save off the model
    print('About to checkpoint model')
    self.checkpointer.save_eqx_module(train_state)
    print('Checkpointed model')

    # Update training metadata
    from Models.training_tracker import update_training_metadata_from_train_state

    # Get max steps from config
    max_steps = None
    if 'model' in self.config and 'optimizer' in self.config['model']:
      if 'max_train_steps' in self.config['model']['optimizer']:
        max_steps = self.config['model']['optimizer']['max_train_steps']

    # Update training metadata
    print('About to update training metadata')
    update_training_metadata_from_train_state(
      self.experiment_identifier,
      train_state,
      max_steps
    )
    print('Updated training metadata')

  def restore(self, train_state: TrainingState) -> TrainingState:
    # Load the model
    train_state = self.checkpointer.load_eqx_module(train_state)
    print(f'Restored train_state {self.checkpointer.saved_model_path}')

    return train_state

################################################################################################################

def default_optimizer(lr=1e-3,
                      clip_norm=15.0,
                      warmup=1000,
                      decay_steps=3e5,
                      end_value=0.1,
                      cosine_exponent=1.0,
                      weight_decay_hyper=10.0,
                      dataset_size=None,
                      batch_size=None) -> optax.GradientTransformation:
  """
  Gradient clipping, AdamW, and cosine decay with warmup.
  """
  schedule = optax.warmup_cosine_decay_schedule(init_value=0.0,
                                                peak_value=1.0,
                                                warmup_steps=warmup,
                                                decay_steps=decay_steps,
                                                end_value=end_value,
                                                exponent=cosine_exponent)
  chain = []
  chain.append(optax.clip_by_global_norm(clip_norm))

  # https://arxiv.org/pdf/2405.13698v1
  if dataset_size is not None:
    weight_decay = 1/(lr*dataset_size/batch_size*weight_decay_hyper)
  else:
    weight_decay = 0.0001

  chain.append(optax.adamw(lr, weight_decay=weight_decay))
  chain.append(optax.scale_by_schedule(schedule))
  optimizer = optax.chain(*chain)
  return optimizer

################################################################################################################
