import optax
import jax
import numpy as np
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
import numpyro.distributions as dist
from infbench.posterior import Posterior
from jax.random import PRNGKey, split
from jax import jvp, value_and_grad, vmap, grad
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
from tqdm import tqdm
from functools import partial
import argparse
from sklearn.preprocessing import StandardScaler

from utils import generate_batch_index

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--s', type=int)
parser.add_argument('--lr', type=float)
parser.add_argument('--dataset', type=str)
parser.add_argument('--opt', type=str, default='adam')
args = parser.parse_args()
SEED = args.s
lr = args.lr
dataset = args.dataset
dataset_name = str(dataset)
OPT = args.opt
if OPT == 'adam':
    config = {
        'log_dir': './results',
        # 'optimizer': partial(optax.sgd, momentum=0.9),
        'optimizer': optax.adam,
        'init_sigma': 0.001,
        'batch_size': 100,
        'local_reparam': False
    }
elif OPT == 'sgd':
    config = {
        'log_dir': './results_sgd',
        # 'optimizer': partial(optax.sgd, momentum=0.9),
        'optimizer': optax.sgd,
        'init_sigma': 0.001,
        'batch_size': 100,
        'local_reparam': False
    }
else:
    raise NotImplementedError

use_local_reparam = config['local_reparam']

from numpyro.infer.util import potential_energy

model = Posterior('xxxxx', dataset_name)
kit_generator = model.numpy()
model_func = model.numpyro()
dataset = model.data()
N = dataset['N']
print(N)
X = dataset['X']
kit = kit_generator(**model.data())
flattened_param_template = ravel_pytree(kit['param_template'])[0]
unflatten_func = kit['unflatten_func']

def get_norm(X):
    return (X ** 2).sum(-1).mean(0)

def get_optimizer(step_size):
    if OPT == 'sgd':
        step_size = step_size * config['batch_size'] * 0.9 / N
    return(config['optimizer'](step_size))

def get_log_p_func(X):
    def _inner(params):
        return -potential_energy(
            model_func,
            model_args = [],
            model_kwargs = {
                'X': X,
                'N': N,
                'K': dataset['K']
            },
            params=params
        )
    return _inner

def eval_fulldataset_loss(key, params, X):
    loc = params['loc']
    key, _key = split(key)
    shuffled_idx = generate_batch_index(_key, N, 5000)
    losses = []
    for idx in shuffled_idx:
        key, _key = split(key)
        eps = jax.random.normal(_key, shape=(500,) + loc.shape)
        losses.append(vmap(loss_func, (None, None, 0))(
            params, X[idx], eps).mean()
        )
    return np.mean(losses)


def idx_update_func(x, idx, y):
    x = x.at[idx].set(y)
    return x

idx_update_func = jax.jit(idx_update_func, donate_argnums=(0,))

# NUM_ITERS = 75000
NUM_ITERS = 20000
BATCH_SIZE = config['batch_size']

# @partial(jax.jit, static_argnums=(0,))
def elbo(sample, log_q_func, log_p_func):
    log_p = log_p_func(unflatten_func(sample))
    log_q = log_q_func(sample).sum()
    return log_q - log_p

@jax.jit
def loss_func(params, X, eps):
    loc, log_scale = params["loc"], params["log_scale"]
    log_q_func = dist.Normal(loc, jnp.exp(log_scale)).log_prob
    log_p_func = get_log_p_func(X)
    z = loc + jnp.exp(log_scale) * eps
    return elbo(z, log_q_func, log_p_func)

@jax.jit
def get_cv(params, X, eps):
    loc, log_scale = params["loc"], params["log_scale"]
    log_q_func = lambda x: jnp.zeros_like(x)
    log_p_func = get_log_p_func(X)
    scale_noise_product = eps * jnp.exp(log_scale)
    elbo_func = partial(elbo, log_q_func=log_q_func, log_p_func=log_p_func)
    hvp = jvp(grad(elbo_func), (loc,), (scale_noise_product,))[1]
    return hvp

@jax.jit
def get_sample_grad(params, X, eps):
    loc, log_scale = params["loc"], params["log_scale"]
    log_q_func = lambda x: jnp.zeros_like(x)
    log_p_func = get_log_p_func(X)
    z = loc
    return grad(lambda z: -log_p_func(unflatten_func(z)))(z)


def get_loss_eps_grad(key, params, idx, local_reparam=True):
    loc = params['loc']
    if local_reparam:
        eps = jax.random.normal(key, shape=(len(idx[0]),) + loc.shape)
        loss, grads = vmap(value_and_grad(loss_func), (None, 0, 0))(
            params, X[idx], eps
        )
    else:
        eps = jax.random.normal(key, shape=loc.shape)
        loss, grads = vmap(value_and_grad(loss_func), (None, 0, None))(
            params, X[idx], eps)
    return loss, grads, eps


