import torch

from diffusion_arithmetics.noise_learning.distance_classification import get_noise_sample_by_distance_classification
# from diffusion_arithmetics.utils import plot_diffusion


def convert(tensor: torch.Tensor) -> torch.Tensor:
    tensor2 = tensor - tensor.min()
    tensor2 = tensor2 / tensor2.max()
    return tensor2


samples = torch.load("experiments/angles/openai_imagenet_samples.pth")
vars_samples = [(sample.var(), sample) for sample in samples.cpu()]
vars_samples2 = sorted(vars_samples, key=lambda x: x[0])
samples_3 = [convert(x[1]) for x in vars_samples2[:10]]
# plot_diffusion(samples_3)

samples = torch.load("experiments/angles/openai_cifar10_samples.pth")
vars_samples = [(sample.var(), sample) for sample in samples.cpu()]
vars_samples2 = sorted(vars_samples, key=lambda x: x[0])
samples_3 = [convert(x[1]) for x in vars_samples2[:10]]
# plot_diffusion(samples_3)


noise = torch.load("experiments/angles/openai_cifar10_noise.pth")
samples = torch.load("experiments/angles/openai_cifar10_samples.pth")
get_noise_sample_by_distance_classification(
    noises=noise.clone(), samples=samples.clone(), examples_limit=10, top_n_fits=4
)
