import os
import numpy as np
from pathlib import Path
import torch.nn as nn
import re
import sys
from typing import Any, Dict, List, Optional, Tuple
import torchvision.transforms.functional as TF
from torchvision import transforms
from PIL import Image
import random
import torch
import glob
import cv2
from my_transforms import get_imagenet_transforms, get_imagenet_transforms_simsiam
import torchvision
import torchvision.transforms as T

def pil_loader(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("RGB")
def pil_loader_gray(path: str) -> Image.Image:
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, "rb") as f:
        img = Image.open(f)
        return img.convert("L")

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(
        directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(
            f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


class ImageNetteDataset(torch.utils.data.Dataset):
    def __init__(self, root: str, num_views: int = 0, aug_root_path: str = None, transform=None, target_transform=None, randomize = False, ordered = False, adampi_prob = 0.0):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        directory = os.path.expanduser(root)
        _, class_to_idx = find_classes(directory)
        # print(class_to_idx)
        self.class_to_idx = class_to_idx
        self.samples = [(os.path.join(directory, i), class_to_idx[i])
                        for i in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, i))]
        self.all_samples = []
        for dir_, class_idx in self.samples:
            for image_path in sorted(os.listdir(dir_)):
                self.all_samples.append(
                    (os.path.join(dir_, image_path), class_idx))

        self.num_views = num_views
        self.aug_root_path = aug_root_path
        if self.aug_root_path:
            self.aug_samples = {}
            for dir_ in os.listdir(self.aug_root_path):
                self.aug_samples[dir_] = [os.path.join(self.aug_root_path, dir_, i) for i in sorted(os.listdir(
                    os.path.join(self.aug_root_path, dir_)))]
        self.randomize = randomize
        self.ordered = ordered
        self.adampi_prob = adampi_prob
        print("adampi_prb: ", self.adampi_prob)

    def __len__(self):
        return len(self.all_samples)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        image_path, label = self.all_samples[idx]
        image = pil_loader(image_path)

        if self.transform is not None:
            image_t = self.transform(image)
            if self.num_views:
                image_all = [image_t]
                if self.aug_root_path:
                    image_name = os.path.basename(os.path.splitext(image_path)[0])
                    if image_name not in self.aug_samples:
                        for _ in range(self.num_views):
                            image_all.append(self.transform(image))
                    else:
                        if self.ordered:
                            len_aug = len(self.aug_samples[image_name])
                            rand_idx = random.randint(0, len_aug)
                            aug_images_path = [self.aug_samples[image_name][(rand_idx+j)%len_aug] for j in range(self.num_views)]
                        else:

                            if random.random() <= self.adampi_prob:
                                for idx in range(self.num_views):
                                    image_all.append(self.transform(image))
                            else:
                            #aug_images_path = random.sample(
                            #    self.aug_samples[image_name], k=self.num_views)
                                aug_images_path = random.sample(
                                    self.aug_samples[image_name], k=self.num_views)
                                for idx, i in enumerate(aug_images_path):
                                # Note doing idx+2 because I want to apply std transforms to adampi views
                                   image_all.append(self.transform(pil_loader(i)))
                        #for i in aug_images_path:
                        #    image_all.append(self.transform(pil_loader(i)))
                else:
                    for _ in range(self.num_views):
                        image_all.append(self.transform(image))
            else:
                image_all = image_t
            if self.randomize:
                random.shuffle(image_all)
        else:
            image_all = image

        return image_all, label, image_path

