'''Get cov train before activation
'''

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
import sys
sys.path.append("google-research/")
import os

from tqdm import tqdm
import pickle
import itertools
from matplotlib import pyplot as plt
import seaborn as sns
import matplotlib
import jax
import tensorflow_datasets as tfds
from bnn_hmc.utils import checkpoint_utils
from jax import numpy as jnp
import numpy as onp
from jax.random import normal, PRNGKey, split
from jax import lax
from jax.interpreters import xla
from jax import vmap
from bnn_hmc.utils import models
from bnn_hmc.utils import metrics
from bnn_hmc.utils import precision_utils
from bnn_hmc.utils import data_utils
from tqdm.notebook import tqdm
from absl import flags
import haiku as hk
import functools
import argparse
from shift_match import make_resnet20_frn_fn, shift_match_builder


parser = argparse.ArgumentParser(description='Experiment config.')
parser.add_argument('-d', action="store")
args = parser.parse_args()
dataset = args.d
print(f'Evaluating performance on {dataset}')


def get_params_and_state_sgd(seed):
  # sgd_base_dir = f'/home/ooo123/data/sgd_log/seed{seed}'
  # ckpt_dir = os.path.join(sgd_base_dir, 'model_step_499.pt')
  sgd_base_dir = ('/home/ooo123/data/sgd_log/' + 
    f'sgd_mom_0.9__lr_sch_i_3e-07___epochs_200_wd_10.0_batchsize_80_temp_1.0__seed_{seed}')
  ckpt_dir = os.path.join(sgd_base_dir, 'model_step_199.pt')
  ckpt_dict = checkpoint_utils.load_checkpoint(ckpt_dir)
  params = ckpt_dict["params"]
  net_state = ckpt_dict["net_state"]
  # params = jax.tree_map(lambda p: p[chain_id], params)
  # net_state = jax.tree_map(lambda p: p[chain_id], net_state)
  return params, net_state



def get_ds_name(corruption, intensity):
  return "cifar10_corrupted/{}_{}".format(corruption, intensity)


x = normal(PRNGKey(0), (10, 32, 32, 3))
y = jnp.ones_like(x)


match_types = [
    # 'None',
    # 'feature',
    # 'channel_wise_joint',
    'channel_wise_sep',
    # 'spatial_joint',
    # 'spatial_sep',
    # 'fft_spatial'
    'batch_norm'
]

# cache_dir = '/home/ooo123_321321/scratch/cov_caches'
cache_dir = '/home/ooo123/data/cov_caches'

shift_match_config = {
  'feature_only': False,
  'shift_match_before_act': True
}

for match_type in match_types:
  net_init, net_apply = hk.transform_with_state(
      make_resnet20_frn_fn(match_type)
  )
  if match_type == 'batch_norm':
    shift_match_config['feature_only']=True
    shift_match_config['shift_match_before_act']=True
  params, init_state = net_init(PRNGKey(0), batch=x, is_training=True, shift_match_mode=None,
  **shift_match_config)
  train_set, _, _ = data_utils.load_image_dataset('train', 5000, 'cifar10')
  for seed in tqdm(range(0,500,10)):
    state = init_state
    pt = get_params_and_state_sgd(seed)[0]
    for x, y in tqdm(train_set):
      preds, state = net_apply(pt, state, None, x, False, 'acc', **shift_match_config)
    # sgd_BA_new.pkl comes from the new run
    # sgd_BA.pkl is the original results
    with open(os.path.join(cache_dir, f"{match_type}__{seed}_sgd_BA_new.pkl"), "wb") as f:
      pickle.dump(state, f)