import sys
import os

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from argparse import ArgumentParser
from pipeline_metaquery import MetaQueryPipeline
import torch
from datasets import load_dataset
from tqdm import tqdm
from trainer_utils import find_newest_checkpoint
from torchvision.transforms import v2

preprocess_image = v2.Compose(
    [
        v2.Lambda(lambda img: img.convert("RGB")),
        v2.Resize(384, interpolation=v2.InterpolationMode.BICUBIC),
        v2.CenterCrop(384),
    ]
)


def sample_metaquery(pipeline, images, args):
    return pipeline(
        image=images,
        caption="",
        num_inference_steps=args.num_inference_steps,
        guidance_scale=args.guidance_scale,
        image_guidance_scale=args.image_guidance_scale,
        enable_progress_bar=False,
    ).images


def main(args):
    args.checkpoint_path = find_newest_checkpoint(args.checkpoint_path)
    pipeline = MetaQueryPipeline.from_pretrained(
        args.checkpoint_path,
        ignore_mismatched_sizes=True,
        _gradient_checkpointing=False,
        torch_dtype=torch.bfloat16,
    )
    pipeline = pipeline.to(device="cuda", dtype=torch.bfloat16)
    sample_fn = sample_metaquery
    os.makedirs(args.output_dir, exist_ok=True)

    dataset = load_dataset(
        "ILSVRC/imagenet-1k",
        split="validation",
        cache_dir=args.dataset_folder,
        trust_remote_code=True,
        num_proc=12,
    )
    dataset = dataset.remove_columns("label")
    dataset = dataset.select(
        range(args.start_idx, args.end_idx if args.end_idx != -1 else len(dataset))
    )

    for i in tqdm(range(0, len(dataset), args.batch_size)):
        data = dataset[i : i + args.batch_size]["image"]
        data = [preprocess_image(img) for img in data]
        images = sample_fn(pipeline, data, args)

        for j, image in enumerate(images):
            image.save(f"{args.output_dir}/{args.start_idx+i+j:05d}.jpg")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument("--dataset_folder", type=str, default="/path/to/dataset_folder")
    parser.add_argument("--start_idx", type=int, default=0)
    parser.add_argument("--end_idx", type=int, default=-1)

    parser.add_argument("--output_dir", type=str, default="/path/to/output_dir")
    parser.add_argument(
        "--checkpoint_path", type=str, default="/path/to/checkpoint_path"
    )
    parser.add_argument("--guidance_scale", type=float, default=3.0)
    parser.add_argument("--image_guidance_scale", type=float, default=3.0)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--num_inference_steps", type=int, default=30)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()
    main(args)
