from os.path import join
from pathlib import Path

import numpy as np
import torchvision
from PIL import Image
from joblib import Parallel, delayed
from torch.utils.data import Dataset
from torchvision.transforms import transforms

import augment


class Normalize:
    def __init__(self, dataset):
        if dataset == 'mvtec':
            self.normalize = transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        elif dataset == 'svhn':
            self.normalize = transforms.Normalize(
                mean=[0.4377, 0.4438, 0.4728],
                std=[0.1980, 0.2010, 0.1970],
            )
        else:
            raise ValueError(dataset)

    def __call__(self, img):
        return self.normalize(img)


class MVTecAD(Dataset):
    def __init__(self, root, obj_type, ano_type, transform=None, mode='train',
                 img_size=256):
        self.root = Path(root)
        self.mode = mode
        self.img_size = img_size

        transform_list = [transforms.ToTensor()]
        if transform is not None:
            transform_list.append(transform)
        self.transform = transforms.Compose(transform_list)

        if self.mode == 'train':
            path = self.root / obj_type / 'train' / 'good'
            self.image_names = list(path.glob('*.png'))
        else:
            path = self.root / obj_type / 'test'
            if ano_type == 'all':
                self.image_names = list(path.glob(str(Path('*') / '*.png')))
            else:
                good = list(path.glob(str(Path('good') / '*.png')))
                anom = list(path.glob(str(Path(ano_type) / '*.png')))
                if len(anom) == 0:
                    raise ValueError(obj_type, ano_type)
                self.image_names = good + anom

        self.imgs = Parallel(n_jobs=10)(
            delayed(self.read_image)(file) for file in self.image_names
        )

    def read_image(self, file):
        return Image.open(file) \
            .resize((self.img_size, self.img_size)) \
            .convert('RGB')

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        img = self.transform(img)
        if self.mode == 'train':
            y = 0
        else:
            file_name = self.image_names[idx]
            y = file_name.parts[-2] != 'good'
        return img, y


class SyntheticMVTecAD(MVTecAD):
    """
    Synthetic dataset with controlled anomalies. It differs from NVTecAD only
    when `mode == 'test` in `__init__` and `__getitem__` functions.
    """

    def __init__(self, root, obj_type, aug_func, transform=None, mode='train',
                 img_size=256):
        super().__init__(root, obj_type, 'all', transform, mode, img_size)

        def is_good(v):
            return str(v).split('/')[-2] == 'good'

        if self.mode != 'train':
            assert transform is None

            to_tensor = transforms.ToTensor()
            good_imgs = [to_tensor(self.read_image(n))
                         for n in self.image_names if is_good(n)]

            self.imgs, i = [], 0
            for n in self.image_names:
                if is_good(n):
                    self.imgs.append(good_imgs[i])
                    i += 1
                else:
                    target_idx = np.random.randint(len(good_imgs))
                    target_img = good_imgs[target_idx]
                    augmented = aug_func(target_img.unsqueeze(0))
                    self.imgs.append(augmented.squeeze(0))

    def __getitem__(self, idx):
        if self.mode == 'train':
            return super().__getitem__(idx)
        else:
            img = self.imgs[idx]
            file_name = self.image_names[idx]
            y = file_name.parts[-2] != 'good'
            return img, y


class SVHN(torchvision.datasets.SVHN):
    def __init__(self, root, obj_type, ano_type, transform=None, mode='train'):
        super().__init__(join(root, 'SVHN'), mode, transform, download=True)
        obj_type = int(obj_type)
        if mode == 'train':
            index = self.labels == obj_type
            self.data = self.data[index]
            self.labels = self.labels[index]
        elif ano_type != 'all':
            ano_type = int(ano_type)
            index = (self.labels == obj_type) | (self.labels == ano_type)
            self.data = self.data[index]
            self.labels = self.labels[index]

        normal_idx = self.labels == obj_type
        self.labels[normal_idx] = 0
        self.labels[~normal_idx] = 1


def load_data(root, dataset, obj_type, ano_type, transform=None, mode='train',
              img_size=256, syn_args=None):
    if dataset == 'mvtec':
        path = join(root, 'mvtec_anomaly_detection')
        if ano_type == 'synthetic':
            aug_func = augment.to_aug_function('cutdiff', **syn_args)
            return SyntheticMVTecAD(path, obj_type, aug_func, transform, mode,
                                    img_size)
        else:
            return MVTecAD(path, obj_type, ano_type, transform, mode, img_size)
    elif dataset == 'mpdd':
        path = join(root, 'MPDD')
        return MVTecAD(path, obj_type, ano_type, transform, mode, img_size)
    elif dataset == 'svhn':
        if transform is None:
            t_list = []
        else:
            t_list = [transform]
        transform = transforms.Compose(t_list + [
            transforms.ToTensor()
        ])
        return SVHN(root, obj_type, ano_type, transform, mode)
    else:
        raise ValueError()
