import argparse

import jax
import jax.numpy as np

from jax import grad, jit, vmap, pmap, value_and_grad
from jax import random

from jax.tree_util import tree_multimap, tree_map
from utils import optimizers
from utils import adaptation_utils
from utils.regularizers import weighted_parameter_loss
import haiku as hk

import numpy as onp

import tensorflow_datasets as tfds
import tensorflow as tf

from jax.config import config

import os
import requests

import pickle
import time

from models.util import get_model

from utils.training_utils import train_epoch
from utils.eval import eval_ds_all, get_labels

from utils.losses import nll, accuracy, entropy, brier, ece
from utils.misc import get_single_copy, manual_pmap_tree

from posteriors.utils import sample_weights_diag
from posteriors.swag import init_swag, update_swag, collect_posterior

import resource
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
parser = argparse.ArgumentParser(description='Runs basic train loop on a supervised learning task')
parser.add_argument(
    "--dir",
    type=str,
    default=None,
    required=False,
    help="Training directory for logging results"
)
parser.add_argument(
    "--log_prefix",
    type=str,
    default=None,
    required=False,
    help="Name prefix for logging results"
)
parser.add_argument(
    "--data_dir",
    type=str,
    default='datasets',
    required=False,
    help="Directory for storing datasets"
)
parser.add_argument(
    "--seed",
    type=int,
    default=0,
    required=False
)
parser.add_argument(
    "--wd",
    type=float,
    default=0.,
    required=False
)
parser.add_argument(
    "--model",
    type=str,
    default="ResNet50",
    required=False,
    help="Model class"
)
parser.add_argument(
    "--corruption_type",
    type=str,
    default="brightness",
    required=False,
)
parser.add_argument(
    "--corruption_level",
    type=int,
    default=1,
    required=False,
)
parser.add_argument(
    "--n_epochs",
    type=int,
    default=1,
    required=False,
)
parser.add_argument(
    "--batch_size",
    type=int,
    default=64,
    required=False,
)
parser.add_argument(
    "--lr",
    type=float,
    default=0.00025,
    required=False,
)
parser.add_argument(
    "--adapt_bn_only",
    dest="adapt_bn_only",
    action='store_true'
)
parser.add_argument(
    "--use_swag_posterior",
    dest="use_swag_posterior",
    action='store_true'
)
parser.add_argument(
    "--swag_posterior_weight",
    type=float,
    default=1e-3,
    required=False,
)
parser.add_argument(
    "--swag_posterior_damp",
    type=float,
    default=1e-4,
    required=False,
)

args = parser.parse_args()

channel_means = (0.485, 0.456, 0.406)
channel_stds = (0.229, 0.224, 0.225)

n_classes = 1000

def preprocess_inputs(datapoint):
    image, label = datapoint['image'], datapoint['label']
    image = image / 255
    image = (image - channel_means) / channel_stds
    label = tf.one_hot(label, n_classes) 
    return image, label

corruption_str = '{}_{}'.format(args.corruption_type, args.corruption_level)
print(corruption_str, flush=True)