class ImageNetteDatasetSwAV(torch.utils.data.Dataset):
    def __init__(self, root: str, num_views: int = 0, aug_root_path: str = None, transform=None, target_transform=None, randomize = False, ordered = False, adampi_prob = 0.0):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        directory = os.path.expanduser(root)
        _, class_to_idx = find_classes(directory)
        # print(class_to_idx)
        self.class_to_idx = class_to_idx
        self.samples = [(os.path.join(directory, i), class_to_idx[i])
                        for i in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, i))]
        self.all_samples = []
        for dir_, class_idx in self.samples:
            for image_path in sorted(os.listdir(dir_)):
                self.all_samples.append(
                    (os.path.join(dir_, image_path), class_idx))

        self.num_views = num_views
        self.aug_root_path = aug_root_path
        if self.aug_root_path:
            self.aug_samples = {}
            for dir_ in os.listdir(self.aug_root_path):
                self.aug_samples[dir_] = [os.path.join(self.aug_root_path, dir_, i) for i in sorted(os.listdir(
                    os.path.join(self.aug_root_path, dir_)))]
        self.randomize = randomize
        self.ordered = ordered
        self.adampi_prob = adampi_prob
        self.num_views = 1
        print("self.num_views: ", self.num_views)
        self.low_res_views = 6
        self.high_res_views = 2
        print("adampi_prb: ", self.adampi_prob)

        crop_sizes = [128, 96]
        crop_min_scales = [0.14, 0.05]
        crop_max_scales = [1.0, 0.14]


        self.multi_crop_transform = T.RandomResizedCrop(
                crop_sizes[1],
                scale=(crop_min_scales[1], crop_max_scales[1])
            )

    def __len__(self):
        return len(self.all_samples)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        image_path, label = self.all_samples[idx]
        image = pil_loader(image_path)

        if self.transform is not None:
            image_t = self.transform(image)
            if self.num_views:
                image_all = [image_t]
                if self.aug_root_path:
                    image_name = os.path.basename(os.path.splitext(image_path)[0])
                    if image_name not in self.aug_samples:
                        for _ in range(self.num_views):
                            image_all.append(self.transform(image))
                    else:
                            if random.random() <= self.adampi_prob:
                                for idx in range(self.num_views):
                                    image_all.append(self.transform(image))
                                for i in range(6):
                                    image_all.append(self.multi_crop_transform(image))
                            else:
                            #aug_images_path = random.sample(
                            #    self.aug_samples[image_name], k=self.num_views)
                                aug_images_path = random.sample(
                                    self.aug_samples[image_name], k=self.num_views)
                                for idx, i in enumerate(aug_images_path):
                                # Note doing idx+2 because I want to apply std transforms to adampi views
                                   image_all.append(self.transform(pil_loader(i)))
                                for i in range(3):
                                    image_all.append(self.multi_crop_transform(image))
                                temp_image = pil_loader(aug_images_path[0])
                                for i in range(3):
                                    image_all.append(self.multi_crop_transform(temp_image))

                        #for i in aug_images_path:
                        #    image_all.append(self.transform(pil_loader(i)))
                else:
                    for _ in range(self.num_views):
                        image_all.append(self.transform(image))
            else:
                image_all = image_t
            if self.randomize:
                random.shuffle(image_all)
        else:
            image_all = image

        return torch.stack(image_all), label, image_path

def depth_loader(disp_path):
     disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1)
     disp_img = Image.fromarray(disp)
     return disp_img


