from torch.utils.data import Dataset
#from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from utils.conf import base_path
from datasets.utils.validation import get_train_val
from datasets.utils.continual_dataset import ContinualDataset, store_masked_loaders
from datasets.utils.continual_dataset import get_previous_train_loader
from typing import Tuple
from datasets.transforms.denormalization import DeNormalize
from augmentations import get_aug
from PIL import Image
import numpy as np
import os
import os.path
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
from datasets.utils.cached_image_folder import CachedImageFolder


def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]):
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file
        extensions (tuple of strings): extensions to consider (lowercase)

    Returns:
        bool: True if the filename ends with one of given extensions
    """
    return filename.lower().endswith(extensions)


def is_image_file(filename: str):
    """Checks if a file is an allowed image extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    return has_file_allowed_extension(filename, IMG_EXTENSIONS)


def make_dataset(
    directory: str,
    class_to_idx: Dict[str, int],
    extensions: Optional[Tuple[str, ...]] = None,
    is_valid_file: Optional[Callable[[str], bool]] = None,
):
    instances = []
    directory = os.path.expanduser(directory)
    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
    if extensions is not None:
        def is_valid_file(x: str):
            return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))
    is_valid_file = cast(Callable[[str], bool], is_valid_file)
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = path, class_index
                    instances.append(item)
    return instances
def pil_loader(path: str):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

def accimage_loader(path: str):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)

def default_loader(path: str):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

class ImageFolder(Dataset):
    """
    Defines Tiny Imagenet as for the others pytorch datasets.
    """
    NAME = 'seq-tinyimgnet'
    SETTING = 'class-il'
    N_CLASSES_PER_TASK = 100
    N_TASKS = 10

    def __init__(self, root: str, train: bool=True, transform: transforms=None,
                target_transform: transforms=None, is_valid_file: Optional[Callable[[str], bool]] = None, loader: Callable[[str], Any] = default_loader):
        self.not_aug_transform = transforms.Compose([transforms.ToTensor()])
        self.root = root
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

        classes, class_to_idx = self._find_classes(self.root)
        IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
        extensions = IMG_EXTENSIONS if is_valid_file is None else None
        samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
        if len(samples) == 0:
            msg = "Found 0 files in subfolders of: {}\n".format(self.root)
            if extensions is not None:
                msg += "Supported extensions are: {}".format(",".join(extensions))
            raise RuntimeError(msg)

        self.loader = loader
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

        self.data = np.array(self.samples)
        self.target = np.array(self.targets)

    def _find_classes(self, dir: str):
        """
        Finds the class folders in a dataset.

        Args:
            dir (string): Root directory path.

        Returns:
            tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.

        Ensures:
            No class is a subdirectory of another.
        """
        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        classes.sort()
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx


    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self):
        return len(self.samples)



