import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

import functools
import itertools
from shift_match import make_resnet20_frn_fn
import haiku as hk
from absl import flags
from tqdm.notebook import tqdm
from bnn_hmc.utils import data_utils
from bnn_hmc.utils import precision_utils
from bnn_hmc.utils import metrics
from bnn_hmc.utils import models
from jax import vmap
from jax.interpreters import xla
from jax import lax
from jax.random import normal, PRNGKey, split
import numpy as onp
from jax import numpy as jnp
from bnn_hmc.utils import checkpoint_utils
from matplotlib import pyplot as plt
import seaborn as sns
import matplotlib
import jax
import sys
import tensorflow_datasets as tfds
import os
import pickle
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'


sys.path.append("google-research/")
log_dir = '/home/ooo123/data/sm_sgd_result_log'

def get_params_and_state_sgd(seed):
  # sgd_base_dir = f'/home/ooo123/data/sgd_log/seed{seed}'
  # ckpt_dir = os.path.join(sgd_base_dir, 'model_step_499.pt')
  sgd_base_dir = ('/home/ooo123/data/sgd_log_aug/' +
                  f'sgd_mom_0.9__lr_sch_i_1e-06___epochs_200_wd_0.0001_batchsize_80_temp_1.0__seed_{seed}')
  ckpt_dir = os.path.join(sgd_base_dir, 'model_step_199.pt')
  ckpt_dict = checkpoint_utils.load_checkpoint(ckpt_dir)
  params = ckpt_dict["params"]
  net_state = ckpt_dict["net_state"]
  # params = jax.tree_map(lambda p: p[chain_id], params)
  # net_state = jax.tree_map(lambda p: p[chain_id], net_state)
  return params, net_state

cache_dir = '/home/ooo123/data/sm_sgd_cov_cache'
def load_cov_cache(seed):
  ckpt_dir = os.path.join(cache_dir, f"bn_sm__{seed}.pkl")
  state_dict_sm = checkpoint_utils.load_checkpoint(ckpt_dir)
  ckpt_dir = os.path.join(cache_dir, f"sm_{seed}.pkl")
  state_dict_bn_sm = checkpoint_utils.load_checkpoint(ckpt_dir)
  return state_dict_sm, state_dict_bn_sm

_DEFAULT_BN_CONFIG = {
    "decay_rate": 0.9,
    "eps": 1e-5,
    "create_scale": True,
    "create_offset": True
}


def normalization_layer():
  return hk.BatchNorm(**_DEFAULT_BN_CONFIG)


shift_match_config = {
    'feature_only': False,
    'shift_match_before_act': True
}

def evalaute(pred, target):
  acc = metrics.accuracy(pred, target)
  nll = metrics.nll(pred, target)
  ece = metrics.calibration_curve(pred, target)['ece']
  result_dict = {
    'prediction': pred,
    'acc': acc,
    'nll': nll,
    'ece': ece
  }
  return result_dict

def get_ds_name(corruption, intensity):
  return "cifar10_corrupted/{}_{}".format(corruption, intensity)

datset_list = [
    # 'frosted_glass_blur',
    # 'impulse_noise',
    # 'pixelate',
    # 'saturate',
    # 'brightness',
    # 'contrast',
    # 'defocus_blur',
    # 'elastic',
    # 'shot_noise',
    # 'spatter',
    # 'speckle_noise',
    # 'zoom_blur',
    'fog',
    'frost',
    'gaussian_blur',
    'gaussian_noise']

method_names = ['sgd', 'sgd_bn', 'sgd_sm', 'sgd_bn_sm']

def save_result(ds, name, preds, target):
  for m, p in zip(method_names, preds):
    result_dict = evalaute(p, target)
    log_name = f'{m}_{name}_{ds}.pkl'
    with open(os.path.join(log_dir, log_name), "wb") as f:
      pickle.dump(result_dict, f)


train_set, _, _ = data_utils.load_image_dataset('train', 5000, 'cifar10')

ds_list = [get_ds_name(dataset, intensity).lower()
        for (dataset, intensity) in itertools.product(datset_list, range(1,6))]
# ds_list = ['cifar10'] + ds_list

for ds in ds_list:
  print(f"Evaluating {ds}")
  test_set_c, _, _ = data_utils.load_image_dataset(
      'test', -1, ds)
  test_set_c = next(iter(test_set_c))
  cifar_c, cifar_c_targets = test_set_c
  y_hat_sgd = 0.
  y_hat_sgd_bn = 0.
  y_hat_sm = 0.
  y_hat_bn_sm = 0.
  for seed in tqdm(range(0, 500, 10)):
    x = normal(PRNGKey(0), (10, 32, 32, 3))
    y = jnp.ones_like(x)

    net_init, net_apply = hk.transform_with_state(
        make_resnet20_frn_fn('channel_wise_sep_cov_mean', normalization_layer=normalization_layer,
                            activation=jax.nn.relu)
    )
    params, init_state = net_init(PRNGKey(0), batch=x, is_training=True, shift_match_mode=None,
                                  **shift_match_config)
    params, state_bn = get_params_and_state_sgd(seed)
    state_w_bn, state_wo_bn = load_cov_cache(seed)
    preds_sgd_bn_off, _ = net_apply(params, state_bn, None, cifar_c, False, None, **shift_match_config)
    preds_sgd_bn_off = jax.nn.softmax(preds_sgd_bn_off, -1)
    y_hat_sgd += preds_sgd_bn_off

    preds_sgd_bn_on, _ = net_apply(params, state_bn, None, cifar_c, True, None, **shift_match_config)
    preds_sgd_bn_on = jax.nn.softmax(preds_sgd_bn_on, -1)
    y_hat_sgd_bn += preds_sgd_bn_on
    
    preds_sgd_sm, _ = net_apply(params, state_wo_bn, None, cifar_c, False, 'match', **shift_match_config)
    preds_sgd_sm = jax.nn.softmax(preds_sgd_sm, -1)
    y_hat_sm += preds_sgd_sm

    preds_sgd_bn_sm, _ = net_apply(params, state_w_bn, None, cifar_c, True, 'match', **shift_match_config)
    preds_sgd_bn_sm = jax.nn.softmax(preds_sgd_bn_sm, -1)
    y_hat_bn_sm += preds_sgd_bn_sm
    # save_result(ds[18:], seed, [preds_sgd_bn_off,preds_sgd_bn_on,preds_sgd_sm,preds_sgd_bn_sm],
    #       cifar_c_targets)

  y_hat_sgd = y_hat_sgd / 50
  y_hat_sgd_bn = y_hat_sgd_bn / 50
  y_hat_sm = y_hat_sm / 50
  y_hat_bn_sm = y_hat_bn_sm / 50
  save_result(ds[18:], 'de', [y_hat_sgd,y_hat_sgd_bn,y_hat_sm,y_hat_bn_sm],
          cifar_c_targets)