class ImageNetteDatasetDepth(torch.utils.data.Dataset):
    def __init__(self, root: str, num_views: int = 0, depth_path: str = None, aug_root_path: str = None, transform=None, target_transform=None, 
           train=False, args = None):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.depth_path = depth_path

        directory = os.path.expanduser(root)
        _, class_to_idx = find_classes(directory)
        # print(class_to_idx)
        self.class_to_idx = class_to_idx
        self.samples = [(os.path.join(directory, i), class_to_idx[i])
                        for i in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, i))]
        self.all_samples = []
        for dir_, class_idx in self.samples:
            for image_path in sorted(os.listdir(dir_)):
                self.all_samples.append(
                    (os.path.join(dir_, image_path), class_idx))

        self.num_views = num_views
        self.aug_root_path = aug_root_path
        if self.aug_root_path:
            self.aug_samples = {}
            for dir_ in os.listdir(self.aug_root_path):
                self.aug_samples[dir_] = [os.path.join(self.aug_root_path, dir_, i) for i in sorted(os.listdir(
                    os.path.join(self.aug_root_path, dir_)))]
        if self.depth_path:
            self.depth_index = {}
            for dir_ in os.listdir(self.depth_path):
                for img in sorted(os.listdir(os.path.join(self.depth_path, dir_))):
                    image_name = os.path.splitext(img)[0]
                    self.depth_index[image_name] = os.path.join(self.depth_path, dir_, img)
            print(len(self.depth_index))

        self.randomize = args.randomize
        self.ordered = args.ordered
        self.rrc_transform = transforms_rrc_hflip(input_size= 128, min_scale=0.08, h_flip_prob=0.5, drop_depth = args.drop_depth, use_pfm=False)
        self.cc_transform = transforms_cc_resize(input_size= 128, use_pfm=args.use_pfm)
        self.resize_transform = self.rrc_transform if train else self.cc_transform
        self.depth_loader = pil_loader_gray if args.depth_0_255 else depth_loader
        self.depth_transform = lambda x: torch.from_numpy(np.array(x)).float().unsqueeze(0)
        self.resize_depth = args.resize_depth
        self.drop_depth = args.drop_depth
        self.reshape_transform = transforms.Resize(size=160)
        self.use_pfm = False

    def renormalize_depth(self, x):
        return (x - self.dataset_min)/ (self.dataset_max - self.dataset_min)

    def __len__(self):
        return len(self.all_samples)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        image_path, label = self.all_samples[idx]
        image = pil_loader(image_path)

        if self.depth_path:
            image_name = os.path.basename(os.path.splitext(image_path)[0])
            depth_image = self.depth_loader(self.depth_index[image_name])
            if self.resize_depth:
                depth_image = self.reshape_transform(depth_image)

        if self.transform is not None:
            image_t, depth_t = self.resize_transform(image, depth_image)
            image_t = self.transform(image_t)
            depth_t = self.depth_transform(depth_t)
            if random.random() < self.drop_depth:
                depth_t = torch.zeros_like(depth_t).float()
            image_t_ = torch.cat((image_t, depth_t), 0)
            if self.num_views:
                image_all = [image_t_]
                for _ in range(self.num_views):
                    image_t2, depth_t2 = self.resize_transform(image, depth_image)
                    image_t2 = self.transform(image_t2)
                    depth_t2 = self.depth_transform(depth_t2) 
                    if random.random() < self.drop_depth:
                        depth_t2 = torch.zeros_like(depth_t2).float()
                    image_t_ = torch.cat((image_t2, depth_t2), 0)
                    image_all.append(image_t_)
            else:
                image_all = image_t_
            if self.randomize:
                random.shuffle(image_all)
        else:
            image_all = image

        return image_all, label, image_path

