import torch
from heterogeneity_gym_hsp90.src.heterogeneity_gym.hsp90 import hsp90_jax as hsp90
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import jax.numpy as jnp
import numpy as np
from time import time
import cv2
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import os
import jax.scipy as jsc
import numpyro.distributions as dist
from jax import random, vmap
import sys

image_width_in_pixels = 128
all_N = 10000

class HSP90Data(Dataset):
    def __init__(self, img_dir='../HSP90', split = 'train', n=10000, spl = False):
        self.files = []
        training_N = n * 8000 // all_N
        valid_N = n * 1000 // all_N
        test_N = n - training_N - valid_N
        for i in range(n):
            if spl:
                if split == 'train' and i >= training_N:
                    continue
                if split == 'test' and i < training_N + valid_N:
                    continue
                if split == 'valid' and (i < training_N or i >= training_N + valid_N):
                    continue
            file = os.path.join(img_dir, f'{i}.png')
            if not os.path.exists(file):
                continue
            self.files.append(file)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = cv2.imread(self.files[idx], cv2.IMREAD_GRAYSCALE)
        return (img/255) * 40 - 20

class HSP90NoRot:
    d = 1
    def __init__(self, dir = '../HSP90', batch_size = 64, noise_std = 10., n = 10000, *args, **kwargs):

        self.train_data = HSP90Data(split='train', img_dir=dir, n=n)
        self.valid_data = HSP90Data(split='valid', img_dir=dir, n=n)
        self.test_data = HSP90Data(split='test', img_dir=dir, n=n)
        self.train_loader = DataLoader(self.train_data, batch_size=batch_size, shuffle=True)
        self.valid_loader = DataLoader(self.valid_data, batch_size=batch_size, shuffle=True)
        self.test_loader = DataLoader(self.test_data, batch_size=batch_size, shuffle=True)

        self.model = hsp90.HSP90_Model(latent_density=DiscreteDensity(), pixel_size=1.1, defocus_range=(1000, 2000), noise_strength=noise_std)

    def link(self, theta):
        return jsc.special.expit(theta) * 19
    def log_prior(self, theta):
        return jnp.sum(dist.Normal(0, 2).log_prob(theta))

    def log_likelihood(self, theta, y):
        raise NotImplementedError()

    def log_likelihoods(self, theta, y):
        raise NotImplementedError()
    def test_log_likelihoods(self, theta, y):
        raise NotImplementedError()
    def sample_datapoint(self, theta, rng_key):
        theta = self.link(theta)[0]
        k1, k2 = random.split(rng_key)
        poses = jnp.zeros((1, 6))
        poses = poses.at[:, 3].set(90.)
        poses = poses.at[:, 4].set(90.)
        data, _, _, _ = self.model.render_images_from_interpolated_latent(theta, k2, poses=poses)
        return data

    def sample_prior(self, key):
        raise NotImplementedError()

    def data(self, ):
        ite = iter(self.train_loader)
        while True:
            try:
                x = next(ite)
            except:
                ite = iter(self.train_loader)
                x = next(ite)
            yield x

    def valid_data(self, ):
        ite = iter(self.valid_loader)
        while True:
            try:
                x = next(ite)
            except:
                ite = iter(self.valid_loader)
                x = next(ite)
            yield x

    def test_data(self, ):
        ite = iter(self.test_loader)
        while True:
            try:
                x = next(ite)
            except:
                ite = iter(self.test_loader)
                x = next(ite)
            yield x

    def likelihood_parameters(self, theta):
        raise NotImplementedError()

    def M(self, theta = None):
        raise NotImplementedError()

    def validate_crps(self, theta1, theta2, key, valid_y, link = True):
        valid_y = jnp.array(valid_y)
        if link:
            theta1 = self.link(theta1)
            theta2 = self.link(theta2)
        k1, k2, key1, key2 = random.split(key, 4)
        poses1 = jnp.zeros((1, 6))
        poses1 = poses1.at[:, 3].set(90.)
        poses1 = poses1.at[:, 4].set(90.)
        poses2 = jnp.zeros((1, 6))
        poses2 = poses2.at[:, 3].set(90.)
        poses2 = poses2.at[:, 4].set(90.)
        y1, _, _, _ = self.model.render_images_from_interpolated_latent(theta1, key1, poses=poses1)
        y2, _, _, _ = self.model.render_images_from_interpolated_latent(theta2, key2, poses=poses2)
        return jnp.mean(jnp.linalg.norm(y1 - valid_y, axis = (1, 2)) / 2 + jnp.linalg.norm(y2 - valid_y, axis = (1, 2)) / 2 - jnp.linalg.norm(y1 - y2,) / 2)

    def plot(self, theta1, theta2, y3):
        fig, axes = plt.subplots(3, 6, sharex=True, sharey=True, figsize=(8, 5))
        theta1 = self.link(theta1)
        theta2 = self.link(theta2)
        k1, k2, key1, key2 = random.split(random.PRNGKey(0), 4)
        poses1 = jnp.zeros((1, 6))
        poses1 = poses1.at[:, 3].set(90.)
        poses1 = poses1.at[:, 4].set(90.)
        poses2 = jnp.zeros((1, 6))
        poses2 = poses2.at[:, 3].set(90.)
        poses2 = poses2.at[:, 4].set(90.)
        y1, _, _, _ = self.model.render_images_from_interpolated_latent(theta1, key1, poses=poses1)
        y2, _, _, _ = self.model.render_images_from_interpolated_latent(theta2, key2, poses=poses2)
        y1 = np.array(y1)
        y2 = np.array(y2)
        y3 = np.array(y3)
        for i, ax in enumerate(axes[0]):
            ax.imshow(y1[0])
        for i, ax in enumerate(axes[1]):
            ax.imshow(y2[0])
        for i, ax in enumerate(axes[2]):
            ax.imshow(y3[i])
        plt.show()
        plt.clf()

    def test_crps(self, theta1, theta2, key, test_y):
        test_y = jnp.array(test_y)
        theta1 = self.link(theta1)
        theta2 = self.link(theta2)
        k1, k2, key1, key2 = random.split(key, 4)
        poses1 = jnp.zeros((1, 6))
        poses1 = poses1.at[:, 3].set(90.)
        poses1 = poses1.at[:, 4].set(90.)
        poses2 = jnp.zeros((1, 6))
        poses2 = poses2.at[:, 3].set(90.)
        poses2 = poses2.at[:, 4].set(90.)
        y1, _, _, _ = self.model.render_images_from_interpolated_latent(theta1, key1, poses=poses1)
        y2, _, _, _ = self.model.render_images_from_interpolated_latent(theta2, key2, poses=poses2)

        return jnp.mean(jnp.linalg.norm(y1 - test_y, axis = (1, 2)) / 2 + jnp.linalg.norm(y2 - test_y, axis = (1, 2)) / 2 - jnp.linalg.norm(y1 - y2,) / 2)

    def oracle(self, N = 10):
        density = DiscreteDensity()
        theta1 = density.sample(N)
        theta2 = density.sample(N)
        theta1 = np.expand_dims(theta1, 1)
        theta2 = np.expand_dims(theta2, 1)

        keys = random.split(random.PRNGKey(0), N)
        crpss = []
        for batch in self.valid_loader:
            crps = vmap(self.validate_crps, in_axes=(0, 0, 0, None, None))(theta1, theta2, keys, batch, False)
            crpss.append(jnp.mean(crps))
            print('oracle', jnp.mean(crps))

            #self.plot(theta1[0], theta2[0], batch)

        print("The oracle stats is", jnp.mean(jnp.array(crpss)))




