import os
import functools
import pandas as pd
from typing import Sequence

from absl import app
from absl import logging
import jax
import jax.numpy as jnp
from jax_privacy.experiments.image_classification import config_base
from jax_privacy.experiments.image_classification import experiment as experiment_py
from ml_collections import config_flags

from src.utils import write_pickle, read_pickle


PRETRAIN_PATH_BASE = 'out/black_box/nondp_cinic10_without_cifar_wrn_{}_{}/state_final.pkl'

_CONFIG = config_flags.DEFINE_config_file(
    'config',
    'configs/mnist.py',
    help_string='Experiment configuration file.',
)

def train_eval(config: config_base.ExperimentConfig) -> None:
  """Runs the experiment as a loop without a Jaxline experiment."""
  # Setup.
  logging.info('[Setup] Setting up experiment with config %s', config)

  experiment = experiment_py.ImageClassificationExperiment(config)

  # Save directory
  if 'crafted' in experiment._config.name:
    save_dir = f'out/crafted/black_box/ft_cinic10_{experiment._config.name}'
  else:
    save_dir = f'out/black_box/ft_cinic10_{experiment._config.name}'

  if not os.path.exists(save_dir):
    os.makedirs(save_dir)

  # Initialization.
  logging.info('[Init] Starting initialization...')
  state, step_on_host, train_data = experiment.initialize(
      jax.random.PRNGKey(134))
  logging.info('[Init] Initialization complete.')

  # replace state with pretrained
  import re 
  match = re.search(r'wrn_(\d+)_(\d+)', config.name)
  if match:
      width = match.group(1)
      depth = match.group(2)
  else:
    assert False
  pretrain_path = PRETRAIN_PATH_BASE.format(width, depth)
  state_pretrained = read_pickle(pretrain_path)

  from jax_privacy.src.training.updater import UpdaterState
  state = UpdaterState(
    params=state_pretrained['params_avg']['ema'],
    params_avg=state_pretrained['params_avg'],
    network_state=state['network_state'],
    update_step=state['update_step'],
    opt_state=state['opt_state'],
    noise_state=state['noise_state'],
  )

  # add dummy dirac gradient metadata
  # makes it compatible with the black box code without actually using dirac gradients
  def add_dummy_dirac_idxs(x):
    dummy_array = x['metadata']['id']
    x['metadata']['param_idx'] = jnp.zeros_like(dummy_array).astype(int)
    x['metadata']['dirac'] = jnp.zeros_like(dummy_array).astype(bool)
    return x

  def dummy_dirac_generator(generator):
    for x in generator:
        yield add_dummy_dirac_idxs(x)

  train_data = dummy_dirac_generator(train_data)

  # save initial state
  write_pickle(state, os.path.join(save_dir, 'state_init.pkl')) # save state


  # intial eval
  eval_metrics = experiment.evaluate(
      state=state,
      step_on_host=step_on_host,
  )
  logging.info('[Eval] %s', eval_metrics)

  # Training loop.
  while step_on_host < experiment.max_num_updates:
    state, train_metrics, step_on_host = (
        experiment.update(
            state=state,
            step_on_host=step_on_host,
            inputs_producer=functools.partial(next, train_data),
        )
    )
  
    # Log train
    if step_on_host % 10 == 0:
      logging.info('[Train] %s', train_metrics)
      # early stopping for non-dp
      if train_metrics['acc1'] == 100:
        break

    # Evaluation loop.
    eval_interval = 5000 if experiment._config.name.startswith('nondp_') else 100
    if step_on_host % eval_interval == 0:
      eval_metrics = experiment.evaluate(
          state=state,
          step_on_host=step_on_host,
      )
      logging.info('[Eval] %s', eval_metrics)
    
  eval_metrics = experiment.evaluate(
      state=state,
      step_on_host=step_on_host,
  )
  logging.info('[Eval] %s', eval_metrics)

  logging.info('Training and evaluation complete.')
  # save last checkpoint
  write_pickle(state, os.path.join(save_dir, 'state_final.pkl'))

def main(argv: Sequence[str]) -> None:
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  jaxline_config = _CONFIG.value
  experiment_config = jaxline_config.experiment_kwargs.config
  train_eval(experiment_config)

if __name__ == '__main__':
  app.run(main)