import os
import functools
import pandas as pd
from typing import Sequence

from absl import app
from absl import logging
import jax
import jax.numpy as jnp
from jax_privacy.experiments.image_classification import config_base
from jax_privacy.experiments.image_classification import experiment as experiment_py
from ml_collections import config_flags

from src.utils import write_pickle

import pdb

_CONFIG = config_flags.DEFINE_config_file(
    'config',
    'configs/mnist.py',
    help_string='Experiment configuration file.',
)

def train_eval(config: config_base.ExperimentConfig) -> None:
  """Runs the experiment as a loop without a Jaxline experiment."""
  # Setup.
  logging.info('[Setup] Setting up experiment with config %s', config)

  experiment = experiment_py.ImageClassificationExperiment(config)

  # Save directory
  if 'crafted' in experiment._config.name:
    save_dir = f'out/crafted/black_box/{experiment._config.name}'
  else:
    save_dir = f'out/black_box/{experiment._config.name}'

  if not os.path.exists(save_dir):
    os.makedirs(save_dir)

  # Initialization.
  logging.info('[Init] Starting initialization...')
  state, step_on_host, train_data = experiment.initialize(
      jax.random.PRNGKey(134))
  logging.info('[Init] Initialization complete.')

  # add dummy dirac gradient metadata
  # makes it compatible with the black box code without actually using dirac gradients
  def add_dummy_dirac_idxs(x):
    dummy_array = x['metadata']['id']
    x['metadata']['param_idx'] = jnp.zeros_like(dummy_array).astype(int)
    x['metadata']['dirac'] = jnp.zeros_like(dummy_array).astype(bool)
    return x

  def dummy_dirac_generator(generator):
    for x in generator:
        yield add_dummy_dirac_idxs(x)

  train_data = dummy_dirac_generator(train_data)

  # save initial state
  write_pickle(state, os.path.join(save_dir, 'state_init.pkl')) # save state

  # Training loop.
  while step_on_host < experiment.max_num_updates:
    state, train_metrics, step_on_host = (
        experiment.update(
            state=state,
            step_on_host=step_on_host,
            inputs_producer=functools.partial(next, train_data),
        )
    )

    if 'nondp' in experiment._config.name:
      train_interval = 100
      eval_interval = 5000
    else:
      train_interval = 10
      eval_interval = 100
  
    # Log train
    if step_on_host % train_interval == 0:
      logging.info('[Train] %s', train_metrics)
      # # early stopping for non-dp
      # if train_metrics['acc1'] == 100:
      #   break

    # Evaluation loop.
    if step_on_host % eval_interval == 0:
      eval_metrics = experiment.evaluate(
          state=state,
          step_on_host=step_on_host,
      )
      logging.info('[Eval] %s', eval_metrics)
    
  eval_metrics = experiment.evaluate(
      state=state,
      step_on_host=step_on_host,
  )
  logging.info('[Eval] %s', eval_metrics)

  logging.info('Training and evaluation complete.')
  # save last checkpoint
  write_pickle(state, os.path.join(save_dir, 'state_final.pkl'))

def main(argv: Sequence[str]) -> None:
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  jaxline_config = _CONFIG.value
  experiment_config = jaxline_config.experiment_kwargs.config
  train_eval(experiment_config)

if __name__ == '__main__':
  app.run(main)