import os
import time
import itertools
import sys
import inspect
import multiprocessing
from functools import partial
import pandas as pd
import numpy as np
import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad, random, disable_jit
from jax.nn import log_softmax, softmax, one_hot
from jax.experimental import optimizers
import networks

step_size = 0.1
kappa = 1.0
do_difference_trick = False
batch_size = 128
num_iterations = 60000 / batch_size * 20
test_batch_size = 1000
design_size_jump = 100
max_design_size = 800
evaluation_interval = 50
num_experiments = 5
gpus = range(1)
subsample_size = 2000  # when use uncertainty sampling we compute the score only on subsample
loss_type = 'square'  # 'square' # 'cross_entropy'
no_jit = False

tests_kwargs = [
    dict(width_factor=1,
         design_file='uncertainty_sampling',
         result_file='{width_factor: 1, design: uncertainty_sampling, loss: square}.csv'),
    dict(width_factor=8,
         design_file='uncertainty_sampling',
         result_file='{width_factor: 1, design: uncertainty_sampling, loss: square}.csv'),
    dict(width_factor=1,
         design_file='mnist_lenet_oed.s0_l0_th1.ntk.csv',
         result_file='{width_factor: 1, design: {sigma: 0, lambda: 0, theta: 1}, kernel: ntk, loss: square}.csv'),
    dict(width_factor=8,
         design_file='mnist_lenet_oed.s0_l0_th1.ntk.csv',
         result_file='{width_factor: 8, design: {sigma: 0, lambda: 0, theta: 1}, kernel: ntk, loss: square}.csv'),
    dict(width_factor=1,
         design_file='random',
         result_file='{width_factor: 1, design: random, loss: square}.csv'),
    dict(width_factor=8,
         design_file='random',
         result_file='{width_factor: 8, design: random, loss: square}.csv'),
    dict(width_factor=1,
         design_file='mnist_lenet_oed.s0_l0_th1.ntk.csv',
         result_file='{width_factor: 1, design: {sigma: 0, lambda: 0, theta: 1}, kernel: ntk, loss: square}.csv'),
    dict(width_factor=8,
         design_file='mnist_lenet_oed.coreset.ntk.csv',
         result_file='{width_factor: 8, design: coreset, kernel: ntk, loss: square}'),
]


def timestr():
    return time.strftime("%Y-%m-%d_%H:%M:%S", time.gmtime())


