import numpy as np
import torch
import os, sys
from tqdm import tqdm
from argparse import ArgumentParser
from utils import RNG
import imageio
import torchvision.utils as vutils
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import shutil
from pathlib import Path

from data import cifar10, ffhq256, xray
from vae_helpers import sample_part_images

# Model interfaces
import interface as vae_interface
interface_dict = {"vae": vae_interface, }
test_path_dict = {"cifar10": {"vae": "saved_models/9fpka0rb/iter-60000", },
                  "ffhq256": {"vae": None, }}


device = "cuda" if torch.cuda.is_available() else "cpu"


def tensor2png(img_tensor, path):
    img_numpy = img_tensor.cpu().permute(1,2,0).numpy()
    img_numpy = (img_numpy * 255).astype(np.uint8)
    imageio.imwrite(path, img_numpy)


def sample_mask(opt, img, categories=None, seed=None, *args, **kwargs):
    # opt shoudld have the following attributes:
    # conditioning, max_patches, patch_size_frac, and kls (only for foveal conditioning)
    if isinstance(categories, int):
        categories = torch.ones(len(img)).int() * categories
    def f():
        x = sample_part_images(
            opt, img.permute(0, 2, 3, 1),
            categories=categories, *args, **kwargs
            )[..., -1]
        return x.contiguous()
    if seed is not None:
        with RNG(seed):
            return f()
    else:
        return f()

def sample_interesting_mask_from_shape(opt, shape, categories=None, seed=None, *args, **kwargs):
    # try different seeds and get one leading to nicest 5-patch mask
    best_seed = seed
    most_obs = 0
    for _ in range(50):
        seed = ((seed+13)*123)%7890  # arbitrary transform to try various seeds while depending on initial seed
        mask = sample_mask_from_shape(opt, shape, categories=5, seed=seed, *args, **kwargs)
        obs = mask.sum()
        if obs > most_obs:
            best_seed = seed
            most_obs = obs
    return sample_mask_from_shape(opt, shape, categories=categories, seed=best_seed, *args, **kwargs)
    


def sample_mask_from_shape(opt, shape, categories=None, seed=None, *args, **kwargs):
    # opt shoudld have the following attributes:
    # conditioning, max_patches, patch_size_frac, and kls (only for foveal conditioning)
    assert len(shape) == 3 # Only accepts shapes without batch dimension
    mask = sample_mask(opt, torch.zeros(1, *shape),
                       categories=categories, seed=seed,
                       *args, **kwargs)
    return mask[0] # Drop the batch dimension


def main(args):
    print(args.model_path)
    ## Prepare the model path and type objects
    assert args.model_type is not None
    interface = interface_dict[args.model_type]
    model_path = test_path_dict[args.dataset][args.model_type] if args.model_path is None else args.model_path
    results_dir = f"{os.path.splitext(model_path)[0]}_qual"
    os.makedirs(results_dir, exist_ok=True)
    results_path = os.path.join(results_dir, f"{args.conditioning}_{args.img_idx}.png")
    if os.path.exists(results_path):
        print(f"{results_path} already exists. Skipping ...")
        return
    ## Prepare the inpainting model
    inpainting_model = interface.create_model_from_path(model_path)
    inpainting_model.to(device)
    print("Inpainting model loaded")
    ## Initialize the dataset
    if args.dataset == "cifar10":
        _, _, (teX, _) = cifar10(args.data_root, one_hot=False)
    elif args.dataset == 'ffhq256':
        _, _, teX = ffhq256(args.data_root)
    elif args.dataset == "xray":
        _, _, teX = xray(args.data_root)
        teX = ImageFolder(teX, transforms.ToTensor())  # normalized to 0, 1
    print("Dataset loaded")
    ## Dataset items should be images with shape (3x32x32) and pixel values in [0-1]
    ## teX originally has items with shape (32x32x3) and pixel values in [0-255] in uint8 type.
    if args.dataset == "xray":
        img = teX[args.img_idx][0]
    else:
        teX = torch.as_tensor(teX).permute(0, 3, 1, 2) / 255.
        img = teX[args.img_idx]
    img = img.to(device)
    results = []
    for categories in tqdm(range(args.max_patches)):
        ## Prepare the mask
        mask = sample_interesting_mask_from_shape(args, img.shape, categories=categories, seed=args.img_idx)
        mask = mask.to(device)
        ## Prepare the masked image
        masked_img = mask.unsqueeze(0) * img
        ## Prepare the generator object
        with torch.no_grad():
            inpainted_batch = [inpainting_model.inpaint(img.clone().unsqueeze(0),
                                                        mask.clone().unsqueeze(0)).cpu()
                            for _ in range(args.n)]
            results.append(img.unsqueeze(0).cpu())
            results.append(masked_img.unsqueeze(0).cpu())
            results.extend(inpainted_batch)
    results = torch.cat(results, dim=0)
    to_save = vutils.make_grid(results, nrow=args.n+2, range=(0,1))
    tensor2png(to_save, results_path)
    print(f"Saved to {results_path}")


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--dataset', choices=["cifar10", "ffhq256", "xray"])
    parser.add_argument('--model_type', type=str, choices=["np", "pluralistic", "vae"])
    parser.add_argument('--model_path', type=str)
    parser.add_argument('--conditioning', type=str,
                        choices=['patches', 'patches-missing', 'blank', 'foveal'], default='patches')
    parser.add_argument('--max_patches', type=int, default=10)
    parser.add_argument('-n', type=int, default=10)
    parser.add_argument('--patch_size_frac', type=float, default=0.35,
                        help="Patch width as fraction of image width.")
    parser.add_argument('--data_root', type=str, default='./')
    parser.add_argument('--all', action="store_true")
    args = parser.parse_args()
    print(args)

    if not args.all:
        assert args.dataset is not None
        for args.img_idx in range(10):
            main(args)
    else:
        for conditioning in ['patches', 'patches-missing']:
            args.conditioning = conditioning
            for model_type_dir in Path("checkpoints_shared").iterdir():
                if model_type_dir.name.startswith("vae"):
                    model_type = "vae"
                elif model_type_dir.name.startswith("ce"):
                    model_type = "ce"
                elif model_type_dir.name.startswith("np"):
                    model_type = "np"
                else:
                    model_type = model_type_dir.name
                assert model_type in ["np", "pluralistic", "vae", "ce"]
                for dataset_dir in model_type_dir.iterdir():
                    dataset = dataset_dir.name
                    assert dataset in ["cifar10", "ffhq256"]
                    for run_dir in dataset_dir.iterdir():
                        args.model_type = model_type
                        args.dataset = dataset
                        print(f"\n>>>> {conditioning} / {model_type_dir.name} / {args.dataset} <<<<")
                        if model_type == "vae":
                            ## Path format: model_type/dataset/run_id/iter-*/
                            for model_path in run_dir.iterdir():
                                if not model_path.is_dir() or model_path.name.endswith("_qual"):
                                    continue
                                assert model_path.name.startswith("iter")
                                args.model_path = model_path
                                for args.img_idx in range(10):
                                    main(args)
                        elif model_type == "pluralistic":
                            ## Path format: model_type/dataset/run_id/*.pth
                            for model_path in run_dir.glob("*.pth"):
                                args.model_path = model_path
                                for args.img_idx in range(10):
                                    main(args)
                        elif model_type == "ce" or model_type == "np":
                            ## Path format: model_type/dataset/run_id/*.pt
                            for model_path in run_dir.glob("*.pt"):
                                args.model_path = model_path
                                for args.img_idx in range(10):
                                    main(args)
                        else:
                            assert False
