import os
import cv2
import math
import numpy as np
from tqdm import tqdm
import imageio
import torch
import torchvision.utils as tvu
from torchvision import transforms

def get_beta_schedule(*, beta_start, beta_end, num_diffusion_timesteps):
    betas = np.linspace(beta_start, beta_end,
                        num_diffusion_timesteps, dtype=np.float64)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas

class Diffusion(object):
    def __init__(self, beta_start, beta_end, num_diffusion_timesteps):
        # if device is None:
        #     device = torch.device(
        #         "cuda") if torch.cuda.is_available() else torch.device("cpu")
        # self.device = device

        self.betas = get_beta_schedule(
            beta_start=beta_start,
            beta_end=beta_end,
            num_diffusion_timesteps=num_diffusion_timesteps
        )
        # self.betas = torch.from_numpy(betas).float()


    def sample_noise(self, image, end_step=10):
        # x0 = transforms.ToTensor()(image)
        x0 = image
        # x0 = (x0 - 0.5) * 2
        # a = (1 - self.betas).cumprod(dim=0)
        a = np.cumprod(1 - self.betas, axis=0)
        x = [x0]
        for i in range(0,end_step):
            e = np.random.normal(0, 1 ** 0.5, x0.shape)
            # x.append(x[i] * math.sqrt(a[50 * i]) + e * math.sqrt(1- a[50 * i]))
            x.append(x[i] * math.sqrt(a[2 * i]) + e * math.sqrt(1.0 - a[2 * i]))
            # x.append(x[i] + 0.5 * e)
        return x
    
    def sample_rigid_noise(self, image, end_step=9):
        # x0 = transforms.ToTensor()(image)
        x0 = image
        # x0 = (x0 - 0.5) * 2
        # a = (1 - self.betas).cumprod(dim=0)
        a = np.cumprod(1 - self.betas, axis=0)
        x = [x0]
        for i in range(0,end_step): # dont change for now, make it (0,9) or (0,8)
            theta = 0.5
            e = np.random.normal(0, 1 ** 0.5, x0.shape)
            # x.append(x[i] * math.sqrt(a[50 * i]) + e * math.sqrt(1- a[50 * i]))
            # x.append(x[i] * math.sqrt(a[2 * i]) + e * math.sqrt(1.0 - a[2 * i]))
            x.append(x[i] + theta * e)
        return x

def create_diffusion_runner(beta_start, beta_end, num_diffusion_timesteps):
    runner = Diffusion(beta_start, beta_end, num_diffusion_timesteps)
    return runner

if __name__ == '__main__':
    import matplotlib.pyplot as plt
    image = cv2.imread('/path/to/DR/GenImage/sdv4/train/ai/crop/000_sdv4_00000.png')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (256, 256))
    image = image / 255.0
    runner = create_diffusion_runner(0.0001, 0.02, 1000)
    image_list = runner.sample_noise(image, 400)
    plt.axis('off')
    plt.figure(figsize=(10, 10))
    print(len(image_list))
    for i in range(0, len(image_list)):
        plt.subplot(1, 9, i+1)
        plt.axis('off')
        plt.imshow(image_list[i])
    plt.show()
    plt.savefig("noisetest.png")
    # imageio.mimsave("ai.gif", [np.uint8(image * 255) for image in image_list], 'GIF', fps=100, loop=1)  
