from torchvision import datasets, transforms
import torch

import os
import warnings

from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive
import numpy as np
from typing import Callable, Dict, List, Optional, Tuple, Union
import torchvision


class ImageFolderDataset(torch.utils.data.Dataset):
    EXTENSIONS = (
        ".jpg",
        ".jpeg",
        ".png",
        ".ppm",
        ".bmp",
        ".pgm",
        ".tif",
        ".tiff",
        ".webp",
    )
    """A generic data loader where the images are arranged in this way: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/asd932_.png

        Support conventional image formats when reading local images: ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp']
    """

    def __init__(
            self,
            root: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
    ):
        # prepare info
        self.transform = transform
        self.target_transform = target_transform
        self.loader = datasets.folder.default_loader

        # setup of data and targets
        self.classes, self.class_to_index = self._find_classes(root)
        self.data, self.targets = self._make_dataset(
            root=root,
            class_to_idx=self.class_to_index,
            is_allowed_file=self._has_file_allowed_extension,
        )
        self.data_size = len(self.data)
        # print (self.data_size, root)
        self.indices = list([x for x in range(0, self.data_size)])

        self.label_statistics = self._count_label_statistics(labels=self.targets)
        # print label statistics---------------------------------------------------------
        # for (i, v) in self.label_statistics.items():
        #     print(f"category={i}: {v}.\n")

    def __getitem__(self, index):
        data_idx = self.indices[index]
        img_path = self.data[data_idx]
        img = self.loader(img_path)
        target = self.targets[data_idx]

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def __len__(self):
        return len(self.indices)

    @staticmethod
    def _find_classes(root) -> Tuple[List[str], Dict[str, int]]:
        """
        Finds the class folders in a dataset.
        Ensures no class is a subdirectory of another.

        Args:
            root (string): Root directory path.

        Returns:
            tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
        """
        classes = [cls.name for cls in os.scandir(root) if cls.is_dir()]
        classes.sort()
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def _has_file_allowed_extension(self, filename: str) -> bool:
        """Checks if a file is an allowed extension."""
        return filename.lower().endswith(self.EXTENSIONS)

    @staticmethod
    def _make_dataset(
            root: str,
            class_to_idx: Dict[str, int],
            is_allowed_file: Callable[[str], bool],
    ) -> Tuple[List[str], List[int]]:
        imgs = []
        labels = []
        root = os.path.expanduser(root)

        for target_class in sorted(class_to_idx.keys()):
            class_index = class_to_idx[target_class]
            target_dir = os.path.join(root, target_class)
            if not os.path.isdir(target_dir):
                continue
            for dir, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
                for fname in sorted(fnames):
                    path = os.path.join(dir, fname)
                    if is_allowed_file(path):
                        imgs.append(path)
                        labels.append(class_index)
                    else:
                        raise NotImplementedError(
                            f"The extension = {fname.split('.')[-1]} is not supported yet."
                        )
        return imgs, labels

    def _count_label_statistics(self, labels: list) -> Dict[str, int]:
        """
        This function returns the statistics of label category.
        """
        label_statistics = {}

        if self.class_to_index is not None:
            for k, v in sorted(self.class_to_index.items(), key=lambda item: item[1]):
                num_occurrence = labels.count(v)
                label_statistics[k] = num_occurrence
        else:
            # get the number of categories.
            num_categories = len(set(labels))
            for i in range(num_categories):
                num_occurrence = labels.count(i)
                label_statistics[str(i)] = num_occurrence

        return label_statistics

    def trim_dataset(self, data_size: int, random_seed: int = None) -> None:
        """trim dataset in a random manner given a data size"""
        assert data_size <= len(
            self
        ), "given data size should be smaller than the original data size."
        rng = np.random.default_rng(random_seed)
        indices_to_keep = rng.choice(len(self), size=data_size, replace=False)
        self.indices = self.indices[indices_to_keep]
        self.data_size = len(self.indices)


