import click
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import scipy
import torch
from sbi.inference import SNPE
import cv2
from img2vec_pytorch import Img2Vec
from PIL import Image

class DiscreteDensity:
    def __init__(self, N=20):
        self.N = N
        self.cat = np.array([1, 2, 4, 7, 10, 12, 10, 7, 4, 2,
                             1, 2, 3, 5, 7, 8, 7, 5, 3, 2])
        self.probs = self.cat / np.sum(self.cat)

    def sample(self, num, shuffle=True):
        return np.random.choice(np.arange(self.N), num, p=self.probs)


@click.command()
@click.option('--seed', default=0)
@click.option('--n', default=10000)
@click.option('--img_dir', default='../HSP90_sim')
@click.option('--data_dir', default='../HSP90_fix')
@click.option('--plot', is_flag = True)
@click.option('--samples', default = 100)
@click.option('--noise', default = 5.0)
@click.option('--device', default = 'cuda')
def main(seed, n, img_dir, data_dir, plot, samples, noise, device):
    all_N = n
    torch.manual_seed(seed)
    prior = torch.distributions.Normal(torch.tensor(0.).to(device), 1.)
    ys = []
    img2vec = Img2Vec(model='resnet50', cuda=(device == 'cuda'))

    for i in tqdm(range(all_N)):
        file = os.path.join(img_dir, f'{i}.png')
        if not os.path.exists(file):
            continue
        img = cv2.imread(file, )
        im = Image.fromarray(img)
        vec = img2vec.get_vec(im, tensor=False)
        ys.append(vec)
    ys = torch.Tensor(ys).to(device)
    label_file = os.path.join(img_dir, 'label.npz')
    theta = np.load(label_file)['theta']
    theta = torch.logit((torch.Tensor(theta)+1)/21)[:len(ys)].to(device)

    #simulator, prior = prepare_for_sbi(simulator, prior)
    inference = SNPE(prior, density_estimator='nsf', device=device)
    proposal = prior

    _ = inference.append_simulations(theta, ys, proposal=None).train()
    posterior = inference.build_posterior()
    #posterior = inference.build_posterior(density_estimator)
    #posterior = infer(simulator, prior, method="SNPE", num_simulations=sbi_samples)
    #posterior = inference.build_posterior(posterior)
    one_samples = []
    mean_samples = []
    for i in tqdm(range(all_N)):
        file = os.path.join(data_dir, f'{i}.png')
        if not os.path.exists(file):
            continue
        img = cv2.imread(file,)
        im = Image.fromarray(img)

        vec = img2vec.get_vec(im, tensor=False)

        sample = posterior.sample((samples,), x=torch.Tensor(vec).to(device), show_progress_bars=False)
        one_samples.append(sample.detach().cpu().numpy()[0][0])
        mean_samples.append(np.mean(sample.detach().cpu().numpy()))
    one_samples = np.array(one_samples)
    one_samples = scipy.special.expit(one_samples) * 21 - 1
    mean_samples = np.array(mean_samples)
    mean_samples = scipy.special.expit(mean_samples) * 21 - 1
    if plot:
        sns.kdeplot(data=one_samples,)
        plt.show()
        plt.clf()
        sns.kdeplot(data=mean_samples,)
        plt.show()

    result_dir = f'result/HSP90NoRot_{noise}_{n}/SBI_{seed}/'
    os.makedirs(result_dir, exist_ok=True)
    cls = DiscreteDensity()
    true_y = cls.sample(10000)
    print(scipy.stats.kstest(true_y, one_samples))
    print(scipy.stats.kstest(true_y, mean_samples))

    np.savez_compressed(f'result/HSP90NoRot_{noise}_{n}/SBI_{seed}/sample.npz', one_samples=one_samples, mean_samples=mean_samples)

    result_file = f'result/HSP90NoRot_{noise}_{n}/SBI_{seed}/eval'
    with open(result_file, 'w') as f:
        stats = scipy.stats.kstest(true_y, one_samples)
        print(stats.statistic, stats.pvalue, file=f)
        stats = scipy.stats.kstest(true_y, mean_samples)
        print(stats.statistic, stats.pvalue, file=f)

    print(np.mean(one_samples))
    print(np.mean(mean_samples))



if __name__ == '__main__':
    main()