from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel, UNet2DModel, UNet2DModel_G, UNet2DModel_H, DDPMPipeline_H

import torch
from diffusers import DiffusionPipeline
from utils import local_basis
from utils import compute_grsm_metric
import numpy as np
import random
from tqdm import tqdm
import os

model_id = "google/ddpm-cifar10-32"
while True:
    ##########settings#############
    weight = 'iso_exp/lsun_church_32batch/checkpoint-3200/unet/' # CelebA
    weight_default = "google/ddpm-ema-church-256" # CelebA
    # weight = "iso_exp/lsun_bedrooms_32batch/checkpoint-3000/unet/"
    # weight_default = "google/ddpm-ema-bedroom-256"
    # weight = 'output_jae/test_39/checkpoint-6000/unet' # CIFAR10
    # weight_default = 'output_jae/test_34/checkpoint-100/unet' # CIFAR10
    print('weight=', weight)
    print('weight default=', weight_default)

    close_sample_eps = 1e-1 # perturbation intensity
    seed_max = 100000 # max seed
    pooling_kernel = 8 # pooling kernel of h feature
    N = 100 # the number of pair
    num_topk = 102
    save_folder = f"geodesic/bedroom_N{N}_eps{close_sample_eps}_top{num_topk}/" # save_folder name
    ###############################
    # os.makedirs(save_folder, exist_ok=True)

    unet = UNet2DModel_H.from_pretrained(weight).to('cuda')
    unet_default = UNet2DModel_H.from_pretrained(weight_default).to('cuda')
    sampling_shape = (1,unet.config.in_channels, unet.config.sample_size, unet.config.sample_size)

    scheduler = DDIMScheduler.from_pretrained(model_id)
    pipeline = DDPMPipeline_H(
                        unet=unet,
                        scheduler=scheduler,
                    ).to('cuda')
    pipeline_default = DDPMPipeline_H(
                        unet=unet_default,
                        scheduler=scheduler,
                    ).to('cuda')


    def close_sample(sample, eps=close_sample_eps, DOT_THRESHOLD=0.999995):
        inputs_are_torch = isinstance(sample, torch.Tensor)
        v0 = sample
        v1 = torch.randn_like(v0)
        t = eps
        if inputs_are_torch:
            input_device = v0.device
            v0 = v0.cpu().numpy()
            v1 = v1.cpu().numpy()
            # t = t.cpu().numpy()

        dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
        if np.abs(dot) > DOT_THRESHOLD:
            v2 = (1 - t) * v0 + t * v1
        else:
            theta_0 = np.arccos(dot)
            sin_theta_0 = np.sin(theta_0)
            theta_t = theta_0 * t
            sin_theta_t = np.sin(theta_t)
            s0 = np.sin(theta_0 - theta_t) / sin_theta_0
            s1 = sin_theta_t / sin_theta_0
            v2 = s0 * v0 + s1 * v1

        if inputs_are_torch:
            v2 = torch.from_numpy(v2).to(input_device)

        return v2

    # start = time()

    metric_list = []
    for _ in tqdm(range(N)):
        seed = random.randint(0, seed_max)
        sample, h_basis, s, x_basis = local_basis(unet, seed=seed, pooling_kernel=pooling_kernel, shape=sampling_shape)
        sample2 = close_sample(sample)
        # sample2 = sample * (1-1e-2)
        sample2, h_basis2, s2, x_basis2 = local_basis(unet, sample=sample2, pooling_kernel=pooling_kernel, shape=sampling_shape)
        metric = compute_grsm_metric(h_basis.cpu(), h_basis2.cpu(), d = num_topk, metric_type = 'geodesic')
        metric_list.append(metric)

    metric_list = np.array(metric_list)
    print("iso mean = ", np.mean(metric_list), "iso std = ",np.std(metric_list))


    metric_list_d = []
    for _ in tqdm(range(N)):
        seed = random.randint(0, seed_max)
        sample, h_basis, s, x_basis = local_basis(unet_default, seed=seed, pooling_kernel=pooling_kernel, shape=sampling_shape)
        sample2 = close_sample(sample)
        # sample2 = sample * (1-1e-2)
        sample2, h_basis2, s2, x_basis2 = local_basis(unet_default, sample=sample2, pooling_kernel=pooling_kernel, shape=sampling_shape)
        # print((sample == sample2).sum()/np.sum(sample.shape))
        # print((h_basis == h_basis2).sum()/np.sum(h_basis.shape))
        # print((h_basis - h_basis2).sum())
        metric = compute_grsm_metric(h_basis.cpu(), h_basis2.cpu(), d = num_topk, metric_type = 'geodesic')
        metric_list_d.append(metric)

    metric_list_d = np.array(metric_list_d)
    print("base mean = ", np.mean(metric_list_d), "base std = ",np.std(metric_list_d))

    # np.save(os.path.join(save_folder, "iso"),metric_list)
    # np.save(os.path.join(save_folder, "base"),metric_list_d)




# end = time()
# print(end-start)
# sample_default, h_basis_default, s_default, x_basis_default = local_basis(unet_default, seed=seed, pooling_kernel=pooling_kernel)

# print("Allocated:", round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB', 'Cached:', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')
