from collections import OrderedDict
import numpy as np
import torch
import os, sys
from tqdm import tqdm
from argparse import ArgumentParser
import wandb
from utils import RNG
import imageio
import datetime
from torch.utils.data import Dataset, TensorDataset, DataLoader, ConcatDataset
import tempfile

from data import cifar10, ffhq256
from vae_helpers import sample_part_images
from fid import calculate_fid_given_paths
import shutil
import json

# Model interfaces
import interface as vae_interface
interface_dict = {"vae": vae_interface, }
test_path_dict = {"cifar10": {"vae": "saved_models/9fpka0rb/iter-60000", },
                  "ffhq256": {"vae": None, }}


DEBUG = False
DEBUG_DIR, DEBUG_COUNTER, DEBUG_LABEL = f"debug/{datetime.datetime.now().strftime('%y%m%d_%H%M%S')}", 0, ""
INCEPTION_N = int(5e4) if not DEBUG else 5
INCEPTION_SPLITS = 10
device = "cuda" if torch.cuda.is_available() else "cpu"


def tensor2png(img_tensor, path):
    img_numpy = img_tensor.cpu().permute(1,2,0).numpy()
    img_numpy = (img_numpy * 255).astype(np.uint8)
    imageio.imwrite(path, img_numpy)


def sample_mask(opt, img, categories=None, seed=None, *args, **kwargs):
    # opt shoudld have the following attributes:
    # conditioning, max_patches, patch_size_frac, and kls (only for foveal conditioning)
    if isinstance(categories, int):
        categories = torch.ones(len(img)).int() * categories
    def f():
        x = sample_part_images(
            opt, img.permute(0, 2, 3, 1),
            categories=categories, *args, **kwargs
            )[..., -1]
        return x.contiguous()
    if seed is not None:
        with RNG(seed):
            return f()
    else:
        return f()


class ScaledTensorDataset(Dataset):
    def __init__(self, tensor):
        self.tensor = tensor

    def __getitem__(self, index):
        return self.tensor[index] / 255., 0

    def __len__(self):
        return len(self.tensor)


@torch.no_grad()
def fid(H, data, model, mode, N):
    assert mode in ["valid", "test", "all"]
    (trX, vaX, teX) = data
    if mode == "valid":
        dataset = ScaledTensorDataset(vaX)
    elif mode == "test":
        dataset = ScaledTensorDataset(teX)
    elif mode == "all":
        dataset = ConcatDataset([ScaledTensorDataset(trX),
                                 ScaledTensorDataset(vaX),
                                 ScaledTensorDataset(teX)])

    base_dir = os.environ['TMP_DIR'] if 'TMP_DIR' in os.environ else '.'
    base_dir = os.path.join(base_dir, "FID")
    os.makedirs(base_dir, exist_ok=True)

    # generate and save a bunch of images
    sample_dir = tempfile.mkdtemp(dir=base_dir, prefix=f"{mode}_")
    n_categories = {'patches': 6, 'patches-missing': 6, 'blank': 1, 'foveal': 6}[H.conditioning]
    for c in range(n_categories):
        os.makedirs(os.path.join(sample_dir, str(int(c))))
    sample_i = 0

    # make dataset into format in which fid can be calculated
    tmp = os.path.join(base_dir, f"{H.dataset}_all.npz")
    if os.path.exists(tmp):
        dataset_stats = tmp
        dataset_dir = None
    else:
        dataset_stats = os.path.join(base_dir, f"{H.dataset}_{mode}.npz")
        dataset_dir = os.path.join(base_dir, f"{H.dataset}_{mode}")
    print(f"[FID] dataset_dir = {dataset_dir}")
    print(f"[FID] dataset_stats = {dataset_stats}")

    stats2, dir2 = dataset_stats, dataset_dir
    if not os.path.exists(stats2):
        if os.path.exists(dir2):
            print(f"{stats2} doesn't exist but {dir2} does")
            raise Exception
        os.makedirs(dir2)
        i2 = 0

    for batch in tqdm(DataLoader(dataset, batch_size=H.batch_size, shuffle=False), desc="Data generation"):
        img_batch = batch[0]
        img_batch = img_batch.to(device)
        ## save dataset images to dir2 (if its stats are not already computed in stats2)
        if not os.path.exists(stats2):
            for img in img_batch:
                tensor2png(img, os.path.join(dir2, f"{i2}.png"))
                i2 += 1

        # save image completions
        if N is None or sample_i < N:
            for cat_idx, categories in enumerate(range(n_categories)):
                seed = sample_i * n_categories + cat_idx
                mask = sample_mask(args, img_batch, categories=categories, seed=seed)
                mask = mask.to(device)
                samples = model.inpaint(img_batch, mask)
                for b, img in enumerate(samples):
                    path = os.path.join(sample_dir, str(int(categories)), f"{sample_i + b}.png")
                    tensor2png(img, path)
            sample_i += len(img_batch)
            if N is not None and sample_i == N and os.path.exists(stats2):
                break

    path2 = stats2 if os.path.exists(stats2) else dir2

    scores = {}
    for category in tqdm(range(n_categories), desc="FID computation"):
        cat_dir = os.path.join(sample_dir, str(int(category)))
        scores[f'fid-{int(category)}-{len(os.listdir(cat_dir))}'] = calculate_fid_given_paths([path2, cat_dir],
                                                     cache_path1=True, cache_path2=False)
        for fname in os.listdir(cat_dir):
            shutil.move(os.path.join(cat_dir, fname), os.path.join(sample_dir, f"{category}_{fname}"))  # move files from category dir to main dir
    scores[f'fid-{sample_i}'] = calculate_fid_given_paths([path2, sample_dir], cache_path1=True, cache_path2=False)

    if dir2 is not None and os.path.exists(dir2):
        shutil.rmtree(dir2)
    shutil.rmtree(sample_dir)

    return scores


