import torch
import torchvision
import torchvision.transforms as transforms
import pandas as pd
from openpyxl import Workbook
from typing import Tuple
import torch.distributions.normal as norm
import numpy as np
from torch.distributions import Normal

k=1/2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='/idas/users/liusirui/caculate_steps/data/', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
                                          shuffle=True, num_workers=2)

class GaussianDiffusionTrainer:
    def __init__(self, beta: Tuple[int, int], T: int):
        self.T = T
        self.beta_t = torch.linspace(*beta, T, dtype=torch.float32, device=device)
        alpha_t = 1.0 - k*self.beta_t
        alpha_t_bar = torch.cumprod(alpha_t, dim=0)
        self.signal_rate = torch.sqrt(alpha_t_bar)
        self.noise_rate = torch.sqrt(1.0 - alpha_t_bar)

    def forward(self, x_0, t):
        t = torch.full((x_0.shape[0],), t, dtype=torch.int64, device=x_0.device)
        epsilon = torch.randn_like(x_0, device=device)
        scaling_factor = torch.sqrt(torch.tensor(1/k, device=x_0.device))
        # predict the noise added from $x_{t-1}$ to $x_t$
        x_t = (extract(self.signal_rate, t, x_0.shape) * x_0 +
               extract(self.noise_rate, t, x_0.shape) * epsilon*scaling_factor)
        return x_t

def extract(v, i, shape):
    out = torch.gather(v, index=i, dim=0)
    out = out.to(device=i.device, dtype=torch.float32)
    out = out.view([i.shape[0]] + [1] * (len(shape) - 1))
    return out

trainer = GaussianDiffusionTrainer((0.0001, 0.02), 1000)

pixel_counts = torch.zeros(3, 32, 32, 8, device=device)

t = 999
total_images = 0
all_x_t_images = torch.empty(0).to(device)
for images, _ in trainloader:
    images = images.to(device)
    batch_size = images.shape[0]
    total_images += batch_size
    x_t = trainer.forward(images,t)
    all_x_t_images = torch.cat((all_x_t_images, x_t), dim=0)
flattened_all_images = all_x_t_images.view(all_x_t_images.shape[0], 3, 32, 32)

k_tensor = torch.tensor(1/k, dtype=torch.float32)

conditions = [flattened_all_images < -3 * torch.sqrt(k_tensor),
              (flattened_all_images >= -3 * torch.sqrt(k_tensor)) & (flattened_all_images < -2 * torch.sqrt(k_tensor)),
              (flattened_all_images >= -2 * torch.sqrt(k_tensor)) & (flattened_all_images < -1 * torch.sqrt(k_tensor)),
              (flattened_all_images >= -1 * torch.sqrt(k_tensor)) & (flattened_all_images < 0 * torch.sqrt(k_tensor)),
              (flattened_all_images >= 0 * torch.sqrt(k_tensor)) & (flattened_all_images < 1 * torch.sqrt(k_tensor)),
              (flattened_all_images >= 1 * torch.sqrt(k_tensor)) & (flattened_all_images < 2 * torch.sqrt(k_tensor)),
              (flattened_all_images >= 2 * torch.sqrt(k_tensor)) & (flattened_all_images < 3 * torch.sqrt(k_tensor)),
              flattened_all_images >= 3 * torch.sqrt(k_tensor)]
result_tensor = torch.zeros(3, 32, 32, device=device)
for i, condition in enumerate(conditions):
    count = condition.sum(dim=0).cpu().numpy()
    ratio = count / total_images
    if i == 0:
        normal_dist = Normal(0, 1)
        cdf_diff = ratio - normal_dist.cdf(torch.tensor(-3, device=device)).item()
    elif i == 7:
        normal_dist = Normal(0, 1)
        cdf_diff = ratio - (1 - normal_dist.cdf(torch.tensor(3, device=device)).item())
    else:
        lower_bound = -4 + i
        upper_bound = -4 + i + 1
        normal_dist = Normal(0, 1)
        cdf_diff = ratio - (normal_dist.cdf(torch.tensor(upper_bound, device=device)).item() - normal_dist.cdf(torch.tensor(lower_bound, device=device)).item())
    f_value = abs(cdf_diff)
    f_value_tensor = torch.tensor(f_value, device=device)
    result_tensor += f_value_tensor.to(device).reshape(3, 32, 32)


norm_result = torch.norm(result_tensor, p=2)
print(norm_result)