import argparse
import random
from pathlib import Path
from PIL import Image
import math

import numpy as np
import torch
import torch.nn.functional as F
import yaml
from tqdm import tqdm

from utils import num_to_groups
from multihead_edm.model import MultiHeadEDM

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    if images.shape[-1] == 1:
        # special case for grayscale (single channel) images
        pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
    else:
        pil_images = [Image.fromarray(image) for image in images]

    return pil_images

def main(args):

    # Set random seeds for reproducibility
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # device
    device = torch.device(args.device)

    # Create work directory if it doesn't exist
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    synth_dir = output_dir / 'images'
    synth_dir.mkdir(parents=True, exist_ok=True)

    for i in range(args.num_classes):
        class_dir = synth_dir / str(i)
        class_dir.mkdir(parents=True, exist_ok=True)

    with open(args.config_path, "r") as f:
        config = yaml.safe_load(f)
    
    model = MultiHeadEDM(**config)
    model.load_state_dict(torch.load(output_dir / "model.pt",  map_location=args.device))
    model = model.to(device)
    model.eval()

    num_per_images = args.num_samples // args.num_classes
    conditions = []
    for i in range(args.num_classes):
        y = [i] * num_per_images
        conditions.extend(y)
    
    assert len(conditions) == args.num_samples

    conditions = num_to_groups(conditions, args.batch_size)

    num_per_class_count = [0] * args.num_classes
    with torch.no_grad():
        for y_cond in tqdm(conditions, desc="Generating images"):
            y_torch = torch.tensor(y_cond).to(device)
            y_onehot = F.one_hot(y_torch, num_classes=args.num_classes)

            images = model.inference(y_onehot)
            images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()

            images_pil = [
                Image.fromarray(image.squeeze(), mode="L") if image.shape[-1] == 1
                else Image.fromarray(image, mode="RGB") if image.shape[-1] == 3
                else Image.fromarray(image)
                for image in images_np
            ]

            for img, y in zip(images_pil, y_cond):
                img.save(synth_dir / str(y) / f'{num_per_class_count[y]:05d}.png')
                num_per_class_count[y] += 1

if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument("--seed", type=int, default=42)

    parser.add_argument("--output_dir", type=str, default="./expr/mnist/multihead_edm")
    parser.add_argument("--config_path", type=str, default="./multihead_edm/config_mnist.yaml")

    parser.add_argument("--batch_size", type=int, default=256)

    parser.add_argument('--num_samples', type=int, default=50000)
    parser.add_argument("--num_classes", type=int, default=10)

    args = parser.parse_args()

    main(args)