class ImageNetteDatasetDepth3DViews(torch.utils.data.Dataset):
    def __init__(self, root: str, num_views: int = 0, depth_path: str = None, aug_root_path: str = None, depth_aug_root_path:str = None, 
                transform=None, target_transform=None, 
                train=False, args = None):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.depth_path = depth_path
        self.depth_aug_root_path = depth_aug_root_path

        directory = os.path.expanduser(root)
        _, class_to_idx = find_classes(directory)
        # print(class_to_idx)
        self.class_to_idx = class_to_idx
        self.samples = [(os.path.join(directory, i), class_to_idx[i])
                        for i in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, i))]
        self.all_samples = []
        for dir_, class_idx in self.samples:
            for image_path in sorted(os.listdir(dir_)):
                self.all_samples.append(
                    (os.path.join(dir_, image_path), class_idx))

        self.num_views = num_views
        self.aug_root_path = aug_root_path
        if self.aug_root_path:
            self.aug_samples = {}
            for dir_ in os.listdir(self.aug_root_path):
                self.aug_samples[dir_] = [os.path.join(self.aug_root_path, dir_, i) for i in sorted(os.listdir(
                    os.path.join(self.aug_root_path, dir_)))]
    
        if self.depth_path:
            self.depth_index = {}
            for dir_ in os.listdir(self.depth_path):
                for img in sorted(os.listdir(os.path.join(self.depth_path, dir_))):
                    image_name = os.path.splitext(img)[0]
                    self.depth_index[image_name] = os.path.join(self.depth_path, dir_, img)
            print(len(self.depth_index))

        self.randomize = args.randomize
        self.ordered = args.ordered
        self.rrc_transform = transforms_rrc_hflip(input_size= 128, min_scale=0.08, h_flip_prob=0.5, drop_depth = args.drop_depth, use_pfm=False)
        self.cc_transform = transforms_cc_resize(input_size= 128, use_pfm=args.use_pfm)
        self.resize_transform = self.rrc_transform if train else self.cc_transform
        self.depth_loader = pil_loader_gray if args.depth_0_255 else depth_loader
        self.depth_transform = lambda x: torch.from_numpy(np.array(x)).float().unsqueeze(0)
        self.resize_depth = args.resize_depth
        self.drop_depth = args.drop_depth
        self.reshape_transform = transforms.Resize(size=160)
        self.use_pfm = False

    def renormalize_depth(self, x):
        return (x - self.dataset_min)/ (self.dataset_max - self.dataset_min)

    def __len__(self):
        return len(self.all_samples)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        image_path, label = self.all_samples[idx]
        image = pil_loader(image_path)

        if self.depth_path:
            image_name = os.path.basename(os.path.splitext(image_path)[0])
            depth_image = self.depth_loader(self.depth_index[image_name])
            if self.resize_depth:
                depth_image = self.reshape_transform(depth_image)

        if self.transform is not None:
            image_t, depth_t = self.resize_transform(image, depth_image)
            image_t = self.transform(image_t)
            depth_t = self.depth_transform(depth_t)
            if random.random() < self.drop_depth:
                depth_t = torch.zeros_like(depth_t).float()
            image_t_ = torch.cat((image_t, depth_t), 0)
            if self.num_views:
                image_all = [image_t_]
                for _ in range(self.num_views):
                    if image_name not in self.aug_samples:
                        for _ in range(self.num_views):
                            image_all.append(self.transform(image))
                    else:
                        if random.random() <= self.adampi_prob:
                            for idx in range(self.num_views):
                                image_all.append(self.transform(image))
                        else:
                        #aug_images_path = random.sample(
                        #    self.aug_samples[image_name], k=self.num_views)
                            aug_images_path = random.sample(
                                self.aug_samples[image_name], k=self.num_views)
                            for idx, i in enumerate(aug_images_path):
                            # Note doing idx+2 because I want to apply std transforms to adampi views
                                # i becomes name of the view
                                _, img_name = os.path.split(i)
                                depth_aug_path = self.depth_loader(os.path.join(self.depth_aug_root_path, img_name))
                                image_aug = pil_loader(i)
                                image_t2, depth_t2 = self.resize_transform(image_aug, depth_aug_path)
                            # image_t2 = self.transform(image_t2)
                            # depth_t2 = self.depth_transform(depth_t2) 
                                if random.random() < self.drop_depth:
                                    depth_t2 = torch.zeros_like(depth_t2).float()
                                image_t_ = torch.cat((image_t2, depth_t2), 0)
                                image_all.append(image_t_)
            else:
                image_all = image_t_
            if self.randomize:
                random.shuffle(image_all)
        else:
            image_all = image

        return image_all, label, image_path

# Consistent cropping for RGB and Depth Chanell
class transforms_rrc_hflip(nn.Module):
    def __init__(self, input_size, min_scale, h_flip_prob, drop_depth, use_pfm):
        super().__init__()
        use_pfm = False
        if input_size==224:
            self.rrc = transforms.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0), interpolation=Image.BICUBIC)
        else:
            self.rrc = transforms.RandomResizedCrop(size=input_size, scale=(min_scale, 1.0))

        self.h_flip_prob = h_flip_prob
        self.hflip = TF.hflip
        self.input_size = input_size
        self.drop_depth = drop_depth
        self.use_pfm = use_pfm

    def forward(self, img, depth):
        i, j, h, w = self.rrc.get_params(img, self.rrc.scale, self.rrc.ratio)
        transformed_image = TF.resized_crop(img, i, j, h, w, self.rrc.size, self.rrc.interpolation)#, antialias=self.rrc.antialias)
        transformed_depth = TF.resized_crop(depth, i, j, h, w, self.rrc.size, self.rrc.interpolation)#, antialias=self.rrc.antialias)

        if random.random() > self.h_flip_prob:
            transformed_image = self.hflip(transformed_image)
            #if self.use_pfm:
            #    transformed_depth = cv2.flip(transformed_depth, 1)
            #else:
            transformed_depth = self.hflip(transformed_depth)

        #if random.random() < self.drop_depth:
        #    transformed_depth = np.zeros_like(transformed_depth)


        return transformed_image, transformed_depth