if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--dataset', required=True, choices=["cifar10", "ffhq256"])
    parser.add_argument('--partition', required=True, choices=["test", "valid", "all"],
                        help="Chooses the dataset partition to compute FID on.")
    parser.add_argument('--model_type', type=str, choices=["np", "pluralistic", "vae", "ce"])
    parser.add_argument('--model_path', type=str)
    parser.add_argument('--conditioning', type=str,
                        choices=['patches', 'patches-missing', 'blank', 'foveal'], default='patches')
    parser.add_argument('--max_patches', type=int, default=5)
    parser.add_argument('--patch_size_frac', type=float, default=0.35,
                        help="Patch width as fraction of image width.")
    parser.add_argument('--data_root', type=str, default='./')
    parser.add_argument('--batch_size', type=int, default=None)
    parser.add_argument('-N', type=int, default=None)
    args = parser.parse_args()
    if args.model_type == "np":
        args.batch_size = 1
    if args.batch_size is None:
        if args.dataset == "cifar10":
            args.batch_size = 32
        elif args.model_type == "vae":
            args.batch_size = 4
        else:
            args.batch_size = 16
    print(args)

    ## FID computation ##
    ## Initialize the dataset
    if args.dataset == "cifar10":
        (trX, _), (vaX, _), (teX, _) = cifar10(args.data_root, one_hot=False)
    elif args.dataset == 'ffhq256':
        trX, vaX, teX = ffhq256(args.data_root)
    ## Dataset items should be images with shape (3x32x32) and pixel values in [0-1]
    ## teX originally has items with shape (32x32x3) and pixel values in [0-255] in uint8 type.
    trX = torch.as_tensor(trX).permute(0, 3, 1, 2)
    vaX = torch.as_tensor(vaX).permute(0, 3, 1, 2)
    teX = torch.as_tensor(teX).permute(0, 3, 1, 2)
    print("Dataset loaded")
    ## Prepare the model path and type objects
    assert args.model_type is not None
    interface = interface_dict[args.model_type]
    model_path = test_path_dict[args.dataset][args.model_type] if args.model_path is None else args.model_path
    results_path = f"{os.path.splitext(os.path.basename(model_path))[0]}_fid_{args.conditioning}_{args.partition}"
    results_path = os.path.join(os.path.dirname(model_path), f"{results_path}.json")
    ## Prepare the inpainting model
    inpainting_model = interface.create_model_from_path(model_path)
    inpainting_model.to(device)
    print("Inpainting model loaded")
    results = fid(args, (trX, vaX, teX), inpainting_model, mode=args.partition, N=args.N)
    json.dump(results, open(results_path, "w"), indent = 2)
    print(f"Stored the results at {results_path}")