def run_gd(seed, step_size, local_reparam=True):
    key = PRNGKey(seed)
    key, _key = split(key)
    loc, log_scale = (
        jax.random.normal(_key, flattened_param_template.shape) / 100,
        jnp.ones_like(flattened_param_template) * config['init_sigma'],
    )
    params = {"loc": loc, "log_scale": log_scale}
    losses = []
    grad_norms = []
    optimizer = get_optimizer(step_size)
    opt_state = optimizer.init(params)
    iter_counter = 0
    while iter_counter <= NUM_ITERS:
        key, _key = split(key)
        shuffled_idx = generate_batch_index(_key, N, BATCH_SIZE)
        for idx in shuffled_idx:
            idx = (idx,)
            key, _key = split(key)
            loss, grads, eps = get_loss_eps_grad(_key, params, idx, local_reparam)
            grad_norms.append(
                (grads['loc'] ** 2).mean()
            )
            grads = tree_map(lambda g: g.mean(0), grads)
            updates, opt_state = optimizer.update(grads, opt_state)
            params = optax.apply_updates(params, updates)
            if iter_counter % 100 == 0:
                key, _key = split(key)
                losses.append(eval_fulldataset_loss(_key, params, X))
            iter_counter += 1
            
    return losses, np.array(grad_norms)  


def run_cv(seed, step_size, local_reparam=True):
    key = PRNGKey(seed)
    key, _key = split(key)
    loc, log_scale = (
        jax.random.normal(_key, flattened_param_template.shape) / 100,
        jnp.ones_like(flattened_param_template) * config['init_sigma'],
    )
    params = {"loc": loc, "log_scale": log_scale}
    losses = []
    grad_norms = []
    optimizer = get_optimizer(step_size)
    opt_state = optimizer.init(params)
    iter_counter = 0
    while iter_counter <= NUM_ITERS:
        key, _key = split(key)
        shuffled_idx = generate_batch_index(_key, N, BATCH_SIZE)
        for idx in shuffled_idx:
            idx = (idx,)
            key, _key = split(key)
            loss, grads, eps = get_loss_eps_grad(_key, params, idx, local_reparam)
            if local_reparam:
                taylor_cv = vmap(get_cv, (None, 0, 0))(params, X[idx], eps)
            else:
                taylor_cv = vmap(get_cv, (None, 0, None))(params, X[idx], eps)
            grads['loc'] = grads['loc'] - taylor_cv
            grad_norms.append(
                (grads['loc'] ** 2).mean()
            )
            grads = tree_map(lambda g: g.mean(0), grads)
            updates, opt_state = optimizer.update(grads, opt_state)
            params = optax.apply_updates(params, updates)
            if iter_counter % 100 == 0:
                key, _key = split(key)
                losses.append(eval_fulldataset_loss(_key, params, X))
            iter_counter += 1
    return losses, np.array(grad_norms)

def run_saga(seed, step_size, local_reparam=True):
    key = PRNGKey(seed)
    key, _key = split(key)
    loc, log_scale = (
        jax.random.normal(_key, flattened_param_template.shape) / 100,
        jnp.ones_like(flattened_param_template) * config['init_sigma'],
    )
    params = {"loc": loc, "log_scale": log_scale}
    losses = []
    grad_norms = []
    optimizer = get_optimizer(step_size)
    opt_state = optimizer.init(params)
    iter_counter = 0

    # Dual CV specific parameters
    mu_table = jnp.zeros((N,) + loc.shape)
    scale_table = jnp.zeros((N,) + log_scale.shape)
    warmup_phase = True
    grad_mean = 0
    counter = 0
    # Main training loop
    while iter_counter <= NUM_ITERS:
        key, _key = split(key)
        shuffled_idx = generate_batch_index(_key, N, BATCH_SIZE)
        for idx in shuffled_idx:
            idx = (idx,)
            key, _key = split(key)
            loss, grads, eps = get_loss_eps_grad(_key, params, idx, local_reparam)
            eps_vmap_flag = 0 if local_reparam else None
            sample_grads = vmap(get_sample_grad, (None, 0, eps_vmap_flag))(
                params, X[idx], eps
            )
            if warmup_phase:
                grad_mean += sample_grads.sum(0) / N
            else:
                old_params = {'loc': mu_table[idx], 'log_scale': scale_table[idx]}
                cv_term_0 = vmap(get_sample_grad, (0, 0, eps_vmap_flag))(
                    old_params, X[idx], eps
                )
                cv_term_1 = vmap(get_cv, (0, 0, eps_vmap_flag))(
                    old_params, X[idx], eps
                )
                if counter > 3:
                    grads['loc'] = grads['loc'] - (cv_term_0 + cv_term_1) + grad_mean
                # Update SAGA cache
                grad_mean += (sample_grads - cv_term_0).sum(0) / N
            def _update_table():
                nonlocal mu_table
                nonlocal scale_table
                mu_table = idx_update_func(mu_table, idx, params['loc'])
                scale_table = idx_update_func(scale_table, idx, params['log_scale'])

            _update_table()
            grad_norms.append((grads['loc'] ** 2).mean())
            grads = tree_map(lambda g: g.mean(0), grads) 
            updates, opt_state = optimizer.update(grads, opt_state)
            params = optax.apply_updates(params, updates)
            if iter_counter % 100 == 0:
                key, _key = split(key)
                losses.append(eval_fulldataset_loss(_key, params, X))
            iter_counter += 1
        counter += 1
        warmup_phase = False
    return losses, np.array(grad_norms)


