import math
from nis import match
import os

os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'


import sys
sys.path.append("google-research/")

from shift_match import make_resnet20_frn_fn
import argparse
import functools
import haiku as hk
from absl import flags
from tqdm.notebook import tqdm
from bnn_hmc.utils import data_utils
from bnn_hmc.utils import precision_utils
from bnn_hmc.utils import metrics
from bnn_hmc.utils import models
from jax import vmap
from jax.interpreters import xla
from jax import lax
from jax.random import normal, PRNGKey, split
import numpy as onp
from jax import numpy as jnp
from bnn_hmc.utils import checkpoint_utils
import tensorflow_datasets as tfds
import jax
import matplotlib
import seaborn as sns
from matplotlib import pyplot as plt
import itertools
import pickle
from tqdm import tqdm
import os
import os
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

match_types = [
    'None',
    'feature',
    'feature_cov',
    'channel_wise_joint',
    'channel_wise_sep',
    'spatial_joint',
    'spatial_joint_cov',
    'spatial_sep',
    'spatial_sep_cov',
    'fft_spatial',
    'channel_wise_sep_cov',
    'channel_wise_sep_cov_mean',
    'batch_norm',
    'spatial_sep_cov_mean'
]


parser = argparse.ArgumentParser(description='Experiment config.')
parser.add_argument('-m', action="store")
parser.add_argument('-p', action='store')
args = parser.parse_args()
match_type = args.m
dataset_part = int(args.p)

assert match_type in match_types


# datset_list = [
#   'brightness',
#  'contrast',
#  'defocus_blur',
#  'elastic',
#  'fog',
#  'frost',
#  'frosted_glass_blur',
#  'gaussian_blur',
#  'gaussian_noise',
#  'impulse_noise',
#  'jpeg_compression',
#  'motion_blur',
#  'pixelate',
#  'saturate',
#  'shot_noise',
#  'snow',
#  'spatter',
#  'speckle_noise',
#  'zoom_blur']


datset_list = [
    'frosted_glass_blur',
    'impulse_noise',
    'pixelate',
    'saturate',
    'brightness',
    'contrast',
    'defocus_blur',
    'elastic',
    'shot_noise',
    'spatter',
    'speckle_noise',
    'zoom_blur',
    'fog',
    'frost',
    'gaussian_blur',
    'gaussian_noise']


if dataset_part == -1:
  datset_list = datset_list
elif dataset_part == 0:
  datset_list = datset_list[:5]
elif dataset_part == 1:
  datset_list = datset_list[5:10]
else:
  datset_list = datset_list[10:]

print(f'Evaluating model {match_type} on dataset: {datset_list}')

# assert dataset in type_list

# def get_params_and_state_sgd(seed):
#   sgd_base_dir = f'/home/ooo123/data/sgd_log/seed{seed}'
#   ckpt_dir = os.path.join(sgd_base_dir, 'model_step_499.pt')
#   ckpt_dict = checkpoint_utils.load_checkpoint(ckpt_dir)
#   params = ckpt_dict["params"]
#   net_state = ckpt_dict["net_state"]
#   # params = jax.tree_map(lambda p: p[chain_id], params)
#   # net_state = jax.tree_map(lambda p: p[chain_id], net_state)
#   return params, net_state

def get_params_and_state_sgd(seed):
  sgd_base_dir = ('/home/ooo123/data/sgd_log/' + 
    f'sgd_mom_0.9__lr_sch_i_3e-07___epochs_200_wd_10.0_batchsize_80_temp_1.0__seed_{seed}')
  ckpt_dir = os.path.join(sgd_base_dir, 'model_step_199.pt')
  ckpt_dict = checkpoint_utils.load_checkpoint(ckpt_dir)
  params = ckpt_dict["params"]
  net_state = ckpt_dict["net_state"]
  # params = jax.tree_map(lambda p: p[chain_id], params)
  # net_state = jax.tree_map(lambda p: p[chain_id], net_state)
  return params, net_state

