import time

from sacred import Ingredient
from pathlib import Path
import torch
from torch.distributions import Normal
from torchvision import transforms

from latents import (LatentSimpleData, MNIST, HorseZebra, LatentMixtureData,
                     LSUN)

# has to come before importing models
experiment_ingredient = Ingredient('experiment')


@experiment_ingredient.named_config
def annealing():
    annealing = True
    annealing_iterations = 100

@experiment_ingredient.named_config
def test_simple():
    datasets_name = 'test_simple'
    latent_shape = [1]


@experiment_ingredient.named_config
def test_mixture():
    datasets_name = 'test_mixture'
    latent_shape = [1]


@experiment_ingredient.named_config
def mnist():
    root = Path('.').resolve().parent / "datasets"
    datasets_name = 'MNIST'


@experiment_ingredient.named_config
def lsun_bedroom():
    root = Path('.').resolve().parent / "datasets" / "lsun_bedroom"
    datasets_name = 'lsun_bedroom'


@experiment_ingredient.named_config
def horsezebras():
    root = Path('.').resolve().parent / "datasets" / "horse2zebra"
    datasets_paths = (root / "trainA", root / "testA")
    datasets_name = 'horse2zebra'

    data_pre_cleanup = True


@experiment_ingredient.config
def cfg():
    root = None
    datasets_paths = None
    datasets_name = 'test_simple'
    latent_shape = [1, 10, 10]
    annealing = False
    annealing_iterations = None

    data_pre_cleanup = False


@experiment_ingredient.capture
def fetch_datasets(datasets_name, root, datasets_paths, latent_shape, annealing,
                   annealing_iterations, _log):

    if datasets_name == 'horse2zebra':
        datasets = [HorseZebra(f) for f in datasets_paths]
        data_shape = datasets[0].data_shape
    elif datasets_name == "celebHQ":
        raise NotImplementedError()
    elif datasets_name == "imageNet":
        raise NotImplementedError()
    elif datasets_name == "cifar10":
        raise NotImplementedError()
    elif datasets_name == "lsun_bedroom":
        transform = transforms.Compose([
            transforms.Resize((64, 64)),
            transforms.ToTensor(),
        ])
        t = time.time()
        datasets = [LSUN(str(root), classes=[f"bedroom_{type}"],
                         transform=transform,
                         target_transform=lambda x: torch.tensor(0))
                    for type in ['train', 'val']]

        _log.info(f"LSUN initialized. Time taken {time.time()-t}")

        data_shape = datasets[0][0][0].shape
    elif datasets_name == "MNIST":
        transform = transforms.Compose([
            transforms.ToTensor(),
        ])
        datasets = [MNIST(root, annealing=annealing,
                          annealing_iterations=annealing_iterations,
                          train=train, download=True, transform=transform,
                          target_transform=lambda x: torch.tensor(0))
                    for train in [True, False]]
        data_shape = datasets[0][0][0].shape
    elif datasets_name == 'test_simple':
        if latent_shape is None:
            raise ValueError("Latent shape should not be None")
        datasets = [LatentSimpleData(size=512*1000, annealing=annealing,
                                     shape=latent_shape,
                                     annealing_iterations=annealing_iterations,
                                     loc=-1, std=2) for _ in range(2)]

        data_shape = datasets[0].shape
    elif datasets_name == 'test_mixture':
        if latent_shape is None:
            raise ValueError("Latent shape should not be None")
        datasets = [LatentMixtureData(size=512*1000, annealing=annealing,
                                      shape=latent_shape,
                                      annealing_iterations=annealing_iterations,
                                      num_classes=1, locs=[-1, 1], stds=[1, 1],
                                      probs=[1/2., 1/2.]) for _ in range(2)]

        data_shape = datasets[0].shape
    else:
        raise NotImplementedError()
    return (datasets, data_shape)


@experiment_ingredient.capture
def data_loader(data_pre_cleanup, datasets_name, _log):

    data = fetch_datasets()
    (datasets, data_shape) = data

    if data_pre_cleanup:
        msg = "\n Number of cleanups in dataset:"
        msg += (f"\n\t\t Train: {datasets[0].n_cleanups}"
                + "\n\t\t Val: {datasets[1].n_cleanups}")
        _log.info(msg)

    return data + (datasets_name,)