class DiscreteDensity:
    def __init__(self, N=20):
        self.N = N
        self.cat = np.array([1, 2, 4, 7, 10, 12, 10, 7, 4, 2,
                     1, 2, 3, 5, 7, 8, 7, 5, 3, 2])
        self.probs = self.cat / np.sum(self.cat)
    def sample(self, num, shuffle=True):
        return np.random.choice(np.arange(self.N), num, p=self.probs)

if __name__ == '__main__':
    #noise_scale = float(sys.argv[1])

    print(time())
    model1 = hsp90.HSP90_Model(latent_density=DiscreteDensity(), pixel_size=1.1, defocus_range=(1000, 2000), noise_strength=1.)
    model2 = hsp90.HSP90_Model(latent_density=DiscreteDensity(), pixel_size=1.1, defocus_range=(1000, 2000), noise_strength=2.)
    model3 = hsp90.HSP90_Model(latent_density=DiscreteDensity(), pixel_size=1.1, defocus_range=(1000, 2000), noise_strength=3.)
    model4 = hsp90.HSP90_Model(latent_density=DiscreteDensity(), pixel_size=1.1, defocus_range=(1000, 2000), noise_strength=4.)
    model5 = hsp90.HSP90_Model(latent_density=DiscreteDensity(), pixel_size=1.1, defocus_range=(1000, 2000), noise_strength=5.)

    clean_model = hsp90.HSP90_Model(latent_density=DiscreteDensity(), pixel_size=1.1, defocus_range=(1000, 2000),)
    print(time(), 'loaded')
    #latent_code = jnp.array([1, 4, 7, 10, 13, 16,])
    #poses = jnp.zeros((len(latent_code), 6))

    #poses = poses.at[:, 3].set(90.)
    #poses = poses.at[:, 4].set(90.)
    #print(time(), poses)
    #images, structures, rotations, ctfs = model.render_images_from_latent(latent_code, poses=poses, noise_std=0.9)
    #images2, structures, rotations, ctfs = model.render_images_from_latent(latent_code, poses=None, noise_std=0)
    #images, structures, poses, ctf_params, latent_samples = model.sample_images(10, random.PRNGKey(0))
    #clean_images, _, _, _, _ = clean_model.sample_images(10, random.PRNGKey(0))
    #print("SNR:", jnp.mean(clean_images * clean_images) / noise_scale / noise_scale)

   # fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(8, 5))
    #axf = np.array(axes).ravel()

    #for i, ax in enumerate(axes[0]):
    #    ax.imshow(images[i])
    #for i, ax in enumerate(axes[1]):
    #    ax.imshow(images2[i])
    #plt.show()
    #plt.clf()
    density = DiscreteDensity()

    id = 0
    rng_key = random.split(random.PRNGKey(0), 10000)
    thetas = []
    for i in range(10000):
        #raw_images, _, latent_samples, _, _ = model.sample_images(10, rng_key[i])
        theta1 = density.sample(1)
        thetas.append(theta1)
        #theta1 = np.expand_dims(theta1, 1)
        k1, k2 = random.split(rng_key[i])
        poses = jnp.zeros((1, 6))

        poses = poses.at[:, 3].set(90.)
        poses = poses.at[:, 4].set(90.)
        y1, _, _, _ = model1.render_images_from_latent(theta1, k1, poses=poses)
        y2, _, _, _ = model2.render_images_from_latent(theta1, k1, poses=poses)
        y3, _, _, _ = model3.render_images_from_latent(theta1, k1, poses=poses)
        y4, _, _, _ = model4.render_images_from_latent(theta1, k1, poses=poses)
        y5, _, _, _ = model5.render_images_from_latent(theta1, k1, poses=poses)

        mi = -20
        ma = 20
        image = jnp.clip(y1[0], mi, ma)
        image = (image - mi) / (ma - mi) * 255
        cv2.imwrite(f"../HSP90_after_1/{id}.png", np.array(image, dtype=int))
        image = jnp.clip(y2[0], mi, ma)
        image = (image - mi) / (ma - mi) * 255
        cv2.imwrite(f"../HSP90_after_2/{id}.png", np.array(image, dtype=int))
        image = jnp.clip(y3[0], mi, ma)
        image = (image - mi) / (ma - mi) * 255
        cv2.imwrite(f"../HSP90_after_3/{id}.png", np.array(image, dtype=int))
        image = jnp.clip(y4[0], mi, ma)
        image = (image - mi) / (ma - mi) * 255
        cv2.imwrite(f"../HSP90_after_4/{id}.png", np.array(image, dtype=int))
        image = jnp.clip(y5[0], mi, ma)
        image = (image - mi) / (ma - mi) * 255
        cv2.imwrite(f"../HSP90_after_5/{id}.png", np.array(image, dtype=int))

        id += 1

    thetas = np.array(thetas)
    np.savez_compressed(f'../HSP90_after_1/label.npz', theta=thetas)