from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel, UNet2DModel, UNet2DModel_G, UNet2DModel_H, DDPMPipeline_H

import torch
from diffusers import DiffusionPipeline
from utils import local_basis
from utils import compute_grsm_metric
import numpy as np
import random
from tqdm import tqdm
import os
from accelerate import Accelerator

accelerator = Accelerator()
torch_device = accelerator.device
ngpus = accelerator.num_processes

# dataset = 'cifar10'
dataset = 'celeba'
space = 'h'
if dataset == 'celeba':
    model_id = "google/ddpm-celebahq-256"
else:
    model_id = "google/ddpm-cifar10-32"
##########settings#############
if dataset == 'celeba':
    weight = 'output_jae/test_18/checkpoint-4000/unet' #"output_jae/test_44/checkpoint-8700/unet/" # CelebA
    weight_default = 'output_jae/test_22/checkpoint-1000/unet' # CelebA

else:
    weight = 'output_jae/test_40/checkpoint-7000/unet' # CIFAR10
    weight_default = 'output_jae/test_34/checkpoint-100/unet' # CIFAR10
if accelerator.is_local_main_process:
    print('weight=', weight)
    print('weight default=', weight_default)

close_sample_eps = 1e-4 # perturbation intensity
# seed_max = 100000 # max seed
if dataset == 'celeba':
    pooling_kernel = 8 # pooling kernel of h feature
    num_topk = 51
else:
    pooling_kernel = 4
    num_topk = 25
N = 100 # the number of pair
save_folder = f"geodesic/cifar_N{N}_eps{close_sample_eps}_top{num_topk}/" # save_folder name
assert N % ngpus == 0
NperGPU = N // ngpus
###############################
# os.makedirs(save_folder, exist_ok=True)

unet = UNet2DModel_H.from_pretrained(weight).to('cuda')
unet_default = UNet2DModel_H.from_pretrained(weight_default).to('cuda')
sampling_shape = (1,unet.config.in_channels, unet.config.sample_size, unet.config.sample_size)
if accelerator.is_local_main_process:
    print("sampling shape = ", sampling_shape)
    print(f"{N} samples")
    print(f"Top - {num_topk}, close {close_sample_eps}")

scheduler = DDIMScheduler.from_pretrained(model_id)
scheduler.set_timesteps(num_inference_steps=10)

pipeline = DDPMPipeline_H(
                    unet=unet,
                    scheduler=scheduler,
                ).to('cuda')
pipeline_default = DDPMPipeline_H(
                    unet=unet_default,
                    scheduler=scheduler,
                ).to('cuda')


def close_sample(sample, eps=close_sample_eps, DOT_THRESHOLD=0.999995):
    inputs_are_torch = isinstance(sample, torch.Tensor)
    v0 = sample
    v1 = torch.randn_like(v0)
    t = eps
    if inputs_are_torch:
        input_device = v0.device
        v0 = v0.cpu().numpy()
        v1 = v1.cpu().numpy()
        # t = t.cpu().numpy()

    dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
    if np.abs(dot) > DOT_THRESHOLD:
        v2 = (1 - t) * v0 + t * v1
    else:
        theta_0 = np.arccos(dot)
        sin_theta_0 = np.sin(theta_0)
        theta_t = theta_0 * t # 몇도나 갈지
        sin_theta_t = np.sin(theta_t) 
        s0 = np.sin(theta_0 - theta_t) / sin_theta_0
        s1 = sin_theta_t / sin_theta_0
        v2 = s0 * v0 + s1 * v1

    if inputs_are_torch:
        v2 = torch.from_numpy(v2).to(input_device)

    return v2

def forward(sample):
    image = sample
    for t in scheduler.timesteps[:-1]:
        with torch.no_grad():
            residual = unet(image, t)[0]["sample"]
        prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
        image = prev_image
    return image


