import click
import importlib
import pandas as pd
import numpy as np
import seaborn as sns
from jax import random, vmap
from jax.example_libraries import optimizers
import jax
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import scipy
import jaxopt
import jax.numpy as jnp

@click.command()
@click.option('--model', default='Toy')
@click.option('--objective', default='VIBasic')
@click.option('--regularizer', default='VIBasic')
@click.option('--posterior', default='Basic')
@click.option('--seed', default=0)
@click.option('--lamb', default=0.0)
@click.option('--alpha', default=0.0)
@click.option('--s', default=1)
@click.option('--iterations', default=1000)
@click.option('--n', default=100)
@click.option('--g', default=5)
@click.option('--prediction_sample', default=10000)
@click.option('--learning_rate', default=0.001)
@click.option('--vi_training', is_flag = True)
@click.option('--test_bias', is_flag = True)
@click.option('--optimizer', default = 'sgd')
@click.option('--plot', is_flag = True)

def main(model, objective, regularizer, posterior, seed, lamb, alpha, s, iterations, n, g, prediction_sample, learning_rate, vi_training, test_bias, optimizer, plot):
    module1 = importlib.import_module('model')
    m = getattr(module1, model)(n, g, alpha=alpha)
    rng_key = random.PRNGKey(seed)
    #data_key, rng_key = random.split(rng_key)
    y = m.data(random.PRNGKey(1))
    test_y = None
    if type(y) == tuple:
        y, test_y = y

    module2 = importlib.import_module('posterior')
    p = getattr(module2, posterior)(m.d)
    params = p.gen_params()

    module3 = importlib.import_module('objective')
    o = getattr(module3, objective)(m, p, y, s=s, )

    r = getattr(module3, regularizer)(m, p, y, s=s)

    def scheduler(step):
        if step < iterations // 2:
            return learning_rate
        return learning_rate / 10


    if optimizer == 'sgd':
        opt_init, opt_update, get_params = optimizers.sgd(scheduler,)
    elif optimizer == 'nesterov':
        opt_init, opt_update, get_params = optimizers.nesterov(scheduler, 0.9)
    elif optimizer == 'rmsprop':
        opt_init, opt_update, get_params = optimizers.rmsprop_momentum(scheduler, )
    else:
        raise ValueError(optimizer)
    opt_state = opt_init(params), rng_key

    if vi_training:
        o2 = getattr(module3, 'VIBasic')(m, p, y, s=s)
        print('Pretraining with VI...')
        def step(step, opt_state):
            param, rng_key = opt_state
            data_key, rng_key = random.split(rng_key)
            value, grads = o2.value_and_grad(data_key, get_params(param))
            grads = grads / n
            updated_state = opt_update(step, -grads, param)
            return value, (updated_state, rng_key)

        data = []
        for i in tqdm(range(iterations // 2)):
            value, opt_state = step(i, opt_state)
            param, rng_key = opt_state
            if i % 1000 == 0:
                print(i, value, get_params(param))
            data.append({'step': i, 'loss': float(value)})
    def regularized_objective(key, params):
        key1, key2 = random.split(key)
        v1, g1 = o.value_and_grad(key1, params)
        v2, g2 = r.value_and_grad(key2, params)
        return (1 - lamb) * v1 + lamb * v2, (1 - lamb) * g1 + lamb * g2
    def step(step, opt_state):
        param, rng_key = opt_state
        data_key, rng_key = random.split(rng_key)
        value, grads = regularized_objective(data_key, get_params(param))
        #print(t)
        #print(rejection_map)
        #if step %1000 == 0:
        #    print(flags, grads)
        #print(value, grads)
        norm = jnp.linalg.norm(grads)
        grads = grads / n
        if norm > 100:
            grads = grads / norm * 100
        if jnp.sum(jnp.isnan(grads)):
            updated_state = param
        else:
            updated_state = opt_update(step, -grads, param)
        return value, (updated_state, rng_key)

    rej = None
    if test_bias:
        rej = getattr(module3, 'PVIRejection')(m, p, y, s=s)


    data = []
    bias_data = []
    best_pred = -np.inf
    best_param = params
    for i in tqdm(range(iterations)):
        #param, rng_key = opt_state
        #theta_sample = p.sample(rng_key, get_params(param), 1)
        #print(theta_sample)
        # log_likelihoods = vmap(m.log_likelihoods, in_axes=(0, None))(theta_sample, y)
        # predictive_ll = np.sum(np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
        #print(m.log_prior(theta_sample), m.log_likelihoods(theta_sample, y),
        #      p.log_posterior(theta_sample, get_params(param)))
        value, opt_state = step(i, opt_state)
        param, rng_key = opt_state
        if i % 1000 == 0:
            print(i, value, get_params(param))

            theta_sample = p.sample(rng_key, get_params(param), prediction_sample)
            log_likelihoods = vmap(m.valid_log_likelihoods)(theta_sample)
            predictive_ll = np.sum(
                np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
            if predictive_ll > best_pred:
                best_pred = predictive_ll
                best_param = get_params(param)
            print(predictive_ll)
            if test_bias and i > 0:
                grad1 = []
                grad2 = []
                for _ in range(1000):
                    bias_key, rng_key = random.split(rng_key)
                    key1, key2 = random.split(bias_key)
                    _, g1 = rej.value_and_grad(key1, get_params(param))
                    _, g2 = o.value_and_grad(key2, get_params(param))
                    grad1.append(g1)
                    grad2.append(g2)
                print(np.linalg.norm(np.mean(grad1,axis=0)- np.mean(grad2,axis=0)))
#                print(np.linalg.norm(np.mean(biases,axis=0))/np.linalg.norm(np.mean(grads,axis=0)), np.linalg.norm(np.mean(biases,axis=0)), np.linalg.norm(np.mean(grads,axis=0)))
                bias_data.append({'step':i})

            #print(np.mean(biases))
        data.append({'step':i, 'loss': float(value)})
        #print(theta_sample)

    param, rng_key = opt_state
    #prediction_key = random.split(rng_key, prediction_sample)
    param_list = get_params(param)
    theta_sample = p.sample(rng_key, param_list, prediction_sample)
    log_likelihoods = vmap(m.valid_log_likelihoods)(theta_sample,)
    predictive_ll = np.sum(np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
    #print(m.log_prior(theta_sample), m.log_likelihoods(theta_sample,y), p.log_posterior(theta_sample,get_params(param)))
    #print(theta_sample)

    print(best_param, predictive_ll)
    #param_list = param_list.at[-2].set(-2.)
    #theta_sample = p.sample(rng_key, param_list, prediction_sample)
    #log_likelihoods = vmap(m.log_likelihoods, in_axes=(0, None))(theta_sample, y)
    #predictive_ll = np.sum(np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
    #print(param_list, predictive_ll)

    if test_y is not None:
        log_likelihoods = vmap(m.test_log_likelihoods, in_axes=(0, None))(theta_sample, test_y)
        test_ll = np.sum(np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
        predictive_ll = test_ll
        print('test set :', test_ll)

    os.makedirs(f"result_inter/{model}_{alpha}_{n}_{g}_{posterior}_{objective}/", exist_ok=True)
    result_file = f'result_inter/{model}_{alpha}_{n}_{g}_{posterior}_{objective}/{regularizer}_{lamb}_{seed}_{s}'
    print('writing to',result_file)

    #if hasattr(p, 'diagonosis') and plot:
    #    p.diagonosis(param_list, m, o.name(), m.gt)

    with open(result_file, 'w') as f:
        params = best_param
        print(' '.join(map(str, params)), predictive_ll, file=f)

        if hasattr(m, 'valid_y'):
            log_likelihoods = vmap(m.valid_log_likelihoods,)(theta_sample,)
            valid_ll = np.sum(np.array(scipy.special.logsumexp(log_likelihoods, axis=0)) - np.log(prediction_sample))
            print('valid set :', valid_ll, )
            print('valid set :', valid_ll, file=f)

    data = pd.DataFrame(data)
    sns.lineplot(data = data, x='step', y = 'loss')
    plt.savefig(result_file+'.pdf')


if __name__ == '__main__':
    main()