def run_svrg(seed, step_size, local_reparam=True):
    key = PRNGKey(seed)
    key, _key = split(key)
    loc, log_scale = (
        jax.random.normal(_key, flattened_param_template.shape) / 100,
        jnp.ones_like(flattened_param_template) * config['init_sigma'],
    )
    params = {"loc": loc, "log_scale": log_scale}
    losses = []
    grad_norms = []
    optimizer = get_optimizer(step_size)
    opt_state = optimizer.init(params)
    iter_counter = 0

    # Dual CV specific parameters
    inner_loop_size = N // BATCH_SIZE
    grad_mean = 0.0
    num_epoch = 0
    # Main training loop
    while iter_counter <= NUM_ITERS:
        key, _key = split(key)
        shuffled_idx = generate_batch_index(_key, N, BATCH_SIZE)
        for idx in shuffled_idx:
            if iter_counter % inner_loop_size == 0:
                old_params = params
                grad_mean = vmap(get_sample_grad, (None, 0, None))(
                    old_params, X, jnp.zeros_like(params['loc'])
                ).mean(0)
            idx = (idx,)
            key, _key = split(key)
            loss, grads, eps = get_loss_eps_grad(_key, params, idx, local_reparam)
            eps_vmap_flag = 0 if local_reparam else None
            cv_term_0 = vmap(get_sample_grad, (None, 0, eps_vmap_flag))(
                old_params, X[idx], eps
            )
            cv_term_1 = vmap(get_cv, (None, 0, eps_vmap_flag))(
                old_params, X[idx], eps
            )
            grads['loc'] = grads['loc'] - (cv_term_0 + cv_term_1) + grad_mean
            grad_norms.append((grads['loc'] ** 2).mean())
            grads = tree_map(lambda g: g.mean(0), grads) 
            updates, opt_state = optimizer.update(grads, opt_state)
            params = optax.apply_updates(params, updates)
            if iter_counter % 100 == 0:
                key, _key = split(key)
                losses.append(eval_fulldataset_loss(_key, params, X))
            iter_counter += 1
    return losses, np.array(grad_norms)    


LOG_DIR = config['log_dir']
print(f'Stepsize: {lr}')
# naive_loss, naive_grad_norm = run_gd(SEED, lr, use_local_reparam)
# np.save(f'./{LOG_DIR}/{dataset_name}_naive_{lr}_{SEED}_loss.npy', naive_loss)
# np.save(f'./{LOG_DIR}/{dataset_name}_naive_{lr}_{SEED}_gradnorm.npy', naive_grad_norm)
# cv_loss, cv_grad_norm = run_cv(SEED, lr, use_local_reparam)
# np.save(f'./{LOG_DIR}/{dataset_name}_cv_{lr}_{SEED}_loss.npy', cv_loss)
# np.save(f'./{LOG_DIR}/{dataset_name}_cv_{lr}_{SEED}_gradnorm.npy', cv_grad_norm)
dual_loss, dual_grad_norm = run_saga(SEED, lr, use_local_reparam)
np.save(f'./{LOG_DIR}/{dataset_name}_dual_{lr}_{SEED}_loss.npy', dual_loss)
np.save(f'./{LOG_DIR}/{dataset_name}_dual_{lr}_{SEED}_gradnorm.npy', dual_grad_norm)
# svrg_loss, svrg_grad_norm = run_svrg(SEED, lr, use_local_reparam)
# np.save(f'./{LOG_DIR}/{dataset_name}_svrg_{lr}_{SEED}_loss.npy', svrg_loss)
# np.save(f'./{LOG_DIR}/{dataset_name}_svrg_{lr}_{SEED}_gradnorm.npy', svrg_grad_norm)
# svrg_loss, svrg_grad_norm = run_svrg(SEED, lr, use_local_reparam) 
# np.save(f'./{LOG_DIR}/{dataset_name}_svrg2_{lr}_{SEED}_loss.npy', svrg_loss) # svrg2 stands for SVRG with fixed loop size
# np.save(f'./{LOG_DIR}/{dataset_name}_svrg2_{lr}_{SEED}_gradnorm.npy', svrg_grad_norm)
# dual_loss, dual_grad_norm = run_saga_late(SEED, lr, use_local_reparam)
# np.save(f'./{LOG_DIR}/{dataset_name}_dual_late_{lr}_{SEED}_loss.npy', dual_loss)
# np.save(f'./{LOG_DIR}/{dataset_name}_dual_late_{lr}_{SEED}_gradnorm.npy', dual_grad_norm)
