#!/usr/bin/env python
# coding: utf-8

# **Copyright 2018 The JAX Authors.**

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

import numpy as np

import tensorflow_datasets as tfds
import tensorflow as tf
# tf.config.set_visible_devices([], device_type='GPU')
from tensorflow_datasets.core.utils import gcs_utils
gcs_utils._is_gcs_disabled = True

from jax.scipy.special import logsumexp
from copy import deepcopy

import os
import time
import argparse
from optimizers import run_optex, run_benchmark, run_standard, run_line_search, tuning_mattern

# A helper function to randomly initialize weights and biases
# for a dense neural network layer
def random_layer_params(m, n, key, scale=1e-2):
    w_key, b_key = random.split(key)
    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(sizes, key):
    keys = random.split(key, len(sizes))
    return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

def vectorize(params):
    return jnp.concatenate([a.flatten() for x in params for a in x])

def unvectorize(vector, params):
    new_params = []
    index = 0
    for p in params:
        size = p[0].size + p[1].size
        new_params.append((
            jnp.reshape(vector[index:index+size][:p[0].size], p[0].shape),
            jnp.reshape(vector[index:index+size][p[0].size:], p[1].shape)
        ))
        index += size
    return new_params

def relu(x):
    return jnp.maximum(0, x)

def predict(params, image):
    # per-example predictions
    activations = image
    for w, b in params[:-1]:
        outputs = jnp.dot(w, activations) + b
        if w.shape[0] == w.shape[1]:
            activations += relu(outputs)
        else:
            activations = relu(outputs)

    final_w, final_b = params[-1]
    logits = jnp.dot(final_w, activations) + final_b
    return logits - logsumexp(logits)

batched_predict = vmap(predict, in_axes=(None, 0))

def one_hot(x, k, dtype=jnp.float32):
    """Create a one-hot encoding of x of size k."""
    return jnp.array(x[:, None] == jnp.arange(k), dtype)

def accuracy(params, images, targets):
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
    return jnp.mean(predicted_class == target_class)

def loss(params_vec, images, targets, init_params=None):
    params = unvectorize(params_vec, init_params)
    logits = batched_predict(params, images)
    return jnp.mean(jnp.sum(-targets * jax.nn.log_softmax(logits), axis=1))

def get_train_test_data(data_dir, name="mnist"):
    mnist_data, info = tfds.load(name=name, batch_size=-1, data_dir=data_dir, with_info=True, shuffle_files=True)
    mnist_data = tfds.as_numpy(mnist_data)
    train_data, test_data = mnist_data['train'], mnist_data['test']
    num_labels = info.features['label'].num_classes
    h, w, c = info.features['image'].shape
    num_pixels = h * w * c

    # Full train set
    train_images, train_labels = train_data['image'], train_data['label']
    train_images = jnp.reshape(train_images, (len(train_images), num_pixels))
    train_labels = one_hot(train_labels, num_labels)

    # Full test set
    test_images, test_labels = test_data['image'], test_data['label']
    test_images = jnp.reshape(test_images, (len(test_images), num_pixels))
    test_labels = one_hot(test_labels, num_labels)
    
    print('Train:', train_images.shape, train_labels.shape)
    print('Test:', test_images.shape, test_labels.shape)
    
    return train_images, train_labels, test_images, test_labels, num_labels, num_pixels


def get_train_batches(data_dir, name="mnist", batch_size=128):
    # as_supervised=True gives us the (image, label) as a tuple instead of a dict
    ds = tfds.load(name=name, split='train', as_supervised=True, data_dir=data_dir)
    # You can build up an arbitrary tf.data input pipeline
    ds = ds.batch(batch_size).prefetch(1)
    # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
    return tfds.as_numpy(ds)