class transforms_cc_resize(nn.Module):
    def __init__(self, input_size, use_pfm):
        super().__init__()
        use_pfm = False
        if input_size==224:
            resize_size = 256
        else:
            resize_size = input_size
        self.val_transforms = transforms.Compose([
        transforms.Resize(resize_size),
        transforms.CenterCrop(input_size)
    ])  
        self.use_pfm = use_pfm
        self.input_size = input_size


    def forward(self, img, depth):
        
        return self.val_transforms(img), self.val_transforms(depth)

class ImageNetteDatasetDepthSwAV(torch.utils.data.Dataset):
    def __init__(self, root: str, num_views: int = 0, depth_path: str = None, aug_root_path: str = None, transform=None, target_transform=None, 
           train=False, args = None):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.depth_path = depth_path

        directory = os.path.expanduser(root)
        _, class_to_idx = find_classes(directory)
        # print(class_to_idx)
        self.class_to_idx = class_to_idx
        self.samples = [(os.path.join(directory, i), class_to_idx[i])
                        for i in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, i))]
        self.all_samples = []
        for dir_, class_idx in self.samples:
            for image_path in sorted(os.listdir(dir_)):
                self.all_samples.append(
                    (os.path.join(dir_, image_path), class_idx))

        self.num_views = num_views
        self.aug_root_path = aug_root_path
        if self.depth_path:
            self.depth_index = {}
            for dir_ in os.listdir(self.depth_path):
                for img in sorted(os.listdir(os.path.join(self.depth_path, dir_))):
                    image_name = os.path.splitext(img)[0]
                    self.depth_index[image_name] = os.path.join(self.depth_path, dir_, img)
            print(len(self.depth_index))
        
         
        self.randomize = False
        self.ordered = False
        #self.rrc_transform = transforms_rrc_hflip(input_size= 128, min_scale=0.08, h_flip_prob=0.5, drop_depth = args.drop_depth, use_pfm=False)
        self.train = train
        self.rrc_transform = multi_crop_transforms(drop_depth=args.drop_depth)
        self.cc_transform = transforms_cc_resize(input_size= 128, use_pfm=args.use_pfm)
        self.resize_transform = self.rrc_transform if train else self.cc_transform
        self.depth_loader = pil_loader_gray if args.depth_0_255 else depth_loader
        self.depth_transform = lambda x: torch.from_numpy(np.array(x)).float().unsqueeze(0)
        
        self.resize_depth = args.resize_depth
        self.drop_depth = args.drop_depth
        self.use_pfm = False

    def __len__(self):
        return len(self.all_samples)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        image_path, label = self.all_samples[idx]
        image = pil_loader(image_path)

        if self.depth_path:
            image_name = os.path.basename(os.path.splitext(image_path)[0])
            depth_image = self.depth_loader(self.depth_index[image_name])
            #print(np.array(depth_image).shape)

        if self.transform is not None:
            image_t, depth_t = self.resize_transform(image, depth_image)
            if self.train:
                image_t_all = [self.transform(i) for i in image_t]
                depth_t_all = [self.depth_transform(i) for i in depth_t]
                image_all =  []
                for i, d in zip(image_t_all, depth_t_all):
                    image_all.append(torch.cat((i, d), 0))
                    
            else:
                image_t_all = self.transform(image_t)
                depth_t_all = self.depth_transform(depth_t)
                image_all = torch.cat((image_t_all, depth_t_all), 0)
            if self.randomize:
                random.shuffle(image_all)
        else:
            image_all = image

        return image_all, label, image_path


