from diffusers import UNet2DModel_H, DDIMScheduler
from torchvision.utils import save_image
import torch
import numpy as np
import os
from diffusers import DDPMPipeline, DDPMScheduler, DDIMScheduler, UNet2DModel, UNet2DModel_H, DDPMPipeline_H
from eval import compute_fid, compute_ppl, add_dimensions, compute_distortion_per_timesteps
from accelerate import Accelerator
from accelerate import PartialState

distributed_state = PartialState()
torch_device = distributed_state.device

model_path = "google/ddpm-celebahq-256"
fid_stats_path = "assets/stats/mattymchen_celeba.npz"

# weight_paths = "output_jae/test_1"
weight_paths = [f"output_jae/test_12/checkpoint-{n}/unet" for n in range(4500, 5000, 300)]
N = 1000
batch_size = 64

if type(weight_paths) is list:
# weights = ["google/ddpm-celebahq-256", \
#            "output/isno_mean_radius_inv_radius_nocreasing_15step/checkpoint-50000/unet", \
#            "output/isno_mean_radius_inv_radius_decreasing_20step/checkpoint-14000/unet", \
#            "output/isno_mean_radius_inv_radius_nodecreasing_10step_2/checkpoint-21000/unet"]
#weights = ["output/isno_mean_radius_inv_radius_nocreasing_15step/checkpoint-50000/unet", "output/isno_mean_radius_inv_radius_decreasing_20step/checkpoint-14000/unet"]
    for weight in weight_paths:
        print(f'\nPrinting path lengths of {weight}')
        # unet_path = os.path.join(unet_folder, step_weight, 'unet')
        unet = UNet2DModel_H.from_pretrained(weight)
        scheduler = DDIMScheduler.from_pretrained(model_path)
        scheduler.set_timesteps(num_inference_steps=20)
        pipeline = DDPMPipeline_H(
                            unet=unet,
                            scheduler=scheduler,
                        )
        pipeline.set_progress_bar_config(disable=True)
        # torch_device = "cuda" if torch.cuda.is_available() else "cpu"
        # unet.to(torch_device)
        pipeline.to(torch_device)
        generator = None

        sampling_shape = [batch_size, 3, 256, 256]
        print("Calculating FID")
        fid = compute_fid(N, 1, sampling_shape, pipeline, generator, fid_stats_path, torch_device)
        print(f"FID {fid:.2f}")
        print(f"FID with {N} samples.")

        print("Calculating PPL")
        ppl = compute_ppl(n_samples=N, n_gpus=1, sampling_shape=sampling_shape, sampler=pipeline, gen=generator, device=torch_device)
        print(f"PPL {ppl:.2f}")
        print(f"PPL with {N} samples.")

else:
    step_weights = [x for x in os.listdir(weight_paths) if x.startswith('checkpoint')]
    for step_weight in step_weights:
        print(f'\nPrinting path lengths of {step_weight}')
        unet_path = os.path.join(weight_paths, step_weight, 'unet')
        unet = UNet2DModel_H.from_pretrained(unet_path)
        scheduler = DDIMScheduler.from_pretrained(model_path)
        scheduler.set_timesteps(num_inference_steps=20)
        pipeline = DDPMPipeline_H(
                            unet=unet,
                            scheduler=scheduler,
                        )
        pipeline.set_progress_bar_config(disable=True)
        torch_device = "cuda" if torch.cuda.is_available() else "cpu"
        # unet.to(torch_device)
        pipeline.to(torch_device)
        generator = None

        sampling_shape = [batch_size, 3, 256, 256]
        print("Calculating FID")
        fid = compute_fid(N, 1, sampling_shape, pipeline, generator, fid_stats_path, torch_device)
        print(f"FID {fid:.2f}")
        print(f"FID with {N} samples.")

        print("Calculating PPL")
        ppl = compute_ppl(n_samples=N, n_gpus=1, sampling_shape=sampling_shape, sampler=pipeline, gen=generator, device=torch_device)
        print(f"PPL {ppl:.2f}")
        print(f"PPL with {N} samples.")
