#!/usr/bin/env python
# generate_and_fid.py
#
# Simplified utility: loads a trained xAR base model + VAE + one aux model,
# merges parameters with weight w, uses guidance scale cfg,
# generates N images, computes FID/IS, saves scores to JSON under an epoch folder, then deletes images.

import argparse
import os
import time
import shutil
import json
import random
import contextlib
from pathlib import Path

import cv2
import numpy as np
import torch
import torch_fidelity
from tqdm import tqdm

from util import misc
from models.vae import AutoencoderKL
from models import xar


def parse_args():
    p = argparse.ArgumentParser('Generate xAR samples + compute FID/IS')
    p.add_argument('--model', type=str, required=True, help='Which xar.* constructor to use')
    p.add_argument('--model_ckpt', type=str, required=True, help='Path to base .pth weights')
    p.add_argument('--model_ckpt_aux', type=str, required=True, help='Path to aux .pth weights (checkpoint-{epoch}.pth)')
    p.add_argument('--w', type=float, required=True, help='Merge weight')
    p.add_argument('--cfg', type=float, required=True, help='Classifier-free guidance scale')
    p.add_argument('--vae_path', type=str, required=True, help='KL-VAE checkpoint path')
    p.add_argument('--num_images', type=int, default=50000, help='Total images to generate')
    p.add_argument('--batch_size', type=int, default=64, help='Images per device per forward pass')
    p.add_argument('--flow_steps', type=int, default=64, help='Diffusion/flow steps')
    p.add_argument('--img_size', type=int, default=256, help='Image resolution for stats')
    p.add_argument('--class_num', type=int, default=1000, help='Number of classes')
    p.add_argument('--device', type=str, default='cuda', help='Torch device')
    p.add_argument('--fid_stats', type=str, default=None, help='Precomputed FID stats file')
    p.add_argument('--no_isc', action='store_true', help='Skip Inception Score')
    p.add_argument('--cuda', action='store_true', help='Use GPU for metrics')
    p.add_argument('--seed', type=int, default=42, help='Random seed for reproducibility')
    p.add_argument('--dist_url', default='env://')
    p.add_argument('--world_size', type=int, default=1)
    p.add_argument('--local_rank', type=int, default=-1)
    p.add_argument('--dist_on_itp', action='store_true', default=False, help='ITP cluster flag')
    p.add_argument('--out_dir', type=str, default='temp_images', help='Folder for generated images')
    p.add_argument('--out_dir_fid', type=str, required=True, help='Folder to save FID/IS JSON')
    return p.parse_args()


def load_checkpoint(path):
    return torch.load(path, map_location='cpu')


def prepare_labels(num_images, class_num, batch_size, world):
    per = num_images // class_num
    labels = np.arange(class_num).repeat(per)
    labels = torch.tensor(labels, dtype=torch.long)
    total = labels.numel()
    bs_world = batch_size * world
    if total % bs_world != 0:
        pad = bs_world - (total % bs_world)
        labels = torch.cat([labels, torch.zeros(pad, dtype=torch.long)], 0)
    batches = labels.numel() // bs_world
    return labels, batches, bs_world


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


