import click
import importlib
import pandas as pd
import numpy as np
import seaborn as sns
from jax import random, vmap, jit
from jax.example_libraries import optimizers
import jax
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import jax.numpy as jnp
import jaxopt
import optax
import equinox as eqx

@click.command()
@click.option('--model', default='Y2Regression')
@click.option('--objective', default='PVICRPS')
@click.option('--posterior', default='NormalizingFlow')
@click.option('--seed', default=0)
@click.option('--s', default=1)
@click.option('--iterations', default=10000)
@click.option('--n', default=100)
@click.option('--g', default=5)
@click.option('--h', default=5)
@click.option('--alpha', default=0.0)
@click.option('--prediction_sample', default=1000)
@click.option('--test_sample', default=10000)
@click.option('--learning_rate', default=0.001)
@click.option('--plot', is_flag=True)
@click.option('--optimizer', default = 'sgd')
def main(model, objective, posterior, seed, s, iterations, n, g, h, alpha, prediction_sample, test_sample, learning_rate, plot, optimizer):
    module1 = importlib.import_module('model')
    m = getattr(module1, model)(n, g, m=h, alpha=alpha)
    rng_key = random.PRNGKey(seed)
    #data_key, rng_key = random.split(rng_key)
    y = m.data(random.PRNGKey(1))
    test_y = None
    if type(y) == tuple:
        y, test_y = y

    init_key, rng_key = random.split(rng_key)
    module2 = importlib.import_module('posterior')
    p = getattr(module2, posterior)(m.d, init_key)
    params = p.gen_params()

    module3 = importlib.import_module('objective')
    o = getattr(module3, objective)(m, p, y, s=s)

    chain = []
    schedule = optax.warmup_cosine_decay_schedule(init_value=0.0,
                                                  peak_value=1.0,
                                                  warmup_steps=1000,
                                                  decay_steps=iterations,
                                                  end_value=0.1,
                                                  exponent=1.0)
    chain.append(optax.clip_by_global_norm(10.0))
    chain.append(optax.adam(learning_rate, nesterov=True))
    chain.append(optax.scale_by_schedule(schedule))
    optimizer = optax.chain(*chain)
    opt_state = optimizer.init(eqx.filter(params, eqx.is_inexact_array)), params, rng_key

    def objective(flow, key):
        return -o.objective(key, flow) / n

    def step(step, opt_state):
        param, flow, rng_key = opt_state
        data_key, rng_key = random.split(rng_key)
        value, grads = eqx.filter_jit(eqx.filter_value_and_grad(objective))(flow, data_key)
        skip, dict = optax.skip_not_finite(grads, param, flow)
        #skip = False
        if not skip:
            updates, new_opt_state = optimizer.update(grads, param, flow)
            new_flow = eqx.apply_updates(flow, updates)
        else:
            #print(step,'skipped')
            new_opt_state = param
            new_flow = flow
        return value, (new_opt_state, new_flow, rng_key)

    def eval(key, flow):
        key, sim_key = random.split(key)
        theta_sample, logp = eqx.filter_vmap(flow.sample_and_log_prob)(random.split(key, test_sample * 2), )
        crpss = vmap(m.test_crps, in_axes=(0, 0, 0, None))(theta_sample[test_sample:], theta_sample[:test_sample], random.split(sim_key, test_sample), test_y)

        return jnp.mean(crpss)

    def validation_crps(key, flow):
        key, sim_key = random.split(key)
        theta_sample, logp = eqx.filter_vmap(flow.sample_and_log_prob)(random.split(key, prediction_sample * 2), )
        print(jnp.mean(theta_sample, axis=0))
        crpss = vmap(m.validate_crps, in_axes=(0, 0, 0))(theta_sample[prediction_sample:], theta_sample[:prediction_sample], random.split(sim_key, prediction_sample))

        return jnp.mean(crpss)

    valid_key, rng_key = random.split(rng_key)
    best_valid = validation_crps(valid_key, params)
    best_param = eqx.tree_at(lambda t:t, params, replace_fn=lambda t:t)

    data = []
    for i in tqdm(range(iterations)):
        value, opt_state = step(i, opt_state)
        _, flow, rng_key = opt_state
        if i % 1000 == 0:
            print(i, value,)
            valid_key, rng_key = random.split(rng_key)
            valid = validation_crps(valid_key, flow)
            if valid < best_valid:
                best_valid = valid
                best_param = eqx.tree_at(lambda t:t, flow, replace_fn=lambda t:t)

            print(valid)

        data.append({'step':i, 'loss': float(value)})

    _, flow, rng_key = opt_state

    print(eval(random.PRNGKey(0), flow))

    theta_sample, _ = eqx.filter_vmap(flow.sample_and_log_prob)(random.split(rng_key, test_sample), )
    if plot:
        gt_betas = m.beta
        #beta1 = gt_betas[:g]
        #beta2 = gt_betas[g:]
        #for p1, p2 in zip(beta1, beta2):
        #    plt.plot(float(p1), float(p2), 'ro')

        data = []
        for sample in theta_sample:
            s = {f'beta{d}': float(sample[d])for d in range(len(sample))}
            data.append(s)
        data = pd.DataFrame(data)
        sns.kdeplot(data=data, x='beta0', y='beta1')
        #plt.xlim([-2,2])
        #plt.ylim([-2,2])
        plt.show()
    print(m.beta)

    result_dir = f'result/{model}_{alpha}_{n}_{g}_{h}/PVI_{seed}_{s}/'
    os.makedirs(result_dir,exist_ok=True)
    result_file = f'result/{model}_{alpha}_{n}_{g}_{h}/PVI_{seed}_{s}/eval'
    with open(result_file, 'w') as f:
        print(eval(random.PRNGKey(0), flow), file=f)
    np.savez_compressed(f'result/{model}_{alpha}_{n}_{g}_{h}/PVI_{seed}_{s}/sample.npz', samples = np.array(theta_sample))


if __name__ == '__main__':
    main()