import os
import functools
import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import torch
from torch.nn.functional import cross_entropy

from absl import app
from absl import logging
from typing import Sequence
from ml_collections import config_flags

from jax_privacy.experiments.image_classification import config_base
from jax_privacy.experiments.image_classification import experiment as experiment_py
from third_party.jax_privacy.jax_privacy.src.training.updater import UpdaterState
from src.utils import write_pickle, read_pickle
import pdb

EVAL_BATCH_SIZE = 100

assert jax.process_count() == 1


_CONFIG = config_flags.DEFINE_config_file(
    'config',
    'configs/mnist.py',
    help_string='Experiment configuration file.',
)

def train_eval(config: config_base.ExperimentConfig) -> None:
  """Runs the experiment as a loop without a Jaxline experiment."""
  # Setup.
  logging.info('[Setup] Setting up experiment with config %s', config)

  experiment = experiment_py.ImageClassificationExperiment(config)
  
  # load last iter state
  if 'crafted' in experiment._config.name:
    save_dir = f'out/crafted/black_box/{experiment._config.name}'
  else:
    save_dir = f'out/black_box/{experiment._config.name}'
    
  state_path = os.path.join(save_dir, 'state_final.pkl')
  state = read_pickle(state_path)

  # get dataloader
  from itertools import chain
  def get_ds(
    data,
    batch_size=EVAL_BATCH_SIZE,
  ):
    return data.load_dataset(
      batch_dims=(
        # jax.process_count(),
        jax.local_device_count(),
        batch_size,
      ),
      is_training=False,
      drop_metadata=False,
      shard_data=False,
    )
  
  def get_ds_all():
    return chain(
      get_ds(experiment._config.data_train),
      get_ds(experiment._config.data_eval),
      get_ds(experiment._config.data_eval_additional),
    )

  # load and eval model
  from third_party.jax_privacy.jax_privacy.experiments.image_classification.models.cifar import WideResNet
  num_classes = config.data_train.config.num_classes
  model_config = config.model

  def get_embedding(x):
    model = WideResNet(
        num_classes=num_classes,
        depth=model_config.depth,
        width=model_config.width,
        dropout_rate=model_config.dropout_rate,
        use_skip_init=model_config.use_skip_init,
        use_skip_paths=model_config.use_skip_paths,
        which_conv=model_config.which_conv,
        which_norm=model_config.which_norm,
        activation=model_config.activation,
        groups=model_config.groups,
        is_dp=model_config.is_dp,
    )
    return model(x, is_training=False, return_embedding=True)
  
  import haiku as hk
  model = hk.transform(get_embedding)

  @jax.pmap
  def pmap_apply(params, rng, x):
    return model.apply(params, rng, x)

  params = state['params_avg']['ema']
  # the replicated version of the weights was saved. for now, we just handle it here
  params = jax.tree.map(lambda x: jnp.repeat(x, jax.local_device_count(), axis=0), params)

  rng_key = jax.random.PRNGKey(0)

  out = {
    k: [] for k in [
      'id',
      'include',
      'holdout',
      'canary',
      'label',
      'logit',
      # 'embedding',
    ]
  }
  for batch in get_ds_all():
    rng_key, apply_key = jax.random.split(rng_key, 2)
    apply_keys = jax.random.split(apply_key, jax.local_device_count())

    # get info about examples
    metadata = batch['metadata']
    for k in [
      'id',
      'include',
      'holdout',
      'canary',
    ]:
      out[k].append(metadata[k].reshape(-1))

    # get label
    labels_oh = batch['label']
    labels = jnp.argmax(labels_oh, axis=2).reshape(-1)
    out['label'].append(labels)

    # pass through model
    images = batch['image']
    logits, embeddings = pmap_apply(params, apply_keys, images)
    logits = logits.reshape(-1, logits.shape[2])
    embeddings = embeddings.reshape(-1, embeddings.shape[2])
    out['logit'].append(logits)
    # out['embedding'].append(embeddings)

  # concat
  for k, v in out.items():
    out[k] = jnp.concatenate(v)

  # remove duplicates - e.g., train and canary sets overlap (since we train on half of the canaries)
  _, idxs = jnp.unique(out['id'], return_index=True)
  for k, v in out.items():
    out[k] = v[idxs]

  # convert to numpy
  for k, v in out.items():
    v = np.array(v)
    out[k] = v
  
  # calculate additional metrics
  n_rows = len(out['logit'])
  logit_target = out['logit'][np.arange(n_rows), out['label']].copy()
  out['logit_target'] = logit_target

  logits_target_masked_out = out['logit'].copy()
  logits_target_masked_out[np.arange(n_rows), out['label']] = 0
  sum_logit_nontarget = logits_target_masked_out.sum(axis=1)
  out['sum_logit_nontarget'] = sum_logit_nontarget

  logits_target_masked_out = out['logit'].copy()
  logits_target_masked_out[np.arange(n_rows), out['label']] = -np.inf
  logit_highest_nontarget = logits_target_masked_out.max(axis=1)
  out['sum_logit_nontarget'] = sum_logit_nontarget

  neg_loss = -cross_entropy(
      torch.tensor(out['logit']),
      torch.tensor(out['label']).long(),
      reduction='none',
  ).numpy()
  out['neg_loss'] = neg_loss

  out['logit_diff'] = logit_target - sum_logit_nontarget
  out['logit_next_diff'] = logit_target - logit_highest_nontarget
  
  # save
  out_path = os.path.join(save_dir, 'out.pkl')
  write_pickle(out, out_path)

def main(argv: Sequence[str]) -> None:
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  jaxline_config = _CONFIG.value
  experiment_config = jaxline_config.experiment_kwargs.config
  train_eval(experiment_config)

if __name__ == '__main__':
  app.run(main)