from diffusers import UNet2DModel_H, DDIMScheduler
from torchvision.utils import save_image
import torch
import numpy as np
import os
from diffusers import DDPMPipeline, DDPMScheduler, DDIMScheduler, UNet2DModel, UNet2DModel_H, DDPMPipeline_H
from eval import compute_fid, compute_ppl, compute_ppl_end, add_dimensions, compute_distortion_per_timesteps, calculate_frechet_distance
from accelerate import Accelerator

accelerator = Accelerator()
torch_device = accelerator.device
ngpus = accelerator.num_processes

model_path = "google/ddpm-ema-celebahq-256"
fid_stats_path = "assets/stats/cifar10.npz"
# fid_stats_path = "assets/stats/mattymchen_celeba.npz"

# weight_paths = ["google/ddpm-ema-celebahq-256"]
# weight_paths = ["output_jae/test_18/checkpoint-4000/unet", "output_jae/test_22/checkpoint-1000/unet", "output_jae/test_33/checkpoint-200/unet"]
# weight_paths = ["output_jae/test_34/checkpoint-100/unet"] + [f"output_jae/test_{36}/checkpoint-{n}/unet" for n in range(2000, 8000, 2000)]
# weight_paths = ["output_jae/test_34/checkpoint-100/unet", "output_jae/test_39/checkpoint-1000/unet"]
# weight_paths = ["pretrained/cifar10-32/checkpoint-528/unet", 'output_jae/test_39/checkpoint-8000/unet']
# weight_paths = ['output_jae/test_43/checkpoint-1000/unet', 'output_jae/test_41/checkpoint-7200/unet', 'output_jae/test_33/checkpoint-200/unet', 'output_jae/test_18/checkpoint-3000/unet']
# weight_paths = ['output_jae/test_18/checkpoint-4000/unet', 'output_jae/test_18/checkpoint-3000/unet']
weight_paths = ["pretrained/cifar10-32/checkpoint-528/unet"] + [f'output_jae/test_49/checkpoint-{n}/unet' for n in range(200, 1400, 200)]
# weight_paths = weight_paths * 100

N_fid = 10000 #50000
N_ppl = 50000 #10000
num_inference_steps = 20
epsilon = 1e-4
batch_size = 800
fid_cond = True
ppl_cond = True
ppl_end_cond = False

# print(N, ngpus)
assert N_fid % ngpus == 0
assert N_fid % ngpus == 0
NperGPU_fid = N_fid // ngpus
NperGPU_ppl = N_ppl // ngpus

if type(weight_paths) is list:
    for weight in weight_paths:
        if accelerator.is_local_main_process:
            print(f'\nPrinting metrics of {weight}')
        # unet_path = os.path.join(unet_folder, step_weight, 'unet')
        unet = UNet2DModel_H.from_pretrained(weight)
        scheduler = DDIMScheduler.from_pretrained(model_path)
        scheduler.set_timesteps(num_inference_steps=num_inference_steps)
        pipeline = DDPMPipeline_H(
                            unet=unet,
                            scheduler=scheduler,
                        )
        pipeline.set_progress_bar_config(disable=True)
        # torch_device = "cuda" if torch.cuda.is_available() else "cpu"
        # unet.to(torch_device)
        pipeline.to(torch_device)
        generator = None
        sampling_shape = [batch_size, 3, unet.sample_size, unet.sample_size]

        
        if fid_cond:
            if accelerator.is_local_main_process:
                print(f"Calculating FID with num_inference_steps={num_inference_steps}.")
            if accelerator.num_processes > 1:
                act = compute_fid(NperGPU_fid, ngpus, sampling_shape, num_inference_steps, pipeline, generator, fid_stats_path, torch_device, accelerator)
                accelerator.wait_for_everyone()
                act = accelerator.gather(torch.tensor(act).cuda())
                if accelerator.is_local_main_process:
                    act = act.cpu().detach().numpy()
                    mu = np.mean(act, axis=0)
                    sigma = np.cov(act, rowvar=False)
                    m = torch.from_numpy(mu).cuda()
                    s = torch.from_numpy(sigma).cuda()

                    all_pool_mean = m.cpu().numpy()
                    all_pool_sigma = s.cpu().numpy()

                    stats = np.load(fid_stats_path)
                    data_pools_mean = stats['mu']
                    data_pools_sigma = stats['sigma']
                    fid = calculate_frechet_distance(data_pools_mean, data_pools_sigma, all_pool_mean, all_pool_sigma)
            else:
                fid = compute_fid(NperGPU_fid, ngpus, sampling_shape, pipeline, generator, fid_stats_path, torch_device, accelerator)
            if accelerator.is_local_main_process:
                print(f"FID {fid:.2f}")
                print(f"FID with {N_fid} samples.")
        
        if ppl_cond:
            if accelerator.is_local_main_process:
                print("Calculating PPL")

            ppl = compute_ppl(n_samples=NperGPU_ppl, n_gpus=ngpus, sampling_shape=sampling_shape, num_inference_steps=num_inference_steps, sampler=pipeline, gen=generator, device=torch_device, accelerator=accelerator, epsilon=epsilon)
            accelerator.wait_for_everyone()
            ppl = accelerator.gather(torch.tensor(ppl).cuda())        
            if accelerator.is_local_main_process:
                ppl = ppl.mean().cpu().detach().numpy()
                print(f"PPL {ppl:.2f}")
                print(f"PPL with {N_ppl} samples.")

        if ppl_end_cond:
            if accelerator.is_local_main_process:
                print("Calculating PPL_end")

            ppl_end = compute_ppl_end(n_samples=NperGPU_ppl, n_gpus=ngpus, sampling_shape=sampling_shape, num_inference_steps=num_inference_steps, sampler=pipeline, gen=generator, device=torch_device, accelerator=accelerator)
            accelerator.wait_for_everyone()
            ppl_end = accelerator.gather(torch.tensor(ppl_end).cuda())        
            if accelerator.is_local_main_process:
                ppl_end = ppl_end.mean().cpu().detach().numpy()
                print(f"PPL_end {ppl_end:.2f}")
                print(f"PPL_end with {N_ppl} samples.")

        if accelerator.is_local_main_process:
            print("Calculating metrics done!")