# start = time()
while True:
    metric_list = []
    pbar = tqdm(total = NperGPU, disable=not accelerator.is_local_main_process)
    for _ in range(NperGPU):
        # seed = random.randint(0, seed_max)
        seed = None
        # print(seed)
        x = torch.randn(sampling_shape, device=torch_device)
        sample = forward(x)
        sample, h_basis, s, x_basis = local_basis(unet, sample=sample, seed=seed, pooling_kernel=pooling_kernel, shape=sampling_shape)
        sample2 = close_sample(x)
        # sample2 = sample * (1-1e-2)

        sample2 = forward(sample2)
        sample2, h_basis2, s2, x_basis2 = local_basis(unet, sample=sample2, pooling_kernel=pooling_kernel, shape=sampling_shape)
        if space == 'x':
            metric = compute_grsm_metric(x_basis.cpu(), x_basis2.cpu(), d = num_topk, metric_type = 'geodesic')
        else:
            metric = compute_grsm_metric(h_basis.cpu(), h_basis2.cpu(), d = num_topk, metric_type = 'geodesic')
        metric_list.append(metric)
        pbar.update(1)
    pbar.close()
    accelerator.wait_for_everyone()
    metric_list = accelerator.gather(torch.tensor(metric_list).cuda())
    metric_list = np.array(metric_list.detach().cpu())
    if accelerator.is_local_main_process:
        print("iso mean = ", np.mean(metric_list), "iso std = ",np.std(metric_list))
    
    pbar = tqdm(total = NperGPU, disable=not accelerator.is_local_main_process)
    metric_list_d = []
    for _ in range(NperGPU):
        # seed = random.randint(0, seed_max)
        seed = None
        # print(seed)
        x = torch.randn(sampling_shape, device=torch_device)
        sample = forward(x)
        sample, h_basis, s, x_basis = local_basis(unet, sample=sample, seed=seed, pooling_kernel=pooling_kernel, shape=sampling_shape)
        sample2 = close_sample(x)
        # sample2 = sample * (1-1e-2)
        
        sample2 = forward(sample2)
        sample2, h_basis2, s2, x_basis2 = local_basis(unet, sample=sample2, pooling_kernel=pooling_kernel, shape=sampling_shape)
        if space == 'x':
            metric = compute_grsm_metric(x_basis.cpu(), x_basis2.cpu(), d = num_topk, metric_type = 'geodesic')
        else:
            metric = compute_grsm_metric(h_basis.cpu(), h_basis2.cpu(), d = num_topk, metric_type = 'geodesic')
        metric_list_d.append(metric)
        pbar.update(1)
    pbar.close()
    accelerator.wait_for_everyone()
    metric_list_d = accelerator.gather(torch.tensor(metric_list_d).cuda())
    metric_list_d = np.array(metric_list_d.detach().cpu())
    if accelerator.is_local_main_process:
        print("base mean = ", np.mean(metric_list_d), "base std = ",np.std(metric_list_d))

    # metric_list_d = []
    # for _ in tqdm(range(N)):
    #     # seed = random.randint(0, seed_max)
    #     seed = None
    #     # print(seed)
    #     sample, h_basis, s, x_basis = local_basis(unet_default, seed=seed, pooling_kernel=pooling_kernel, shape=sampling_shape)
    #     sample2 = close_sample(sample)
    #     # sample2 = sample * (1-1e-2)
    #     sample2, h_basis2, s2, x_basis2 = local_basis(unet_default, sample=sample2, pooling_kernel=pooling_kernel, shape=sampling_shape)
    #     # print((sample == sample2).sum()/np.sum(sample.shape))
    #     # print((h_basis == h_basis2).sum()/np.sum(h_basis.shape))
    #     # print((h_basis - h_basis2).sum())
    #     metric = compute_grsm_metric(h_basis.cpu(), h_basis2.cpu(), d = num_topk, metric_type = 'geodesic')
    #     metric_list_d.append(metric)

    # metric_list_d = np.array(metric_list_d)
    # print("base mean = ", np.mean(metric_list_d), "base std = ",np.std(metric_list_d))

# np.save(os.path.join(save_folder, "iso"),metric_list)
# np.save(os.path.join(save_folder, "base"),metric_list_d)




# end = time()
# print(end-start)
# sample_default, h_basis_default, s_default, x_basis_default = local_basis(unet_default, seed=seed, pooling_kernel=pooling_kernel)

# print("Allocated:", round(torch.cuda.memory_allocated(0)/1024**3, 1), 'GB', 'Cached:', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')