def load_pacs_train_datasets():
    test_datasets = []
    domains = ["art", "cartoon", "photo", "sketch"]
    loader = torchvision.datasets.folder.default_loader

    # import ttab.loads.datasets.datasets as dsets

    train_transform = transforms.Compose(
        [
            # transforms.Resize((224,224)),
            transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
            transforms.RandomGrayscale(),
            # transforms.RandomAffine((-90,90), translate=(0.2, 0.2)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )

    x_trains, y_trains = [], []
    for domain in domains:
        root = "pacs data root" + domain

        if not os.path.exists('../data_np/%s_data_tr.npy' % domain):
            print ("making np data")
            dset = ImageFolderDataset(root)
            x_test_path, y_test = dset.data, dset.targets
            x_test = [train_transform(loader(x_test_path[i])) for i in range(len(x_test_path))]
            x_test, y_test = torch.stack(x_test), torch.tensor(y_test).long()

            np.save('../data_np/%s_data_tr.npy' % domain, x_test.numpy())
            np.save('../data_np/%s_label_tr.npy' % domain, y_test.numpy())

        x_train = np.load('../data_np/%s_data_tr.npy' % domain)
        y_train = np.load('../data_np/%s_label_tr.npy' % domain)

        x_trains.append(torch.tensor(x_train))
        y_trains.append(torch.tensor(y_train).long())

    return x_trains, y_trains


def load_pacs_simclr_dataset():
    test_datasets = []
    domains = ["art", "cartoon", "photo", "sketch"]
    loader = torchvision.datasets.folder.default_loader

    img_size = 224
    simclr_transforms = transforms.Compose([
        transforms.RandomResizedCrop(size=img_size, scale=(0.2, 1.)),  # TODO: modify the hard-coded size
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    x_tests1, x_tests2, y_tests = [], [], []
    for domain in domains:
        root = "pacs data root" + domain

        if not os.path.exists('../data_np/%s_data1_simclr.npy' % domain):
            print ("making np data")
            dset = ImageFolderDataset(root)
            x_test_path, y_test = dset.data, dset.targets
            x_test1 = [simclr_transforms(loader(x_test_path[i])) for i in range(len(x_test_path))]
            x_test2 = [simclr_transforms(loader(x_test_path[i])) for i in range(len(x_test_path))]
            x_test1, x_test2, y_test = torch.stack(x_test1), torch.stack(x_test2), torch.tensor(y_test).long()

            np.save('../data_np/%s_data1_simclr.npy' % domain, x_test1.numpy())
            np.save('../data_np/%s_data2_simclr.npy' % domain, x_test2.numpy())
            np.save('../data_np/%s_label_simclr.npy' % domain, y_test.numpy())

        x_test1 = np.load('../data_np/%s_data1_simclr.npy' % domain)
        x_test2 = np.load('../data_np/%s_data2_simclr.npy' % domain)
        y_test = np.load('../data_np/%s_label_simclr.npy' % domain)

        x_tests1.append(torch.tensor(x_test1))
        x_tests2.append(torch.tensor(x_test2))
        y_tests.append(torch.tensor(y_test).long())

    return x_tests1, x_tests2, y_tests


def load_pacs_test_datasets():
    test_datasets = []
    domains = ["art", "cartoon", "photo", "sketch"]
    loader = torchvision.datasets.folder.default_loader

    # import ttab.loads.datasets.datasets as dsets

    test_transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ])

    x_tests, y_tests = [], []
    for domain in domains:
        root = "root of pacs dataset" + domain

        if not os.path.exists('../data_np/%s_data.npy' % domain):
            print ("making np data")
            dset = ImageFolderDataset(root)
            x_test_path, y_test = dset.data, dset.targets
            x_test = [test_transform(loader(x_test_path[i])) for i in range(len(x_test_path))]
            x_test, y_test = torch.stack(x_test), torch.tensor(y_test).long()

            np.save('../data_np/%s_data.npy' % domain, x_test.numpy())
            np.save('../data_np/%s_label.npy' % domain, y_test.numpy())

        x_test = np.load('../data_np/%s_data.npy' % domain)
        y_test = np.load('../data_np/%s_label.npy' % domain)
        # x_test, y_test = torch.tensor(y_test), torch.tensor(y_test).long()

        x_tests.append(torch.tensor(x_test))
        y_tests.append(torch.tensor(y_test).long())

    return x_tests, y_tests, domains