def get_params_and_state(chain_id, sample_id):
  ckpt_dict = checkpoint_utils.load_checkpoint(
      "cifar10/state-{}.pkl".format(sample_id))
  params = ckpt_dict["params"]
  net_state = ckpt_dict["net_state"]
  params = jax.tree_map(lambda p: p[chain_id], params)
  net_state = jax.tree_map(lambda p: p[chain_id], net_state)
  return params, net_state

def get_cov_train_cache_sgd(match_type, seed):
  if match_type == 'None':
    match_type = 'channel_wise_sep'
  if match_type == 'spatial_sep_cov_mean':
    match_type = 'spatial_sep_cov'
  if match_type == 'channel_wise_sep_cov':
    match_type = 'channel_wise_sep'
  if match_type == 'channel_wise_sep_cov_mean':
    match_type = 'channel_wise_sep'
  # if match_type == 'spatial_sep_cov':
  #   match_type = 'spatial_sep'
  # cache_dir = '/home/ooo123_321321/scratch/cov_caches'
  cache_dir = '/home/ooo123/data/cov_caches'
  ckpt_dir = os.path.join(cache_dir, f"{match_type}__{seed}_sgd_BA_new.pkl")
  state_dict = checkpoint_utils.load_checkpoint(ckpt_dir)
  return state_dict


def get_ds_name(corruption, intensity):
  return "cifar10_corrupted/{}_{}".format(corruption, intensity)

def evalaute(pred, target):
  acc = metrics.accuracy(pred, target)
  nll = metrics.nll(pred, target)
  ece = metrics.calibration_curve(pred, target)['ece']
  result_dict = {
    'prediction': pred,
    'acc': acc,
    'nll': nll,
    'ece': ece
  }
  return result_dict

# net_apply = precision_utils.rewrite_high_precision(net_apply)



x = normal(PRNGKey(0), (10, 32, 32, 3))
y = jnp.ones_like(x)

# log_dir = '/home/ooo123_321321/shift_match_log'
# log_dir = '/home/ooo123_321321/scratch/shift_match_log'
log_dir = '/home/ooo123/data/shift_match_log'

# for match_type in match_types:
  # match_type = 'channel_wise_sep'
net_init, net_apply = hk.transform_with_state(
    make_resnet20_frn_fn(match_type)
)

shift_match_config = {
  'feature_only': False,
  'shift_match_before_act': True
}

if match_type == 'batch_norm':
    shift_match_config['feature_only']=True
    shift_match_config['shift_match_before_act']=True



for dataset in datset_list:
  for intensity in range(1, 6):
    print(f"Evaluating {dataset} of intensity {intensity}")
    params, init_state = net_init(PRNGKey(0), batch=x, is_training=True, shift_match_mode=None)
    test_set_c, _, _ = data_utils.load_image_dataset(
        'test', -1, get_ds_name(dataset, intensity).lower())
    test_set_c = next(iter(test_set_c))
    cifar_c, cifar_c_targets = test_set_c
    y_hat = 0.
    counter = 0.
    for seed in tqdm(range(0,500,10)):
      counter += 1
      state = init_state
      pt = get_params_and_state_sgd(seed)[0]
      state = get_cov_train_cache_sgd(match_type, seed)
      preds, _ = net_apply(pt, state, None, cifar_c, False, 'match',
                           **shift_match_config)
      preds = jax.nn.softmax(preds, -1)
      # print(preds)
      result_dict = evalaute(preds, cifar_c_targets)
      log_name = f'{dataset}_{intensity}_{match_type}_sgd_{seed}.pkl'
      with open(os.path.join(log_dir, log_name), "wb") as f:
        pickle.dump(result_dict, f)
      y_hat += preds
    y_hat = y_hat / counter
    acc = metrics.accuracy(y_hat, cifar_c_targets)
    nll = metrics.nll(y_hat, cifar_c_targets)
    ece = metrics.calibration_curve(y_hat, cifar_c_targets)['ece']
    result_dict = {
      'prediction': y_hat,
      'acc': acc,
      'nll': nll,
      'ece': ece
    }
    log_name = f'{dataset}_{intensity}_{match_type}_sgd_ens_large.pkl'
    with open(os.path.join(log_dir, log_name), "wb") as f:
      pickle.dump(result_dict, f)

