import click
import importlib
import pandas as pd
import numpy as np
import seaborn as sns
from jax import random, vmap, jit
from jax.example_libraries import optimizers
import jax
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import jax.numpy as jnp
import jaxopt
import optax
import equinox as eqx
import scipy


class DiscreteDensity:
    def __init__(self, N=20):
        self.N = N
        self.cat = np.array([1, 2, 4, 7, 10, 12, 10, 7, 4, 2,
                             1, 2, 3, 5, 7, 8, 7, 5, 3, 2])
        self.probs = self.cat / np.sum(self.cat)

    def sample(self, num, shuffle=True):
        return np.random.choice(np.arange(self.N), num, p=self.probs)


@click.command()
@click.option('--model', default='HSP90')
@click.option('--objective', default='PVICRPSKernel')
@click.option('--posterior', default='NormalizingFlow1D')
@click.option('--seed', default=0)
@click.option('--s', default=32)
@click.option('--iterations', default=10000)
@click.option('--evaluate_it', default=1000)
@click.option('--batch_size', default=32)
@click.option('--prediction_sample', default=10)
@click.option('--test_sample', default=10)
@click.option('--plot_sample', default=10000)
@click.option('--learning_rate', default=0.0001)
@click.option('--plot', is_flag=True)
@click.option('--optimizer', default = 'sgd')
@click.option('--datapath', default = '../HSP90')
@click.option('--noise', default = 10.)
@click.option('--k', default = 2)
@click.option('--outdir', default = "HSP90")
@click.option('--n', default = 10000)
def main(model, objective, posterior, seed, s, iterations, evaluate_it, batch_size, prediction_sample, test_sample, plot_sample, learning_rate, plot, optimizer, datapath, noise, k, outdir, n):
    module1 = importlib.import_module('model')
    m = getattr(module1, model)(batch_size=batch_size, dir=datapath, noise_std = noise, n = n)
    rng_key = random.PRNGKey(seed)
    #data_key, rng_key = random.split(rng_key)

    init_key, rng_key = random.split(rng_key)
    module2 = importlib.import_module('posterior')
    p = getattr(module2, posterior)(m.d, init_key, knots = k)
    params = p.gen_params()

    module3 = importlib.import_module('objective')
    o = getattr(module3, objective)(m, p, y=m.data, s=s)

    chain = []
    schedule = optax.warmup_cosine_decay_schedule(init_value=0.0,
                                                  peak_value=1.0,
                                                  warmup_steps=1000,
                                                  decay_steps=iterations//2,
                                                  end_value=0.1,
                                                  exponent=1.0)
    chain.append(optax.clip_by_global_norm(10.0))
    chain.append(optax.adam(learning_rate, nesterov=True))
    chain.append(optax.scale_by_schedule(schedule))
    optimizer = optax.chain(*chain)
    data_iter = m.data()

    opt_state = optimizer.init(eqx.filter(params, eqx.is_inexact_array)), params, rng_key, data_iter


    def objective(flow, key, y, ):
        return -o.objective(key, flow, y)

    jit_obj = eqx.filter_jit(eqx.filter_value_and_grad(objective))

    def step(step, opt_state):
        param, flow, rng_key, data_iter = opt_state
        data_key, rng_key = random.split(rng_key)
        y = jnp.array(next(data_iter))

        value, grads = jit_obj(flow, data_key, y)
        skip, dict = optax.skip_not_finite(grads, param, flow)
        #skip = False
        if not skip:
            updates, new_opt_state = optimizer.update(grads, param, flow)
            new_flow = eqx.apply_updates(flow, updates)
        else:
            print(step,'skipped')
            new_opt_state = param
            new_flow = flow
        return value, (new_opt_state, new_flow, rng_key, data_iter)

    def eval(key, flow):
        key, sim_key = random.split(key)
        theta_sample, logp = eqx.filter_vmap(flow.sample_and_log_prob)(random.split(key, test_sample * 2), )
        crpss = []
        for batch in m.valid_loader:
            crps = vmap(m.test_crps, in_axes=(0, 0, 0, None))(theta_sample[test_sample:], theta_sample[:test_sample], random.split(sim_key, test_sample), batch)
            crpss.append(jnp.mean(crps))
        return jnp.mean(jnp.array(crpss))

    def validation_crps(key, flow):
        key, sim_key = random.split(key)
        theta_sample, logp = eqx.filter_vmap(flow.sample_and_log_prob)(random.split(key, prediction_sample * 2), )
        crpss = []
        for batch in m.valid_loader:
            crps = vmap(m.validate_crps, in_axes=(0, 0, 0, None))(theta_sample[prediction_sample:], theta_sample[:prediction_sample], random.split(sim_key, prediction_sample), batch)
            crpss.append(jnp.mean(crps))
            #print('flow', jnp.mean(crps))
            #m.plot(theta_sample[0], theta_sample[1], batch)
        return jnp.mean(jnp.array(crpss))

    valid_key, rng_key = random.split(rng_key)
    best_valid = validation_crps(valid_key, params)
    print('The number before training is', best_valid)
    best_param = eqx.tree_at(lambda t:t, params, replace_fn=lambda t:t)
    m.oracle(N=prediction_sample)

    data = []
    for i in tqdm(range(iterations)):
        value, opt_state = step(i, opt_state)
        _, flow, rng_key, _ = opt_state
        if i % evaluate_it == 0:
            print(i, value,)
            valid_key, rng_key = random.split(rng_key)
            valid = validation_crps(valid_key, flow)
            if valid < best_valid:
                best_valid = valid
                best_param = eqx.tree_at(lambda t:t, flow, replace_fn=lambda t:t)

            print(valid)
            all_samples, logp = eqx.filter_vmap(flow.sample_and_log_prob)(random.split(random.PRNGKey(0), plot_sample), )

            if plot:
                if len(all_samples[0]) > 1:
                    plotdata = []
                    for sample in all_samples:
                        plotdata.append({'beta1': float(m.link(sample[0])), 'beta2': float(m.link(sample[1])), })
                    plotdata = pd.DataFrame(plotdata)
                    sns.scatterplot(data = plotdata, x='beta1', y='beta2', s = 1, color = ".2")
                    sns.kdeplot(data=plotdata, x='beta1', y='beta2')
                    plt.xlim([0, 6])
                    plt.ylim([0, 6])
                    plt.savefig(f'figure/{outdir}/{i}.png')
                    plt.clf()
                else:
                    cls = DiscreteDensity()
                    true_y = cls.sample(10000)
                    plotdata = np.array(m.link(all_samples)).flatten()
                    print(scipy.stats.kstest(true_y, plotdata))
                    #sns.scatterplot(data = plotdata, x='beta1', y='beta2', s = 1, color = ".2")
                    plt.rcParams.update({'font.size': 25})
                    plt.figure(figsize=(6, 4))
                    sns.kdeplot(true_y, label='True')
                    sns.kdeplot(plotdata, label='Inferred')
                    plt.legend()
                    plt.xticks([])
                    plt.yticks([])
                    plt.xlim([0, 20])
                    plt.xlabel('')
                    plt.ylabel('')
                    plt.title('Parameter distribution')
                    plt.savefig(f'figure/{outdir}/{i}.pdf')
                    plt.clf()


        data.append({'step':i, 'loss': float(value)})

    _, flow, rng_key, _ = opt_state

    print(eval(random.PRNGKey(0), flow))

    theta_sample, _ = eqx.filter_vmap(flow.sample_and_log_prob)(random.split(rng_key, plot_sample), )

    #if plot:
    #    gt_betas = m.beta
        #beta1 = gt_betas[:g]
        #beta2 = gt_betas[g:]
        #for p1, p2 in zip(beta1, beta2):
        #    plt.plot(float(p1), float(p2), 'ro')

    #    data = []
    #    for sample in theta_sample:
    #        s = {f'beta{d}': float(sample[d])for d in range(len(sample))}
    #        data.append(s)
    #    data = pd.DataFrame(data)
    #    sns.kdeplot(data=data, x='beta0', y='beta1')
        #plt.xlim([-2,2])
        #plt.ylim([-2,2])
    #    plt.show()
    #print(m.beta)


    result_dir = f'result/{model}_{noise}_{n}/PVI_{seed}_{s}/'
    os.makedirs(result_dir,exist_ok=True)
    plotdata = np.array(m.link(theta_sample)).flatten()
    cls = DiscreteDensity()
    true_y = cls.sample(10000)
    print(scipy.stats.kstest(true_y, plotdata))
    np.savez_compressed(f'result/{model}_{noise}_{n}/PVI_{seed}_{s}/sample.npz', samples = plotdata)

    result_file = f'result/{model}_{noise}_{n}/PVI_{seed}_{s}/eval'
    with open(result_file, 'w') as f:
        stats = scipy.stats.kstest(true_y, plotdata)
        print(eval(random.PRNGKey(0), flow), stats.statistic, stats.pvalue, file=f)
    #np.savez_compressed(f'result/{model}_{alpha}_{n}_{g}_{h}/PVI_{seed}_{s}/sample.npz', samples = np.array(theta_sample))


if __name__ == '__main__':
    main()