def parse_arguments():
    parser = argparse.ArgumentParser(description='OptEx(Network) experiments')
    parser.add_argument('--data', default="mnist", type=str, help='dataset name')
    parser.add_argument('--opt_name', default="sgd", type=str, help='optimizer name')
    parser.add_argument('--method', default="optex", type=str, help='method name')
    parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
    parser.add_argument('--batch_size', default=512, type=int, help='batch size')
    parser.add_argument('--num_epochs', default=20, type=int, help='number of epochs')
    parser.add_argument('--num_parall', default=4, type=int, help='number of parallel iterations')
    parser.add_argument('--seed', default=0, type=int, help='seed for random number generator')
    parser.add_argument('--num_runs', default=5, type=int, help='number of runs')
    parser.add_argument('--print_every', default=30, type=int, help='number of iterations to print')
    parser.add_argument('--edim', default=10000, type=int, help='number of runs')
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_arguments()
    root = f"./results/mlp-{args.data}"
    
    if os.path.exists(root) is False:
        os.makedirs(root)
    
    # Hyperparameters
    if "mnist" in args.data:
        layer_sizes = [784, 512, 512, 256, 256, 256, 128, 128, 10]
    elif "cifar10" in args.data:
        layer_sizes = [3072, 512, 512, 512, 256, 256, 256, 128, 128, 10]
    
    n_params = sum([(m + 1) * n for m, n in zip(layer_sizes[:-1], layer_sizes[1:])])
    print("# of params:", n_params)
    data_dir = './data/tfds'
    
    all_results = []
    
    for r in range(args.num_runs):
        init_params = init_network_params(layer_sizes, random.PRNGKey(args.seed + r * 1234))
        train_images, train_labels, test_images, test_labels, num_labels, num_pixels = get_train_test_data(data_dir, name=args.data)
        x0 = vectorize(init_params)
        datas = list(get_train_batches(data_dir, args.data, args.batch_size))
        
        print("# of iters:", len(datas))
        
        arguments = {
            "opt_name":     "optax."+args.opt_name,
            "lr":           args.lr,
            "x0":           x0, 
            "num_iters":    len(datas), 
            "num_parall":   args.num_parall,
            "datas":        datas,
            "opt_state":    None,
        }
        
        if args.method in ["optex", "line_search", "benchmark"]:
            arguments["inter_results"] = {}
        if args.method == "optex":
            arguments["effective_dim"] = args.edim
            arguments["inter_results"].update({"length_scale": 1.0})
        
        arguments['num_iters'] = arguments['num_iters'] // (arguments['num_parall'])
        print("\nTraining start for [%s]..." % args.method.upper())
        
        acc_results = []
        
        for epoch in range(args.num_epochs):
            start = time.time()
            
            x, fx, opt_state = eval("run_" + args.method)( # run_optex( # run_standard( # run_benchmark(
                lambda p, image, label: 
                    loss(p, jnp.reshape(image, (len(image), num_pixels)), one_hot(label, num_labels), init_params=init_params), 
                **arguments
            )
            
            print("Elapsed Time: %.4f" % (time.time() - start))
            
            arguments.update({
                "x0":           x,
                "opt_state":    opt_state,
            })
            
            if args.method == "optex":
                xs, ys = arguments["inter_results"]["x_history"], arguments["inter_results"]["g_history"]
                xs, ys = np.concatenate(xs, axis=0), np.concatenate(ys, axis=0)
                
                indices = np.random.choice(len(xs), int(0.8 * len(xs))).tolist()
                target_indices = [i for i in range(len(xs)) if i not in indices]
                
                length_scale = tuning_mattern(
                    xs[indices], 
                    ys[indices], 
                    xs[target_indices], 
                    ys[target_indices], 
                    choice=[0.01, 0.1, 1, 10, 100],
                    effective_dim=5000
                )
                
                arguments["inter_results"].update({
                    "length_scale": length_scale,
                })
                print("optimized length scale:", length_scale)
            
            params = unvectorize(x, init_params)
            train_acc = accuracy(params, train_images, train_labels)
            test_acc = accuracy(params, test_images, test_labels)
            print("Epoch {}".format(epoch))
            print("Training set accuracy {}".format(train_acc))
            print("Test set accuracy {}".format(test_acc))
            
            acc_results.append([train_acc, test_acc])
        
        all_results.append(acc_results)
    
    np.save(f"{root}/{args.opt_name}({args.lr})-{args.num_epochs}x{args.num_parall}-{args.method}.npy", np.array(all_results))