# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Samples a large number of images from a pre-trained SiT model using DDP.
Subsequently saves a .npz file that can be used to compute FID and other
evaluation metrics via evaluation/evaluator.py
"""

import os
import glob
import math
import argparse
from typing import Dict

import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import torch.distributed as dist

from omegaconf import OmegaConf

from download import find_model
from diffusion import create_diffusion
from diffusion.rectified_flow_svg import RectifiedFlow
from utils import instantiate_from_config
from evaluation.inception import InceptionV3


def get_config(ckpt_path):
    exp_root = ckpt_path.split("/")[:-2]
    exp_name = exp_root[-1]
    exp_root = "/".join(exp_root)
    config_path = glob.glob(os.path.join(exp_root, "*.yaml"))

    try:
        print(config_path)
        assert len(config_path) == 1
    except:
        print(config_path)
        raise AssertionError("len(config_path) != 1")
    config_path = config_path[0]
    config = OmegaConf.load(config_path)
    return exp_name, config


def create_npz_from_sample_folder(sample_dir, num=50_000, batch_size=4):
    """
    Builds a single .npz file from a folder of feature samples.
    """
    activations = []
    for i in tqdm(range(num), desc="Building .npz file from samples"):
        feature = np.load(f"{sample_dir}/{i:06d}.npy")
        activations.append(feature)

    activations = np.concatenate(activations)
    assert activations.shape == (num, 2048)
    npz_path = f"{sample_dir}.npz"
    mu = np.nanmean(activations, axis=0)
    sigma = np.cov(activations, rowvar=False)
    np.savez(npz_path, activations=activations, mu=mu, sigma=sigma)
    print(f"Saved .npz file to {npz_path} [shape={activations.shape}].")
    return npz_path


def create_npz_from_sample_folder_png(sample_dir, num=50_000):
    """
    Builds a single .npz file from a folder of .png samples.
    """
    samples = []
    for i in tqdm(range(num), desc="Building .npz file from samples"):
        sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
        sample_np = np.asarray(sample_pil).astype(np.uint8)
        samples.append(sample_np)

    samples = np.stack(samples)
    assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
    npz_path = f"{sample_dir}.npz"
    np.savez(npz_path, arr_0=samples)
    print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
    return npz_path


@torch.no_grad()
def main(args):
    """
    Run distributed sampling.
    """
    torch.backends.cuda.matmul.allow_tf32 = args.tf32
    assert torch.cuda.is_available(), "DDP sampling requires at least one GPU."
    torch.set_grad_enabled(False)

    # Initialize DDP
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = args.global_seed * dist.get_world_size() + rank
    torch.manual_seed(seed)
    torch.cuda.set_device(device)
    print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")

    # Load model
    ckpt_path = args.ckpt
    if "{" not in ckpt_path:
        exp_name, config = get_config(args.ckpt)
        model = instantiate_from_config(config.model).to(device)
        state_dict = find_model(ckpt_path)
        model.load_state_dict(state_dict)
        print(f"Before Eval, model.training: {model.training}")
        model.eval()
        print(f"After Eval, model.training: {model.training}")
        model_string_name = exp_name
        ckpt_string_name = os.path.basename(args.ckpt).replace(".pt", "") if args.ckpt else "pretrained"
    else:
        # Load model for different timesteps
        print("Loading model for different timesteps...")
        ckpt_path = eval(ckpt_path)
        model = {}

        exp_name = ckpt_path["infer_exp_name"]
        ckpt_step = ckpt_path["ckpt_step"]
        del ckpt_path["infer_exp_name"]
        del ckpt_path["ckpt_step"]

        for k, v in ckpt_path.items():
            k = [int(_) for _ in k.split(",")]
            k[1] -= 1
            k = tuple(k)
            print("--------> loading model")
            print(f"Set timestep from: {k[0]} to {k[1]}")
            print(f"Using model: {v}")
            _, config = get_config(v)
            _model = instantiate_from_config(config.model).to(device)
            state_dict = find_model(v)
            _model.load_state_dict(state_dict)
            _model.eval()
            model[k] = _model
        model_string_name = exp_name
        ckpt_string_name = ckpt_step

    # Ensure 'rf' key exists
    print(config)
    if 'rf' not in config.basic:
        config.basic.rf = False

    inception = InceptionV3().to(device).eval()
    if config.basic.rf:
        print("Sampling with rectified flow")
        diffusion = RectifiedFlow(model)
    else:
        diffusion = create_diffusion(str(args.num_sampling_steps))

    # Dino decoder
    from SVG.svg_diffusion.ldm.models._decoder import DinoDecoder
    encoder_config = OmegaConf.load(config.basic.encoder_config)

    dinov3 = DinoDecoder(
        ddconfig=encoder_config.model.params.ddconfig,
        dinoconfig=encoder_config.model.params.dinoconfig,
        lossconfig=encoder_config.model.params.lossconfig,
        embed_dim=encoder_config.model.params.embed_dim,
        ckpt_path=encoder_config.ckpt_path,
        extra_vit_config=encoder_config.model.params.extra_vit_config,
    ).cuda().eval()
    z_channels = encoder_config.model.params.ddconfig.z_channels

    dinov3_sp_stats = torch.load("dinov3_sp_stats.pt")
    dinov3_sp_mean = dinov3_sp_stats["dinov3_sp_mean"].to(device)[:, :, :z_channels]
    dinov3_sp_std = dinov3_sp_stats["dinov3_sp_std"].to(device)[:, :, :z_channels]

    assert args.cfg_scale >= 1.0, "In most cases, cfg_scale should be >= 1.0"
    using_cfg = args.cfg_scale > 1.0
    if not using_cfg:
        print("Classifier-free guidance disabled.")

    # Prepare output folder
    vae_name = args.vae.split("-")[-1]
    folder_name = (
        f"{args.cfg_mode}_{args.tag}_{model_string_name}-{ckpt_string_name}-"
        f"size-{args.image_size}-cfg-{args.cfg_scale}-seed-{args.global_seed}-"
        f"FID-{int(args.num_fid_samples/1000)}K-bs{args.per_proc_batch_size}-"
        f"sampling_{args.num_sampling_steps}-shift{args.shift}-ema"
    )
    sample_folder_dir = f"{args.sample_dir}/npy/{folder_name}"
    if rank == 0:
        os.makedirs(sample_folder_dir, exist_ok=True)
        print(f"Saving .npy samples at {sample_folder_dir}")
    dist.barrier()

    # Compute number of samples per GPU
    n = args.per_proc_batch_size
    global_batch_size = n * dist.get_world_size()
    total_samples = int(math.ceil(args.num_fid_samples / global_batch_size) * global_batch_size)
    if rank == 0:
        print(f"Total number of images that will be sampled: {total_samples}")
    assert total_samples % dist.get_world_size() == 0
    samples_needed_this_gpu = total_samples // dist.get_world_size()
    assert samples_needed_this_gpu % n == 0
    iterations = samples_needed_this_gpu // n

    pbar = tqdm(range(iterations)) if rank == 0 else range(iterations)
    total = 0
    latent_size = args.image_size // 16

    # Sampling loop
    for _ in pbar:
        z = torch.randn(n, latent_size ** 2, z_channels, device=device)
        y = torch.randint(0, args.num_classes, (n,), device=device)
        y_null = None

        # Setup classifier-free guidance
        if using_cfg:
            z_cat = torch.cat([z, z], 0)
            y_null = torch.tensor([1000] * n, device=device)
            y_cat = torch.cat([y, y_null], 0)
            model_kwargs = dict(y=y_cat, cfg_scale=args.cfg_scale)

            if not isinstance(model, Dict):
                sample_fn = model.forward_with_cfg
            else:
                sample_fn = {k: v.forward_with_cfg for k, v in model.items()}
        else:
            z_cat = z
            model_kwargs = dict(y=y)
            if not isinstance(model, Dict):
                sample_fn = model.forward
            else:
                sample_fn = {k: v.forward for k, v in model.items()}

        # Run diffusion
        if config.basic.rf:
            samples = diffusion.sample(
                z, y, y_null,
                sample_steps=args.num_sampling_steps,
                cfg=args.cfg_scale,
                progress=False,
                mode=args.tag,
                timestep_shift=args.shift,
                cfg_mode=args.cfg_mode
            )[-1]

            if config.basic.get("feature_norm", False):
                samples = samples * dinov3_sp_std + dinov3_sp_mean

            B, T, D = samples.shape
            samples_dino_feature = samples.permute(0, 2, 1).reshape(B, D, latent_size, latent_size)
            samples = dinov3.decode(samples_dino_feature)

        else:
            diffusion = create_diffusion(str(args.num_sampling_steps))
            samples = diffusion.p_sample_loop(
                sample_fn,
                z_cat.shape,
                z_cat,
                clip_denoised=False,
                model_kwargs=model_kwargs,
                progress=False,
                device=device,
            )
            if using_cfg:
                samples, _ = samples.chunk(2, dim=0)

        # Clamp and extract inception features
        samples = torch.clamp(127.5 * samples + 128.0, 0, 255)
        inception_feature = inception(samples / 255.).cpu().numpy()

        index = rank + total
        np.save(f"{sample_folder_dir}/{index:06d}.npy", inception_feature)
        total += global_batch_size

    # Barrier before saving .npz
    dist.barrier()
    if rank == 0:
        def get_all_filenames_in_folder(folder_path):
            if not os.path.isdir(folder_path):
                print(f"Error: {folder_path} is not a valid path.")
                return []
            return os.listdir(folder_path)

        sample_dir = sample_folder_dir + '/'
        filenames = get_all_filenames_in_folder(sample_dir)

        def create_npz_from_sample_folder(sample_dir, num=args.num_fid_samples, batch_size=200):
            activations = []
            for name in tqdm(filenames):
                feature = np.load(sample_dir + name)
                activations.append(feature)

            activations = np.concatenate(activations)
            print(activations.shape)
            assert activations.shape == (num, 2048)
            npz_path = f"samples/{folder_name}.npz"
            mu = np.mean(activations, axis=0)
            sigma = np.cov(activations, rowvar=False)
            np.savez(npz_path, activations=activations, mu=mu, sigma=sigma)
            print(f"Saved .npz file to {npz_path} [shape={activations.shape}].")
            return npz_path

        print(filenames)
        create_npz_from_sample_folder(sample_dir, num=total_samples)
        print("Done.")

    dist.barrier()
    dist.destroy_process_group()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--vae", type=str, default="stabilityai/sd-vae-ft-mse")
    parser.add_argument("--sample-dir", type=str, default="samples")
    parser.add_argument("--per-proc-batch-size", type=int, default=32)
    parser.add_argument("--num-fid-samples", type=int, default=50_000)
    parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--cfg-scale", type=float, default=1.5)
    parser.add_argument("--shift", type=float, default=0.0)
    parser.add_argument("--num-sampling-steps", type=int, default=250)
    parser.add_argument("--global-seed", type=int, default=0)
    parser.add_argument("--tf32", action="store_false")
    parser.add_argument("--ckpt", type=str, default=None,
                        help="Optional path to a checkpoint (default: auto-download a pre-trained model).")
    parser.add_argument("--tag", type=str, default="")
    parser.add_argument("--cfg_mode", type=str, default="constant")

    args = parser.parse_args()
    main(args)
