import click
import importlib
import pandas as pd
import numpy as np
import seaborn as sns
from jax import random, vmap
from jax.example_libraries import optimizers
import jax
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import scipy
import generax
import jaxopt
import jax.numpy as jnp
import optax
import equinox as eqx
import functools
import torch
from sbi.inference import SNRE, prepare_for_sbi, simulate_for_sbi
from sbi.inference.base import infer

@click.command()
@click.option('--model', default='Y2Regression')
@click.option('--seed', default=0)
@click.option('--s', default=1)
@click.option('--iterations', default=1000)
@click.option('--n', default=100)
@click.option('--g', default=5)
@click.option('--h', default=5)
@click.option('--alpha', default=0.)
@click.option('--prediction_sample', default=1000)
@click.option('--test_sample', default=10000)
@click.option('--learning_rate', default=0.00001)
@click.option('--sbi_samples', default = 20000)
@click.option('--round', default = 1)
@click.option('--plot', is_flag = True)
def main(model, seed, s, iterations, n, g, h, alpha, prediction_sample, test_sample, learning_rate, sbi_samples, round, plot):
    module1 = importlib.import_module('model')
    m = getattr(module1, model)(n, g, m=h, alpha=alpha)
    y = m.data(random.PRNGKey(1))
    test_y = None
    if type(y) == tuple:
        y, test_y = y
    torch.manual_seed(seed)
    x = np.array(m.x)
    num_dim = m.d
    prior = torch.distributions.MultivariateNormal(torch.zeros(num_dim), torch.eye(num_dim))
    x = torch.tensor(x)
    def simulator(theta):
        beta = theta[:m.x_dim]
        sigma = torch.exp(theta[m.x_dim])
        y = torch.randn((n, h)) * sigma + torch.sum(beta * x, -1)
        return torch.sum(y * y, -1)

    #simulator, prior = prepare_for_sbi(simulator, prior)
    inference = SNRE(prior, )
    proposal = prior
    observation = torch.tensor(np.array(y))
    samples = proposal.sample((sbi_samples,))
    for i in tqdm(range(round)):
        theta = torch.Tensor(samples)
        y = torch.vmap(simulator, randomness = 'different')(theta)
        _ = inference.append_simulations(theta, y,).train()
        posterior = inference.build_posterior(sample_with="mcmc", mcmc_method='nuts').set_default_x(observation)
        proposal = posterior
        if i < round - 1:
            samples = proposal.sample((sbi_samples,), thin=1)
    #posterior = inference.build_posterior(density_estimator)
    #posterior = infer(simulator, prior, method="SNPE", num_simulations=sbi_samples)
    #posterior = inference.build_posterior(posterior)
    samples = posterior.sample((test_sample * 2,), x=observation, thin=1)
    if plot:
        data = []
        for sample in samples:
            data.append({'beta1':float(sample[0]), 'beta2':float(sample[1]), 'sigma':float(sample[2])})
        data = pd.DataFrame(data)
        sns.kdeplot(data=data, x='beta1', y='beta2')
        plt.show()
    print(m.beta)
    samples = jnp.array(samples)

    crpss = vmap(m.test_crps, in_axes=(0, 0, 0, None))(samples[:test_sample], samples[test_sample:], random.split(random.PRNGKey(0), test_sample), test_y)
    print(jnp.mean(crpss))
    os.makedirs(f'result/{model}_{alpha}_{n}_{g}_{h}/Torch4_{seed}_{s}_{round}/', exist_ok = True)
    result_file = f'result/{model}_{alpha}_{n}_{g}_{h}/Torch4_{seed}_{s}_{round}/eval'
    with open(result_file, 'w') as f:
        print(jnp.mean(crpss), file=f)

    np.savez_compressed(f'result/{model}_{alpha}_{n}_{g}_{h}/Torch4_{seed}_{s}_{round}/sample.npz', samples = np.array(samples))



if __name__ == '__main__':
    main()