ds_test = tfds.load('imagenet2012_corrupted/{}'.format(corruption_str), split='validation', data_dir=args.data_dir).map(preprocess_inputs, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(128, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

options = ds_test.options()
options.experimental_threading.private_threadpool_size = 48
options.experimental_threading.max_intra_op_parallelism = 1

rng = random.PRNGKey(args.seed)

bn_only_str = 'adaptbnonly_' if args.adapt_bn_only else ''

filename = 'logs/entropy_minimization_imagenet/{}/posteriorweight{}_posteriordamp{}_{}lr{}_batchsize{}/seed{}_{}.pkl'.format(args.model, args.swag_posterior_weight, args.swag_posterior_damp, bn_only_str, args.lr, args.batch_size, args.seed, corruption_str)
os.makedirs(os.path.dirname(filename), exist_ok=True)
print(filename, flush=True)
try:
    pickle.load(open(filename, 'rb'))
    print(filename, 'file loaded')
    # import ipdb; ipdb.set_trace()
except:
    print(filename, 'file not found')

test_labels = get_labels(tfds.as_numpy(ds_test))

def marginal_logits(logits):
    logits = np.array(logits)
    logits = jax.nn.log_softmax(logits, axis=-1)
    n = logits.shape[0]
    return jax.scipy.special.logsumexp(logits, axis=0, b=1/n)

log_dict = {}

all_logits = []
mean_stats = []
for seed in range(1, 11):
    initial_logits_filename = 'logs/entropy_minimization_imagenet/{}/posteriorweight{}_posteriordamp{}_{}lr{}_batchsize{}/seed{}_{}_initial_logits.npy'.format(args.model, args.swag_posterior_weight, args.swag_posterior_damp, bn_only_str, args.lr, args.batch_size, seed, corruption_str)
    initial_logits = np.load(initial_logits_filename)
    stats = [nll(initial_logits, test_labels), entropy(initial_logits, test_labels), accuracy(initial_logits, test_labels), brier(initial_logits, test_labels), ece(initial_logits, test_labels)]
    mean_stats.append(stats)
    print(seed, stats)
    all_logits.append(initial_logits)

marginal_initial_logits = marginal_logits(all_logits)
stats = [nll(marginal_initial_logits, test_labels), entropy(marginal_initial_logits, test_labels), accuracy(marginal_initial_logits, test_labels), brier(marginal_initial_logits, test_labels), ece(marginal_initial_logits, test_labels)]
print('marginal stats', stats)
log_dict['Initial Test'] = stats
mean_stats = np.array(mean_stats).mean(axis=0)
log_dict['Initial Mean Stats Test'] = mean_stats


all_logits = []
mean_stats = []
for seed in range(1, 11):
    initial_bn_logits_filename = 'logs/entropy_minimization_imagenet/{}/posteriorweight{}_posteriordamp{}_{}lr{}_batchsize{}/seed{}_{}_bn_logits.npy'.format(args.model, args.swag_posterior_weight, args.swag_posterior_damp, bn_only_str, args.lr, args.batch_size, seed, corruption_str)
    initial_bn_logits = np.load(initial_bn_logits_filename)
    stats = [nll(initial_bn_logits, test_labels), entropy(initial_bn_logits, test_labels), accuracy(initial_bn_logits, test_labels), brier(initial_bn_logits, test_labels), ece(initial_bn_logits, test_labels)]
    mean_stats.append(stats)
    print(seed, stats)
    all_logits.append(initial_bn_logits)

marginal_initial_logits = marginal_logits(all_logits)
stats = [nll(marginal_initial_logits, test_labels), entropy(marginal_initial_logits, test_labels), accuracy(marginal_initial_logits, test_labels), brier(marginal_initial_logits, test_labels), ece(marginal_initial_logits, test_labels)]
print('marginal BN stats', stats)
log_dict['Initial Batchnorm Adapted Test'] = stats
mean_stats = np.array(mean_stats).mean(axis=0)
log_dict['Initial Batchnorm Adapted Mean Stats Test'] = mean_stats

all_logits = []
mean_stats = []
for seed in range(1, 11):
    final_logits_filename = 'logs/entropy_minimization_imagenet/{}/posteriorweight{}_posteriordamp{}_{}lr{}_batchsize{}/seed{}_{}_final_logits.npy'.format(args.model, args.swag_posterior_weight, args.swag_posterior_damp, bn_only_str, args.lr, args.batch_size, seed, corruption_str)
    final_logits = np.load(final_logits_filename)
    stats = [nll(final_logits, test_labels), entropy(final_logits, test_labels), accuracy(final_logits, test_labels), brier(final_logits, test_labels), ece(final_logits, test_labels)]
    mean_stats.append(stats)
    print(seed, stats)
    all_logits.append(final_logits)

marginal_initial_logits = marginal_logits(all_logits)
stats = [nll(marginal_initial_logits, test_labels), entropy(marginal_initial_logits, test_labels), accuracy(marginal_initial_logits, test_labels), brier(marginal_initial_logits, test_labels), ece(marginal_initial_logits, test_labels)]
print('marginal final stats', stats)
log_dict['Epoch_0 Test'] = stats
mean_stats = np.array(mean_stats).mean(axis=0)
log_dict['Epoch_0 Mean Stats Test'] = mean_stats

results_filename = 'logs/entropy_minimization_imagenet_ensemble/{}/posteriorweight{}_posteriordamp{}_{}lr{}_batchsize{}/seed{}_{}.pkl'.format(args.model, opt_str, args.swag_posterior_weight, args.swag_posterior_damp, bn_only_str, args.lr, args.batch_size, args.seed, corruption_str)
os.makedirs(os.path.dirname(results_filename), exist_ok=True)
pickle.dump(log_dict, open(results_filename, 'wb'))
print(log_dict)
import ipdb; ipdb.set_trace()

