import argparse

import jax
import jax.numpy as np

from jax import grad, jit, vmap, pmap, value_and_grad
from jax import random

from jax.tree_util import tree_multimap, tree_map
from utils import optimizers
from utils import lr_schedules

import haiku as hk

import numpy as onp

import tensorflow_datasets as tfds
import tensorflow as tf

from jax.config import config

import os
import requests

import pickle
import time

from models.util import get_model

from utils.alternate_net_applies import freeze_params, linearize
from utils.training_utils import train_epoch
from utils.eval import eval_ds
from utils.losses import nll, accuracy
from utils.misc import get_single_copy, manual_pmap_tree

from posteriors.utils import sample_weights_diag
from posteriors.swag import init_swag, update_swag, collect_posterior

parser = argparse.ArgumentParser(description='Runs basic train loop on a supervised learning task')
parser.add_argument(
    "--dir",
    type=str,
    default=None,
    required=False,
    help="Training directory for logging results"
)

parser.add_argument(
    "--data_dir",
    type=str,
    default=None,
    required=False,
    help="Directory for storing datasets"
)
parser.add_argument(
    "--seed",
    type=int,
    default=0,
    required=False
)
parser.add_argument(
    "--n_epochs",
    type=int,
    default=300,
    required=False
)
parser.add_argument(
    "--swag_begin_epoch",
    type=int,
    default=151,
    required=False
)
parser.add_argument(
    "--init_lr",
    type=float,
    default=0.1,
    required=False
)
parser.add_argument(
    "--swag_lr",
    type=float,
    default=0.01,
    required=False
)
parser.add_argument(
    "--wd",
    type=float,
    default=5e-4,
    required=False
)
parser.add_argument(
    "--load_file",
    type=str,
    default=None,
    required=False,
    help="Load parameters and state from ERM"
)
parser.add_argument(
    "--model",
    type=str,
    default="ResNet26",
    required=False,
    help="Model class"
)
parser.add_argument(
    "--dataset",
    type=str,
    default="cifar10",
    required=False,
    help="Dataset: {cifar10, svhn_cropped, cifar100}"
)

args = parser.parse_args()

### CIFAR10 channel means and stddevs
channel_means = np.array([0.4914, 0.4822, 0.4465])
channel_stds = np.array([(0.2023, 0.1994, 0.2010)])
n_classes = 10 if args.dataset == 'cifar10' or args.dataset == 'svhn_cropped' else 100
batch_size = 128

n_devices = jax.device_count()

def preprocess_inputs(datapoint):
    image, label = datapoint['image'], datapoint['label']
    image = image / 255
    image = (image - channel_means) / channel_stds
    label = tf.one_hot(label, n_classes) 
    return image, label

def augment_train_data(image, label):
    image = tf.image.resize_with_crop_or_pad(image, 36, 36)
    image = tf.image.random_crop(image, size=(32, 32, 3))
    image = tf.image.random_flip_left_right(image)
    return image, label

ds_train = tfds.load(args.dataset, split='train', data_dir=args.data_dir, shuffle_files=True).shuffle(50000, reshuffle_each_iteration=True).map(preprocess_inputs, num_parallel_calls=tf.data.experimental.AUTOTUNE).cache().map(augment_train_data, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(batch_size, drop_remainder=True).batch(n_devices, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)

options = ds_train.options()
options.experimental_threading.private_threadpool_size = 48
options.experimental_threading.max_intra_op_parallelism = 1

ds_test = tfds.load(args.dataset, split='test', data_dir=args.data_dir).map(preprocess_inputs).cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)
ds_train_eval = tfds.load(args.dataset, split='train', data_dir=args.data_dir).map(preprocess_inputs).cache().batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)


### Model Initialization
rng = random.PRNGKey(args.seed)
rng = np.broadcast_to(rng, (n_devices,) + rng.shape)
    
model = get_model(args.model, n_classes)