@torch.inference_mode()
def main():
    args = parse_args()

    # --- quietly initialize distributed ---
    with open(os.devnull, 'w') as devnull, contextlib.redirect_stdout(devnull):
        misc.init_distributed_mode(args)
    rank, world = misc.get_rank(), misc.get_world_size()
    # ---------------------------------------

    set_seed(args.seed)

    out_root = Path(args.out_dir)
    fid_root = Path(args.out_dir_fid)
    if rank == 0:
        # only clear the image directory
        shutil.rmtree(out_root, ignore_errors=True)
        out_root.mkdir(parents=True, exist_ok=True)
        # leave fid_root untouched, but ensure it exists
        fid_root.mkdir(parents=True, exist_ok=True)

    vae = AutoencoderKL(embed_dim=16, ch_mult=(1,1,2,2,4), ckpt_path=args.vae_path)
    vae.eval().to(args.device)
    for p in vae.parameters(): p.requires_grad = False
    
    assert args.model in xar.__dict__, f"Model '{args.model}' not found"
    xar_model = xar.__dict__[args.model](
        img_size=args.img_size, vae_stride=16, patch_size=1,
        vae_embed_dim=16, class_num=args.class_num,
        attn_dropout=0.1, proj_dropout=0.1
    ).to(args.device)

    base_ckpt = load_checkpoint(args.model_ckpt)
    base_state = base_ckpt.get('state_dict', base_ckpt)
    print(args.model_ckpt_aux)
    aux_ckpt  = load_checkpoint(args.model_ckpt_aux)
    aux_state = aux_ckpt.get('model_ema', aux_ckpt.get('state_dict', aux_ckpt))

    merged = {}
    w = args.w
    for k, b in base_state.items():
        a = aux_state.get(k)
        if a is None or not torch.isfinite(b).all() or not torch.isfinite(a).all():
            merged[k] = b
        else:
            merged[k] = (1 + w) * b - w * a

    xar_model.load_state_dict(merged, strict=True)
    xar_model.eval()
    model = (torch.nn.parallel.DistributedDataParallel(
        xar_model, device_ids=[args.local_rank], output_device=args.local_rank
    ) if world > 1 else xar_model)

    cfg = args.cfg
    labels, total_batches, bs_world = prepare_labels(
        args.num_images, args.class_num, args.batch_size, world)

    tag = f"aux_{Path(args.model_ckpt_aux).stem}_w{w:.3f}_cfg{cfg:.3f}"
    img_dir = out_root / tag
    if rank == 0:
        img_dir.mkdir(parents=True)

    set_seed(args.seed)

    start = time.time()
    it = tqdm(
        range(total_batches),
        desc=f"Gen-{tag}",
        dynamic_ncols=True,
        leave=False
    ) if rank == 0 else range(total_batches)

    for b in it:
        idx0 = b * bs_world + rank * args.batch_size
        if idx0 >= args.num_images:
            break
        idx1 = idx0 + args.batch_size
        lbl = labels[idx0:idx1].to(args.device)
        with torch.cuda.amp.autocast():
            sampler = model.module.sample_tokens if world > 1 else model.sample_tokens
            z = sampler(num_steps=args.flow_steps, cfg=cfg, label=lbl)
            imgs = vae.decode(z / 0.2325)
            imgs = (imgs + 1) / 2
        arr = imgs.mul(255).clamp_(0,255).byte().cpu().permute(0,2,3,1).numpy()
        for j, im in enumerate(arr):
            gid = idx0 + j
            if gid >= args.num_images:
                break
            cv2.imwrite(str(img_dir / f"{gid:05d}.png"), im[:,:,::-1])

    torch.distributed.barrier()
    if rank == 0:
        stats = args.fid_stats or ('fid_stats/adm_in256_stats.npz' if args.img_size == 256 else None)
        if stats is None:
            raise NotImplementedError(f"No default stats for img_size={args.img_size}")

        m = torch_fidelity.calculate_metrics(
            input1=str(img_dir), input2=None,
            fid_statistics_file=stats, cuda=args.cuda,
            isc=not args.no_isc, fid=True, kid=False, prc=False, verbose=False
        )
        fid_val = m['frechet_inception_distance']
        isc_mean = m.get('inception_score_mean')
        isc_std  = m.get('inception_score_std')
        print(f"w = {w:.1f}, cfg = {cfg:.1f} ====>  FID: {fid_val:.4f}, IS: {isc_mean:.4f} ± {isc_std:.4f}")

        # extract epoch from aux checkpoint filename
        epoch_str = Path(args.model_ckpt_aux).stem.split('-')[-1]
        epoch_dir = fid_root / epoch_str
        epoch_dir.mkdir(parents=True, exist_ok=True)

        # save metrics JSON under the epoch folder
        metrics = {'fid': fid_val, 'w': w, 'cfg': cfg}
        if isc_mean is not None:
            metrics.update({'is_mean': isc_mean, 'is_std': isc_std})
        fname = f"cfg={cfg:.3f}_w={w:.3f}.json"
        with open(epoch_dir / fname, 'w') as fout:
            json.dump(metrics, fout, indent=2)

        # cleanup only the images
        shutil.rmtree(img_dir)


if __name__ == '__main__':
    main()