# else:
#     step_weights = [x for x in os.listdir(weight_paths) if x.startswith('checkpoint')]
#     for step_weight in step_weights:
#         if accelerator.is_local_main_process:
#             print(f'\nPrinting metrics of {step_weight}')
#         unet_path = os.path.join(weight_paths, step_weight, 'unet')
#         unet = UNet2DModel_H.from_pretrained(unet_path)
#         scheduler = DDIMScheduler.from_pretrained(model_path)
#         scheduler.set_timesteps(num_inference_steps=20)
#         pipeline = DDPMPipeline_H(
#                             unet=unet,
#                             scheduler=scheduler,
#                         )
#         pipeline.set_progress_bar_config(disable=True)
#         torch_device = "cuda" if torch.cuda.is_available() else "cpu"
#         # unet.to(torch_device)
#         pipeline.to(torch_device)
#         generator = None

#         sampling_shape = [batch_size, 3, 256, 256]
#         if accelerator.is_local_main_process:
#             print("Calculating FID")
#         if accelerator.num_processes > 1:
#             act = compute_fid(NperGPU, ngpus, sampling_shape, pipeline, generator, fid_stats_path, torch_device, accelerator)
#             accelerator.wait_for_everyone()
#             # print("before",act[:10])
#             act = accelerator.gather(torch.tensor(act).cuda())
#             # print("after",act[:10])

#             if accelerator.is_local_main_process:
#                 act = act.cpu().detach().numpy()
#                 mu = np.mean(act, axis=0)
#                 sigma = np.cov(act, rowvar=False)
#                 m = torch.from_numpy(mu).cuda()
#                 s = torch.from_numpy(sigma).cuda()

#                 all_pool_mean = m.cpu().numpy()
#                 all_pool_sigma = s.cpu().numpy()

#                 stats = np.load(fid_stats_path)
#                 data_pools_mean = stats['mu']
#                 data_pools_sigma = stats['sigma']
#                 fid = calculate_frechet_distance(data_pools_mean, data_pools_sigma, all_pool_mean, all_pool_sigma)
#         else:
#             fid = compute_fid(NperGPU, ngpus, sampling_shape, pipeline, generator, fid_stats_path, torch_device, accelerator)
#         if accelerator.is_local_main_process:
#             print(f"FID {fid:.2f}")
#             print(f"FID with {N} samples.")

#         if accelerator.is_local_main_process:
#             print("Calculating PPL")
#         ppl = compute_ppl(n_samples=NperGPU, n_gpus=ngpus, sampling_shape=sampling_shape, sampler=pipeline, gen=generator, device=torch_device, accelerator=accelerator)
#         accelerator.wait_for_everyone()
#         ppl = accelerator.gather(torch.tensor(ppl).cuda())        
#         if accelerator.is_local_main_process:
#             ppl = ppl.mean().cpu().detach().numpy()
#             print(f"PPL {ppl:.2f}")
#             print(f"PPL with {N} samples.")