# initializes copies of parameters and states on each device
init_params, init_state = pmap(lambda rng, x: model.init(rng, x, is_training=True))(rng, next(iter(tfds.as_numpy(ds_train)))[0])

net_state = init_state

### removes RNG component and runs with is_training=True
@jit
def net_apply(params, state, rng, x):
    return model.apply(params, state, rng, x, True)

@jit
def net_apply_eval(params, state, x):
    return model.apply(params, state, None, x, False)

num_epochs = args.n_epochs 

def step_size_schedule(i):
    return lr_schedules.swa_lr_schedule(i, args.init_lr, args.swag_lr, args.swag_begin_epoch, args.n_epochs)

single_params = get_single_copy(init_params)
single_state = get_single_copy(init_state)
x = next(iter(tfds.as_numpy(ds_train)))[0][0]

all_param_names = init_params.keys()
param_names = all_param_names

opt_init, opt_update, get_params = optimizers.momentum(step_size=step_size_schedule, mass=0.9, wd=args.wd)
opt_state = pmap(opt_init)(init_params)

rng = random.PRNGKey(args.seed)

# main train loop
for epoch in range(num_epochs):
    start = time.time()
    np_ds = tfds.as_numpy(ds_train)
    rng, cur_rng = random.split(rng, 2)
    opt_state, net_state, train_loss = train_epoch(epoch, 
                                                   opt_state, 
                                                   net_state, 
                                                   cur_rng,
                                                   np_ds, 
                                                   nll, 
                                                   get_params, 
                                                   net_apply, 
                                                   # net_apply, 
                                                   opt_update, 
                                                   distributed=True)
    print('Epoch {}: {} {}'.format(epoch, train_loss, time.time() - start), flush=True)
    if epoch >= args.swag_begin_epoch:
        cur_params = get_single_copy(get_params(opt_state))
        if epoch == args.swag_begin_epoch:
            swag_state = init_swag(cur_params)
        else:
            swag_state = update_swag(swag_state, cur_params)
    if epoch % 5 == 0:
        # neesd to flatten params for non-distributed eval
        # takes first copy of params since all copies of params should be identical
        eval_params = get_params(opt_state)
        eval_params, eval_net_state = get_single_copy((eval_params, net_state))

        start = time.time()
        test_results = eval_ds(tfds.as_numpy(ds_test), 
                               eval_params, 
                               eval_net_state, 
                               net_apply_eval, 
                               (nll, accuracy))
        train_results = eval_ds(tfds.as_numpy(ds_train_eval), 
                                eval_params, 
                                eval_net_state, 
                                net_apply_eval, 
                                (nll, accuracy))
        print("Evaluation {}".format(epoch), train_results, test_results, time.time() - start)

        if epoch >= args.swag_begin_epoch:
            swag_means, swag_vars = collect_posterior(swag_state)
            swa_test_results = eval_ds(tfds.as_numpy(ds_test), 
                                   swag_means, 
                                   eval_net_state, 
                                   net_apply_eval,  
                                   (nll, accuracy))
            print("Eval SWA {}".format(epoch), swa_test_results)


# Parameters and state after ERM solution
erm_params, erm_state = get_params(opt_state), net_state

single_erm_params = get_single_copy(erm_params)
single_erm_state = get_single_copy(erm_state)

hyperparams_str = 'seed{}_initlr{}_swaglr{}'.format(args.seed, args.init_lr, args.swag_lr)
print("Hyperparams: {}".format(hyperparams_str))
filedir = 'saved_models_swag/{}/{}/{}_'.format(args.dataset, args.model, hyperparams_str)
os.makedirs(os.path.dirname(filedir + 'params_and_state.pkl'), exist_ok=True)
with open(filedir + 'params_and_state.pkl', 'wb') as f:
    pickle.dump((single_erm_params, single_erm_state), f)
with open(filedir + 'swag_state.pkl', 'wb') as f:
    pickle.dump(swag_state, f)

