#!/usr/bin/env python
# generate_var_dataset.py
#
# Utility that loads the VAR model + VAE and generates N images for a SINGLE class,
# saving them under out_dir/train/<class_name>/*.png
#
# Changes from original:
#  - Adds --class_id and only generates for that class
#  - Generates exactly --num_images total (not per-class)
#  - Uses a fixed seed for full reproducibility
#  - Safer distributed barriers (no-op if not initialized)
#
# Note: The imagenet class map is expected at imagenet/imagenet_class_to_idx.json

import argparse
import json
import os
import random
import shutil
from pathlib import Path

import numpy as np
import torch
import torch.distributed as dist
from PIL import Image

from models import build_vae_var

# -----------------------------------------------------------------------------#
# Helpers                                                                      #
# -----------------------------------------------------------------------------#

def init_distributed():
    if dist.is_initialized():
        return dist.get_rank(), int(os.environ.get("LOCAL_RANK", 0)), dist.get_world_size()
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"])
        world = int(os.environ["WORLD_SIZE"])
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        torch.cuda.set_device(local_rank)
        dist.init_process_group(backend="nccl", init_method="env://")
    else:
        rank = local_rank = 0
        world = 1
    return rank, local_rank, world


def maybe_barrier():
    if dist.is_initialized():
        dist.barrier()


def maybe_destroy_pg():
    if dist.is_initialized():
        dist.destroy_process_group()


def set_seed(seed: int):
    # Fixed, reproducible seed (do NOT offset by rank)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def save_tensor_as_png(t: torch.Tensor, path: Path):
    img = (t.clamp(0, 1) * 255).to(torch.uint8).cpu()
    arr = img.permute(1, 2, 0).contiguous().numpy()
    Image.fromarray(arr).save(path, compress_level=3)

# -----------------------------------------------------------------------------#
# Argument parsing                                                             #
# -----------------------------------------------------------------------------#

def parse_args():
    p = argparse.ArgumentParser('Generate VAR samples for a single class')
    p.add_argument('--out_dir',     type=str, required=True,
                   help='Root folder to save PNGs (will create train/<class> subfolder)')
    p.add_argument('--class_id',    type=int, required=True,
                   help='Class ID (index) to generate from (as defined by imagenet_class_to_idx.json)')
    p.add_argument('--num_images',  type=int, default=750,
                   help='Total number of images to generate (for this one class)')
    p.add_argument('--batch_size',  type=int, default=64,
                   help='Batch size per device')
    p.add_argument('--cfg',         type=float, default=4.0,
                   help='Classifier-free guidance scale')
    p.add_argument('--top_k',       type=int, default=900,
                   help='Top-k sampling')
    p.add_argument('--top_p',       type=float, default=0.95,
                   help='Top-p sampling')
    p.add_argument('--more_smooth', action='store_true',
                   help='Use the more_smooth flag in VAR sampler')
    p.add_argument('--model_depth', type=int, default=16, choices=[16],
                   help='VAR model depth (affects checkpoint name)')
    p.add_argument('--seed',        type=int, default=123,
                   help='Fixed seed for full reproducibility')
    p.add_argument('--cuda',        action='store_true',
                   help='Unused (kept for compatibility)')
    return p.parse_args()

# -----------------------------------------------------------------------------#
# Main                                                                          #
# -----------------------------------------------------------------------------#
@torch.inference_mode()
def main():
    args = parse_args()
    rank, local_rank, world = init_distributed()
    master = (rank == 0)

    # Fixed, reproducible seed
    set_seed(args.seed)

    # Constants / checkpoints
    VAE_CKPT   = 'vae_ch160v4096z32.pth'
    VAR_CKPT   = 'merged_models/ns750k/merged_aux1250304_w1.0.pth'
    CLASS_MAP  = 'imagenet/imagenet_class_to_idx.json'

    # Load class mapping (name -> idx), and derive ordered class list by idx
    class_map = json.load(open(CLASS_MAP, 'r'))
    class_list = sorted(class_map, key=lambda k: class_map[k])

    # Validate class id
    num_classes = len(class_list)
    assert 0 <= args.class_id < num_classes, f"--class_id must be in [0, {num_classes-1}]"
    cls_idx = args.class_id
    cls_name = class_list[cls_idx]

    # Prepare output folder for the single class
    out_dir = Path(args.out_dir) 
    if master:
        # Only clear THIS class folder to avoid nuking others
        shutil.rmtree(out_dir, ignore_errors=True)
        out_dir.mkdir(parents=True, exist_ok=True)
    maybe_barrier()

    # Download checkpoints if needed (only on master)
    hf = "https://huggingface.co/FoundationVision/var/resolve/main"
    if master:
        for ck in (VAE_CKPT, VAR_CKPT):
            if not os.path.exists(ck):
                os.system(f"wget {hf}/{ck}")
    maybe_barrier()

    # Build models
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    # Prevent reinit surprises
    setattr(torch.nn.Linear, "reset_parameters", lambda *_: None)
    setattr(torch.nn.LayerNorm, "reset_parameters", lambda *_: None)

    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,
        device=device, patch_nums=(1,2,3,4,5,6,8,10,13,16),
        num_classes=num_classes, depth=args.model_depth, shared_aln=False,
    )

    ckpt = torch.load(VAR_CKPT, map_location='cpu')
    sd = ckpt.get('var_wo_ddp', ckpt.get('state_dict', ckpt))
    var.load_state_dict(sd, strict=True)

    vae.load_state_dict(torch.load(VAE_CKPT, map_location="cpu"))
    vae.eval().requires_grad_(False)
    var.eval().requires_grad_(False)

    # Compute global sharding of N across ranks (deterministic)
    N = int(args.num_images)
    base = N // world
    rem  = N % world
    local_N = base + (1 if rank < rem else 0)
    start   = rank * base + min(rank, rem)

    if local_N == 0:
        maybe_barrier()
        if master:
            print(f"No work for rank {rank}.")
        maybe_destroy_pg()
        return

    # Labels for this class only
    labels = torch.full((local_N,), cls_idx, device=device, dtype=torch.long)

    steps = (local_N + args.batch_size - 1) // args.batch_size
    ptr = 0
    while ptr < local_N:
        cur_B = min(args.batch_size, local_N - ptr)
        lbl = labels[ptr:ptr+cur_B]

        # Deterministic per-sample seed: seed + global_id
        # (global_id ranges from 0..N-1 across all ranks)
        g_seed = int(args.seed) + (start + ptr)

        imgs = var.autoregressive_infer_cfg(
            B=cur_B, label_B=lbl,
            cfg=args.cfg, top_k=args.top_k, top_p=args.top_p,
            g_seed=g_seed, more_smooth=args.more_smooth,
        )

        for i in range(cur_B):
            gid = start + ptr + i  # global index in [0, N-1]
            save_tensor_as_png(imgs[i], out_dir / f"{gid:05d}.png")

        ptr += cur_B

    maybe_barrier()
    if master:
        print(f"Generated {N} images for class '{cls_name}' (id={cls_idx}) under '{out_dir}'.")
    maybe_destroy_pg()


if __name__ == '__main__':
    main()