class multi_crop_transforms(nn.Module):
    def __init__(self, crop_sizes = [128, 64], crop_counts = [2,6], crop_min_scales = [0.14, 0.05], crop_max_scales = [1.0, 0.14], h_flip_prob=0.5, drop_depth=0.0):
        super().__init__()
        use_pfm = False
        input_size = crop_sizes[0]
        crop_transforms = []
        for i in range(len(crop_sizes)):

            random_resized_crop = transforms.RandomResizedCrop(
                crop_sizes[i],
                scale=(crop_min_scales[i], crop_max_scales[i])
            )

            crop_transforms.extend([random_resized_crop
            ] * crop_counts[i])

        print(len(crop_transforms))
        self.crop_transforms = crop_transforms
        self.h_flip_prob = h_flip_prob
        self.hflip = TF.hflip
        self.drop_depth = drop_depth
        self.input_size = input_size
        self.use_pfm = False

    def forward(self, img, depth):
        all_image = []
        all_depth = []
        for rrc_crop in self.crop_transforms: 
            i, j, h, w = rrc_crop.get_params(img, rrc_crop.scale, rrc_crop.ratio)
            transformed_image = TF.resized_crop(img, i, j, h, w, rrc_crop.size, rrc_crop.interpolation)#, antialias=self.rrc.antialias)
            transformed_depth = TF.resized_crop(depth, i, j, h, w, rrc_crop.size, rrc_crop.interpolation)#, antialias=self.rrc.antialias)

            if random.random() > self.h_flip_prob:
                transformed_image = self.hflip(transformed_image)
                transformed_depth = self.hflip(transformed_depth)
            if random.random() < self.drop_depth:
                transformed_depth = np.zeros_like(transformed_depth)
            all_image.append(transformed_image)
            all_depth.append(transformed_depth)


        return all_image, all_depth



class ImageNetDataset(torch.utils.data.Dataset):
    def __init__(self, root: str, aug_root_path: str = None, num_views: int = 0, transform=None, target_transform=None, args=None):
        self.root = root
        self.aug_root_path = aug_root_path
        if transform is None:
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            if args.method=="simsiam":
                self.transform = get_imagenet_transforms_simsiam(224)
            else:
                if not args.std_transforms:
                    self.transform = get_imagenet_transforms(224)
                else:
                    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
                    print("Using Standard Transforms")
                    transform1 = transforms.Compose([
                       transforms.RandomResizedCrop(224),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       normalize,
                       ])
                    transform2 = transforms.Compose([
                       transforms.RandomResizedCrop(224),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       normalize,
                       ])
                    self.transform = (transform1, transform2)
        else:
            self.transform = [transform]
        self.target_transform = target_transform

        directory = os.path.expanduser(root)
        _, class_to_idx = find_classes(directory)
        # print(class_to_idx)
        self.class_to_idx = class_to_idx
        self.samples = [(os.path.join(directory, i), class_to_idx[i])
                        for i in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, i))]
        self.all_samples = []
        for dir_, class_idx in self.samples:
            for image_path in sorted(os.listdir(dir_)):
                self.all_samples.append(
                    (os.path.join(dir_, image_path), class_idx))
        
        if self.aug_root_path: 
            self.aug_samples = {}
            for dir_ in os.listdir(self.aug_root_path):
                self.aug_samples[dir_] = [os.path.join(self.aug_root_path, dir_, i) for i in sorted(os.listdir(
                    os.path.join(self.aug_root_path, dir_)))]
            print(f"Length of self.aug_samples: {len(self.aug_samples)}")
 
        self.num_views = num_views
        if args is None:
            self.ordered = False
            self.randomize = False
            self.adampi_prob = 1
        else:
            self.ordered = args.ordered
            self.randomize = args.randomize
            self.adampi_prob = args.adampi_prob
        
    def __len__(self):
        return len(self.all_samples)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        image_path, label = self.all_samples[idx]
        image = pil_loader(image_path)

        if self.transform is not None:
            image_t = self.transform[0](image)
            if self.num_views:
                image_all = [image_t]
                if self.aug_root_path:
                    image_name = os.path.basename(os.path.splitext(image_path)[0])
                    if image_name not in self.aug_samples:
                        for idx in range(self.num_views):
                            image_all.append(self.transform[idx+1](image))
                    else:
                        if random.random() <= self.adampi_prob:
                            for idx in range(self.num_views):
                                image_all.append(self.transform[idx+1](image))
                        else:
                            aug_images_path = random.sample(
                                self.aug_samples[image_name], k=self.num_views)
                            for idx, i in enumerate(aug_images_path):
                                image_all.append(self.transform[idx+1](pil_loader(i)))

                else:
                    for idx in range(self.num_views):
                        image_all.append(self.transform[idx+1](image))
            else:
                image_all = image_t
        else:
            image_all = image

        return image_all, label, image_path