class SequentialImagenet(ContinualDataset):

    NAME = 'seq-imagenet'
    SETTING = 'class-il'
    N_CLASSES_PER_TASK = 100
    N_TASKS = 10
    
    def get_target_dataset(self):
        imagenet_norm = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
        transform = get_aug(config=self.args, name='imagenet', is_train=True, transform_single=False, to_pil_image=False)
        test_transform = get_aug(config=self.args, name='imagenet', is_train=False, transform_single=True, to_pil_image=False)
        
        if self.args.dataset.zip_mode:
            ann_file = prefix + "_map.txt"
            prefix = prefix + ".zip@/"
            train_dataset = CachedImageFolder('/st1/dataset/imagenet1k/raw-data/', ann_file, prefix, transform,
                                        cache_mode='part')
            test_dataset = CachedImageFolder('/st1/dataset/imagenet1k/raw-data/', ann_file, prefix, test_transform,
                                        cache_mode='part')
        else:
            # root = os.path.join(config.DATA.DATA_PATH, prefix)
            # dataset = datasets.ImageFolder(root, transform=transform)
            train_dataset = ImageFolder('/st1/dataset/imagenet1k/raw-data/train/', transform=transform)
            test_dataset = ImageFolder('/st1/dataset/imagenet1k/raw-data/val/', transform=test_transform)
        
        if self.args.validation:
            memory_dataset = ImageFolder('/st1/dataset/imagenet1k/raw-data/train/', transform=test_transform)
            memory_dataset, _ = get_train_val(memory_dataset, test_transform, self.NAME)
        else:
            memory_dataset = None

        if self.args.dataset.train_ratio < 1:
            picks = np.random.choice(len(train_dataset.samples),int(len(train_dataset.samples) * self.args.dataset.train_ratio), replace=False)            
            train_dataset.samples =  np.array(train_dataset.samples)[picks]
            train_dataset.targets =  np.array(train_dataset.targets)[picks]
        
        if self.args.dataset.val_ratio < 1:
            picks = np.random.choice(len(test_dataset.samples),int(len(test_dataset.samples) * self.args.dataset.val_ratio), replace=False)
            test_dataset.samples =  np.array(test_dataset.samples)[picks]
            test_dataset.targets =  np.array(test_dataset.targets)[picks]
        
        return train_dataset, memory_dataset, test_dataset
            
    def get_data_loaders(self):
        imagenet_norm = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
        transform = get_aug(config=self.args, name='imagenet', is_train=True, transform_single=False, to_pil_image=False)
        test_transform = get_aug(config=self.args, name='imagenet', is_train=False, transform_single=True, to_pil_image=False)
        
        if self.args.dataset.zip_mode:
            ann_file = prefix + "_map.txt"
            prefix = prefix + ".zip@/"
            train_dataset = CachedImageFolder('/st1/dataset/imagenet1k/raw-data/', ann_file, prefix, transform,
                                        cache_mode='part')
            test_dataset = CachedImageFolder('/st1/dataset/imagenet1k/raw-data/', ann_file, prefix, test_transform,
                                        cache_mode='part')
        else:
            # root = os.path.join(config.DATA.DATA_PATH, prefix)
            # dataset = datasets.ImageFolder(root, transform=transform)
            train_dataset = ImageFolder('/st1/dataset/imagenet1k/raw-data/train/', transform=transform)
            test_dataset = ImageFolder('/st1/dataset/imagenet1k/raw-data/val/', transform=test_transform)
        
        if self.args.validation:
            memory_dataset = ImageFolder('/st1/dataset/imagenet1k/raw-data/train/', transform=test_transform)
            memory_dataset, _ = get_train_val(memory_dataset, test_transform, self.NAME)
        else:
            memory_dataset = None

        if self.args.dataset.train_ratio < 1:
            picks = np.random.choice(len(train_dataset.samples),int(len(train_dataset.samples) * self.args.dataset.train_ratio), replace=False)
            if self.args.ddebug:
                import pdb; pdb.set_trace()
            train_dataset.samples =  np.array(train_dataset.samples)[picks]
            train_dataset.targets =  np.array(train_dataset.targets)[picks]
        
        if self.args.dataset.val_ratio < 1:
            picks = np.random.choice(len(test_dataset.samples),int(len(test_dataset.samples) * self.args.dataset.val_ratio), replace=False)
            test_dataset.samples =  np.array(test_dataset.samples)[picks]
            test_dataset.targets =  np.array(test_dataset.targets)[picks]
        
        labels_to_paths=False if self.args.model.backbone == 'supervised' else True
        train, memory, test = store_masked_loaders(train_dataset, memory_dataset, test_dataset, self, labels_to_paths=labels_to_paths)
        return train, memory, test

    def get_transform(self, is_train=True):
        transform = get_aug(config=self.args, name='imagenet', is_train=is_train, transform_single=True, to_pil_image=True)
        return transform

    def not_aug_dataloader(self, batch_size):
        imagenet_norm = [[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]
        transform = transforms.Compose([transforms.ToTensor(),
                transforms.Normalize(*imagenet_norm)])

        train_dataset = ImageFolder('/st1/dataset/imagenet1k/raw-data/train/', transform=transform)
        train_loader = get_previous_train_loader(train_dataset, batch_size, self)

        return train_loader