def difference_trick(stax_nn_gen):
    def new_stax_nn_gen():
        init_fn1, apply_fn1, kernel_fn = stax_nn_gen()
        init_fn2, apply_fn2, _ = stax_nn_gen()

        @jit
        def new_apply_fn(params, input):
            return apply_fn1(params[:len(params) // 2], input) - apply_fn2(params[len(params) // 2:], input)

        def new_init_fn(rng, input_shape):
            _, param1 = init_fn1(rng, input_shape)
            input_shape, param2 = init_fn2(rng, input_shape)
            return input_shape, param1 + param2

        return new_init_fn, new_apply_fn, kernel_fn

    return new_stax_nn_gen


class GeneratorLen(object):
    def __init__(self, gen, length):
        self.gen = gen
        self.length = length

    def __len__(self):
        return self.length

    def __next__(self):
        return next(self.gen)


def data_stream(images, labels, stream_batch_size, shuffle=True):
    num_complete_batches, leftover = divmod(len(labels), stream_batch_size)
    num_batches = num_complete_batches + bool(leftover)
    rns = npr.RandomState(10)

    def gen():
        while True:
            perm = rns.permutation(len(labels)) if shuffle else np.arange(len(labels))
            for i in range(num_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield images[batch_idx], labels[batch_idx]

    return GeneratorLen(gen(), num_batches)


def main(width_factor, design_file, result_file):
    def loss(param, batch):
        inputs, targets = batch
        preds = predict(param, inputs)
        if loss_type == 'square':
            return 0.5 * jnp.sum((preds - targets) ** 2)
        else:
            return jnp.sum(-log_softmax(preds) * targets)

    def accuracy_batch(param, batch):
        inputs, targets = batch
        target_class = jnp.argmax(targets, axis=1)
        predicted_class = jnp.argmax(predict(param, inputs), axis=1)
        return jnp.mean(predicted_class == target_class)

    def accuracy(param, batches):
        acc = []
        for _ in range(len(batches)):
            acc.append(float(accuracy_batch(param, next(batches))))
        return np.mean(acc)

    def uncertainty_scores_batch(param, batch):
        inputs, targets = batch
        preds = predict(param, inputs)
        n, k = preds.shape
        if loss_type == 'square':
            probs = []
            for i in range(preds.shape[1]):
                targets = one_hot(jnp.ones(n) * i, k)
                probs.append(
                    2 * jnp.pi ** (-k / 2) * jnp.exp(-0.5 * jnp.sum((preds - targets) ** 2, axis=1, keepdims=True)))
            probs = jnp.concatenate(probs, axis=1)
        else:
            probs = softmax(preds, axis=1)
        return 1 - jnp.max(probs, axis=1)

    def uncertainty_scores(param, batches):
        s = []
        for _ in range(len(batches)):
            s.append(uncertainty_scores_batch(param, next(batches)))
        return jnp.concatenate(s)

    design_method = design_file if design_file in ['random', 'uncertainty_sampling'] else 'from_file'

    network_gen = partial(networks.lenet5, width_factor)
    init_random_params, predict_, _ = network_gen() if not do_difference_trick else difference_trick(network_gen)()

    @jit
    def predict(param, inp):
        return predict_(param, inp) * kappa

    rng = random.PRNGKey(10)

    train_images0, test_images, train_labels0, test_labels = np.load(
        f'{os.environ["HOME"]}/data/mnist.2020-08-25_11:35:50.npz').values()
    test_batches = data_stream(test_images, test_labels, test_batch_size)

    opt_init, opt_update, get_params = optimizers.sgd(step_size)

    @jit
    def update(i, optimizer_state, batch):
        param = get_params(optimizer_state)
        return opt_update(i, grad(loss)(param, batch), optimizer_state)

    design_sizes = range(design_size_jump, max_design_size + 1, design_size_jump)
    if design_method == 'from_file':
        xi0 = pd.read_csv(design_file)['xi']

    result_dfs = []
    for experiment in range(num_experiments):
        if design_method == 'random':
            xi0 = npr.choice(len(train_labels0), max_design_size, replace=False)

        result_df = pd.DataFrame()
        start_experiment_time = time.time()
        xi = []
        for design_size in design_sizes:
            if design_method != 'uncertainty_sampling':
                xi = xi0[:design_size]
            else:  # uncertainty_sampling
                if len(xi) > 0:
                    pool = list(set(range(len(train_labels0))) - set(xi))
                    pool = npr.choice(pool, subsample_size, replace=False).tolist()
                    pool_images, pool_labels = jnp.array(train_images0[pool]), jnp.array(train_labels0[pool])
                    pool_batches = data_stream(pool_images, pool_labels, test_batch_size, shuffle=False)
                    xi = xi + [pool[i] for i in
                               jnp.argsort(-uncertainty_scores(params, pool_batches))[:design_size_jump]]
                else:
                    xi = npr.choice(len(train_labels0), design_size_jump, replace=False).tolist()

            train_images, train_labels = jnp.array(train_images0[xi]), jnp.array(train_labels0[xi])
            train_batches = data_stream(train_images, train_labels, batch_size)
            rng, rng_params = random.split(rng, 2)

            _, init_params = init_random_params(rng_params, (-1, 28, 28, 1))

            opt_state = opt_init(init_params)
            itercount = itertools.count()

            accs = []
            print(f"\nStarting training with design size {design_size}")
            num_epochs = int((num_iterations + evaluation_interval) / (
                    design_size / batch_size) / evaluation_interval) * evaluation_interval
            for epoch in range(num_epochs):
                start_time = time.time()
                for _ in range(len(train_batches)):
                    opt_state = update(next(itercount), opt_state, next(train_batches))
                epoch_time = time.time() - start_time
                if (epoch % evaluation_interval) == 0:
                    start_report_time = time.time()
                    params = get_params(opt_state)
                    train_acc = accuracy(params, train_batches)
                    test_acc = accuracy(params, test_batches)
                    print("Epoch {}/{} in {:0.2f} sec".format(epoch, num_epochs, epoch_time))
                    print("Training set accuracy {}".format(train_acc))
                    print("Test set accuracy {}".format(test_acc))
                    print("Report in {:0.2f} sec".format(time.time() - start_report_time))
                    accs.append(test_acc)
            params = get_params(opt_state)

            result_df = result_df.append(
                {'design_size': design_size, 'mean_acc': np.mean(accs[-10:]), 'max_acc': np.max(accs),
                 'max_acc_epoch': np.argmax(accs) * evaluation_interval, 'num_epochs': num_epochs}, ignore_index=True)
            print(f'Result:\n{result_df}')

        result_dfs.append(result_df)
        result_mean_df = pd.DataFrame(
            data=np.concatenate([np.expand_dims(df.values, 2) for df in result_dfs], axis=2).mean(axis=2),
            columns=[f'mean_{c}' for c in result_dfs[0].columns])
        result_std_df = pd.DataFrame(
            data=np.concatenate([np.expand_dims(df.values, 2) for df in result_dfs], axis=2).std(axis=2),
            columns=[f'std_{c}' for c in result_dfs[0].columns])
        final_result_df = pd.concat([result_mean_df, result_std_df], axis=1)
        final_result_df['experiment'] = experiment

        print(
            f'Experiment {experiment}/{num_experiments} in {time.time() - start_experiment_time}\n'
            f'Result:\n'
            f'{final_result_df}\n')
        print(f'Experiment {experiment}/{num_experiments}, write result to {result_file}')
        final_result_df.to_csv(result_file, index=False)


def _wrapper(*args, **kwargs):
    args = args[0]
    print(f'multiprocessing.current_process()={multiprocessing.current_process()}')
    device_idx = int(str(multiprocessing.current_process()).split(',')[0].split('-')[1]) % len(gpus)
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpus[device_idx])
    print(f'Run main{args} on device {device_idx}')
    if no_jit:
        with disable_jit():
            main(*args, **kwargs)
    else:
        main(*args, **kwargs)


if __name__ == '__main__':
    args = inspect.getfullargspec(main)[0]
    tests = [tuple([t[a] for a in args]) for t in tests_kwargs]

    p = multiprocessing.Pool(len(gpus))
    p.map(_wrapper, tests)

    print('END OF SCRIPT')
    sys.exit(0)
