from diffusers import UNet2DModel_H, DDIMScheduler
from torchvision.utils import save_image
import torch
import numpy as np
import os
from diffusers import DDPMPipeline, DDPMScheduler, DDIMScheduler, UNet2DModel, UNet2DModel_H, DDPMPipeline_H
from eval import compute_fid, compute_ppl, compute_ppl_end, add_dimensions, compute_distortion_per_timesteps, calculate_frechet_distance
from accelerate import Accelerator

accelerator = Accelerator()
torch_device = accelerator.device
ngpus = accelerator.num_processes
while True:
    # print("start")
    # dataset_weights = [[f"iso_exp/lsun_bedrooms_32batch/checkpoint-{i}/unet" for i in range(1600, 4000, 200)], [f"iso_exp/lsun_church_32batch/checkpoint-{i}/unet" for i in range(1600, 4000, 200)]]
    # dataset_weights = [["google/ddpm-ema-bedroom-256", "iso_exp/lsun_bedrooms_32batch/checkpoint-3000/unet"], ["google/ddpm-ema-church-256", "iso_exp/lsun_church_32batch/checkpoint-3200/unet"]]
    dataset_weights = [["iso_exp/lsun_church_32batch/checkpoint-1800/unet"]]
    # dataset_weights = [["google/ddpm-ema-bedroom-256"], ["google/ddpm-ema-church-256"]]
    # dataset_numpy = ["assets/stats/bedroom_50k.npz", "assets/stats/church_50k.npz"]
    # model_paths = ["google/ddpm-ema-church-256", "google/ddpm-ema-bedroom-256"]
    dataset_numpy = ["assets/stats/church_50k.npz"]
    # model_paths = ["google/ddpm-ema-bedroom-256"]
    model_paths = ["google/ddpm-ema-church-256"]
    assert len(dataset_weights) == len(dataset_numpy)
    for index in range(len(dataset_weights)):    

        model_path = model_paths[index]
        fid_stats_path = dataset_numpy[index]
        weight_paths= dataset_weights[index]

        # weight_paths = ["google/ddpm-ema-celebahq-256"]
        # weight_paths = ["output_jae/test_18/checkpoint-4000/unet", "output_jae/test_22/checkpoint-1000/unet", "output_jae/test_33/checkpoint-200/unet"]
        # weight_paths = [f"output_jae/test_{m}/checkpoint-{n}/unet" for n in range(100, 600, 100) for m in range(34, 36)]

        # weight_paths = [f"iso_exp/lsun_bedrooms_32batch/checkpoint-{i}/unet" for i in range(200, 1500, 200)] + [f"iso_exp/lsun_church_32batch/checkpoint-{i}/unet" for i in range(200, 1500, 200)]

        N_fid = 10000 #50000
        N_ppl = 10000 #10000
        fid_num_inference_steps = 100
        ppl_num_inference_steps = 20

        batch_size = 64
        fid_cond = True
        ppl_cond = False
        ppl_end_cond = False

        # print(N, ngpus)
        assert N_fid % ngpus == 0
        assert N_fid % ngpus == 0
        NperGPU_fid = N_fid // ngpus
        NperGPU_ppl = N_ppl // ngpus

        if type(weight_paths) is list:
            for weight in weight_paths:
                if accelerator.is_local_main_process:
                    print(f'\nPrinting metrics of {weight}')
                # unet_path = os.path.join(unet_folder, step_weight, 'unet')
                unet = UNet2DModel_H.from_pretrained(weight)
                scheduler = DDIMScheduler.from_pretrained(model_path)
                scheduler.set_timesteps(num_inference_steps=fid_num_inference_steps)
                pipeline = DDPMPipeline_H(
                                    unet=unet,
                                    scheduler=scheduler,
                                )
                pipeline.set_progress_bar_config(disable=True)
                # torch_device = "cuda" if torch.cuda.is_available() else "cpu"
                # unet.to(torch_device)
                pipeline.to(torch_device)
                generator = None
                sampling_shape = [batch_size, 3, unet.sample_size, unet.sample_size]

                
                if fid_cond:
                    if accelerator.is_local_main_process:
                        print(f"Calculating FID with num_inference_steps={fid_num_inference_steps}.")
                    if accelerator.num_processes > 1:
                        act = compute_fid(NperGPU_fid, ngpus, sampling_shape, fid_num_inference_steps, pipeline, generator, fid_stats_path, torch_device, accelerator)
                        accelerator.wait_for_everyone()
                        act = accelerator.gather(torch.tensor(act).cuda())
                        if accelerator.is_local_main_process:
                            act = act.cpu().detach().numpy()
                            mu = np.mean(act, axis=0)
                            sigma = np.cov(act, rowvar=False)
                            m = torch.from_numpy(mu).cuda()
                            s = torch.from_numpy(sigma).cuda()

                            all_pool_mean = m.cpu().numpy()
                            all_pool_sigma = s.cpu().numpy()

                            stats = np.load(fid_stats_path)
                            data_pools_mean = stats['mu']
                            data_pools_sigma = stats['sigma']
                            fid = calculate_frechet_distance(data_pools_mean, data_pools_sigma, all_pool_mean, all_pool_sigma)
                    else:
                        fid = compute_fid(NperGPU_fid, ngpus, sampling_shape, fid_num_inference_steps, pipeline, generator, fid_stats_path, torch_device, accelerator)
                    if accelerator.is_local_main_process:
                        print(f"FID {fid:.2f}")
                        print(f"FID with {N_fid} samples.")
                
                if ppl_cond:
                    if accelerator.is_local_main_process:
                        print("Calculating PPL")

                    ppl = compute_ppl(n_samples=NperGPU_ppl, n_gpus=ngpus, sampling_shape=sampling_shape, num_inference_steps=ppl_num_inference_steps, sampler=pipeline, gen=generator, device=torch_device, accelerator=accelerator)
                    accelerator.wait_for_everyone()
                    ppl = accelerator.gather(torch.tensor(ppl).cuda())        
                    if accelerator.is_local_main_process:
                        ppl = ppl.mean().cpu().detach().numpy()
                        print(f"PPL {ppl:.2f}")
                        print(f"PPL with {N_ppl} samples.")

                if ppl_end_cond:
                    if accelerator.is_local_main_process:
                        print("Calculating PPL_end")

                    ppl_end = compute_ppl_end(n_samples=NperGPU_ppl, n_gpus=ngpus, sampling_shape=sampling_shape, num_inference_steps=ppl_num_inference_steps, sampler=pipeline, gen=generator, device=torch_device, accelerator=accelerator)
                    accelerator.wait_for_everyone()
                    ppl_end = accelerator.gather(torch.tensor(ppl_end).cuda())        
                    if accelerator.is_local_main_process:
                        ppl_end = ppl_end.mean().cpu().detach().numpy()
                        print(f"PPL_end {ppl_end:.2f}")
                        print(f"PPL_end with {N_ppl} samples.")

                if accelerator.is_local_main_process:
                    print("Calculating metrics done!")

