import click
import importlib
import pandas as pd
import numpy as np
import seaborn as sns
from jax import random, vmap
from jax.example_libraries import optimizers
import jax
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import scipy
import generax
import jaxopt
import jax.numpy as jnp
import optax
import equinox as eqx
import functools

@click.command()
@click.option('--model', default='Y2Regression')
@click.option('--seed', default=0)
@click.option('--s', default=1)
@click.option('--iterations', default=1000)
@click.option('--n', default=100)
@click.option('--prediction_sample', default=1000)
@click.option('--learning_rate', default=0.00001)
@click.option('--sbi_samples', default = 1000)
def main(model, seed, s, iterations, n, prediction_sample, learning_rate, sbi_samples):
    module1 = importlib.import_module('model')
    m = getattr(module1, model)(n)
    y = m.data(random.PRNGKey(1))
    test_y = None
    if type(y) == tuple:
        y, test_y = y

    data_key, rng_key = random.split(random.PRNGKey(seed))

    prior_key, y_key = random.split(data_key)

    thetas = vmap(m.sample_prior)(random.split(prior_key, sbi_samples))
    ys = vmap(m.sample_datapoint)(random.split(y_key, sbi_samples), thetas)
    def get_dataset_iter():
        key = random.PRNGKey(0)
        def get_train_ds(key, batch_size: int = 64):
            while True:
                key, sample_key = random.split(key, 2)
                prior_key, y_key = random.split(sample_key)
                thetas = vmap(m.sample_prior)(random.split(prior_key, batch_size))
                ys = vmap(m.sample_datapoint)(random.split(y_key, batch_size), thetas)
                yield dict(x=thetas, y = ys)

        train_ds = get_train_ds(key)
        return train_ds



    flow_key, init_key, rng_key = random.split(rng_key, 3)

    flow = generax.NeuralSpline(
        input_shape = (m.d, ),
        cond_shape = (n, ),
        n_flow_layers = 2,
        working_size = 8,
        hidden_size = 16,
        n_blocks = 2,
        n_spline_knots = 4,
        key = flow_key
    )

    flow = flow.data_dependent_init(thetas, ys, key=init_key)

    result_dir = f'result/{model}_{n}/{seed}_{s}'


    trainer = generax.Trainer(checkpoint_path=result_dir)

    chain = []

    schedule = optax.warmup_cosine_decay_schedule(init_value=0.0,
                                                  peak_value=1.0,
                                                  warmup_steps=1000,
                                                  decay_steps=100000,
                                                  end_value=0.1,
                                                  exponent=1.0)
    chain.append(optax.clip_by_global_norm(10.0))
    chain.append(optax.apply_if_finite(optax.adagrad(learning_rate), 10))
    chain.append(optax.scale_by_schedule(schedule))
    optimizer = optax.chain(*chain)
    #optimizer = optax.adam(learning_rate)

    def loss(flow, data, key):
        x = data['x']
        y = data['y']
        log_px = eqx.filter_vmap(flow.log_prob)(x, y)
        objective = -log_px.mean()

        aux = dict(log_px=log_px)
        return objective, aux

    def eval(key, flow):
        key, sim_key = random.split(key)
        theta_sample, logp = eqx.filter_vmap(flow.sample_and_log_prob, in_axes=(0, None))(random.split(key, prediction_sample), y)
        crpss = vmap(m.test_crps, in_axes=(0, 0, None))(theta_sample, random.split(sim_key, prediction_sample), test_y)

        return jnp.mean(crpss)

    test_key, rng_key = random.split(rng_key)

    print(eval(test_key, flow))

    def evaluate(model):
        print(eval(random.PRNGKey(0), model))

    flow = trainer.train(model = flow,
                          objective= loss,
                          evaluate_model = evaluate,
                          optimizer=optimizer,
                          num_steps=iterations,
                          data_iterator=get_dataset_iter(),
                         test_every = 1000,
                          retrain=True)

    eval_key, rng_key = random.split(rng_key)

    print(eval(random.PRNGKey(0), flow))

    result_file = f'result/{model}_{n}/{seed}_{s}/eval'
    with open(result_file, 'w') as f:
        print(eval(random.PRNGKey(0), flow), file=f)



if __name__ == '__main__':
    main()