class ImageNetDatasetDepth(torch.utils.data.Dataset):
    def __init__(self, root: str, depth_path: str = None, num_views: int = 0, train =False, transform=None, target_transform=None, args = None):
        self.root = root
        if transform is None:
            self.transform = get_imagenet_transforms(224)
            if depth_path:
                new_transform = transforms.Compose(self.transform[1].transforms[2:])
                new_transform_0 = transforms.Compose(self.transform[0].transforms[2:])
                self.transform = (new_transform_0, new_transform)
                print(self.transform)
        else:
            self.transform = [transform]
        self.target_transform = target_transform
        self.depth_path = depth_path

        directory = os.path.expanduser(root)
        _, class_to_idx = find_classes(directory)
        # print(class_to_idx)
        self.class_to_idx = class_to_idx
        self.samples = [(os.path.join(directory, i), class_to_idx[i])
                        for i in sorted(os.listdir(directory)) if os.path.isdir(os.path.join(directory, i))]
        self.all_samples = []
        for dir_, class_idx in self.samples:
            for image_path in sorted(os.listdir(dir_)):
                self.all_samples.append(
                    (os.path.join(dir_, image_path), class_idx))

        self.num_views = num_views
        if self.depth_path:
            self.depth_index = {}
            for dir_ in os.listdir(self.depth_path):
                for img in sorted(os.listdir(os.path.join(self.depth_path, dir_))):
                    image_name = os.path.splitext(img)[0]
                    self.depth_index[image_name] = os.path.join(self.depth_path, dir_, img)
            print(len(self.depth_index))
        self.rrc_transform = transforms_rrc_hflip(input_size=224, min_scale=0.08, h_flip_prob=0.5, drop_depth = args.drop_depth, use_pfm=False)
        self.cc_transform = transforms_cc_resize(input_size=224, use_pfm=False)
        self.resize_transform = self.rrc_transform if train else self.cc_transform
        self.depth_loader = pil_loader_gray if args.depth_0_255 else depth_loader
        self.depth_transform = lambda x: torch.from_numpy(np.array(x)).float().unsqueeze(0)
        
        self.drop_depth = args.drop_depth
        try:
           self.drop_rgb = args.drop_rgb
        except:
            self.drop_rgb = 0
        

    def __len__(self):
        return len(self.all_samples)

    def __getitem__(self, idx: int) -> Tuple[Any, Any]:
        image_path, label = self.all_samples[idx]
        image = pil_loader(image_path)

        if self.depth_path:
            image_name = os.path.basename(os.path.splitext(image_path)[0])
            depth_image = self.depth_loader(self.depth_index[image_name])

        if self.transform is not None:
            image_t, depth_t = self.resize_transform(image, depth_image)
            image_t = self.transform[0](image_t)
            depth_t = self.depth_transform(depth_t)
            if random.random() < self.drop_depth:
                depth_t = torch.zeros_like(depth_t).float()
            if random.random() < self.drop_rgb:
                image_t = torch.zeros_like(image_t).float()

            image_t_ = torch.cat((image_t, depth_t), 0)
            if self.num_views:
                image_all = [image_t_]
                for idx in range(self.num_views):
                    image_t2, depth_t2 = self.resize_transform(image, depth_image)
                    image_t2 = self.transform[idx](image_t2)
                    depth_t2 = self.depth_transform(depth_t2)
                    if random.random() < self.drop_depth:
                        depth_t2 = torch.zeros_like(depth_t2).float()
                    if random.random() < self.drop_rgb:
                        image_t2 = torch.zeros_like(image_t2).float()
                    image_t_ = torch.cat((image_t2, depth_t2), 0)
                    image_all.append(image_t_)
            else:
                image_all = image_t_
        else:
            image_all = image

        return image_all, label, image_path


