from torchvision import transforms
from PIL import Image
import os
import torch
import glob
from torchvision.datasets import MNIST, CIFAR10, FashionMNIST, ImageFolder
import numpy as np
import torch.multiprocessing
from torch.utils.data import Dataset
from PIL import ImageFilter, Image, ImageOps
from torchvision.datasets.folder import default_loader
import os
import random
import pydicom

torch.multiprocessing.set_sharing_strategy('file_system')


def get_data_transforms(size, isize, mean_train=None, std_train=None):
    mean_train = [0.485, 0.456, 0.406] if mean_train is None else mean_train
    std_train = [0.229, 0.224, 0.225] if std_train is None else std_train
    data_transforms = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.CenterCrop(isize),])
        # transforms.Normalize(mean=mean_train,
        #                      std=std_train)])
    gt_transforms = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.CenterCrop(isize),
        transforms.ToTensor()])
    return data_transforms, gt_transforms


def get_strong_transforms(size, isize, mean_train=None, std_train=None):
    mean_train = [0.485, 0.456, 0.406] if mean_train is None else mean_train
    std_train = [0.229, 0.224, 0.225] if std_train is None else std_train
    data_transforms = transforms.Compose([
        transforms.Resize((size, size)),
        # transforms.RandomRotation(15, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.CenterCrop(isize),
        transforms.Normalize(mean=mean_train,
                             std=std_train)])
    return data_transforms

class IMAGENET30_TEST_DATASET(Dataset):
    def __init__(self, root_dir="PATH_TO_ROOT", transform=None):   # Specify path
        """
        Args:
            root_dir (string): Directory with all the classes.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.img_path_list = []
        self.targets = []

        # Map each class to an index
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(sorted(os.listdir(root_dir)))}
        # print(f"self.class_to_idx in ImageNet30_Test_Dataset:\n{self.class_to_idx}")

        # Walk through the directory and collect information about the images and their labels
        for i, class_name in enumerate(os.listdir(root_dir)):
            class_path = os.path.join(root_dir, class_name)
            for instance_folder in os.listdir(class_path):
                instance_path = os.path.join(class_path, instance_folder)
                if instance_path != "PATH_TO_ROOT/airliner/._1.JPEG":  ##  Remove this if there is no damaged instance
                    for img_name in os.listdir(instance_path):
                        if img_name.endswith('.JPEG'):
                            img_path = os.path.join(instance_path, img_name)
                            # image = Image.open(img_path).convert('RGB')
                            self.img_path_list.append(img_path)
                            self.targets.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.img_path_list)

    def __getitem__(self, idx):
        img_path = self.img_path_list[idx]
        image = default_loader(img_path)
        label = self.targets[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

def center_paste(large_img, small_img):
    # Calculate the center position
    large_width, large_height = large_img.size
    small_width, small_height = small_img.size

    # Calculate the top-left position
    left = (large_width - small_width) // 2
    top = (large_height - small_height) // 2

    # Create a copy of the large image to keep the original unchanged
    result_img = large_img.copy()

    # Paste the small image onto the large one at the calculated position
    result_img.paste(small_img, (left, top))

    return result_img


class MVTecDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform, gt_transform, phase, shrink_factor=None):
        if phase == 'train':
            self.img_path = os.path.join(root, 'train')
        else:
            self.img_path = os.path.join(root, 'test')
            self.gt_path = os.path.join(root, 'ground_truth')
        self.transform = transform
        self.gt_transform = gt_transform
        # load dataset
        self.img_paths, self.gt_paths, self.labels, self.types = self.load_dataset()  # self.labels => good : 0, anomaly : 1
        self.imagenet30_testset = IMAGENET30_TEST_DATASET()
        self.shrink_factor = shrink_factor
        print(f"self.shrink_factor: {self.shrink_factor}")

    def load_dataset(self):

        img_tot_paths = []
        gt_tot_paths = []
        tot_labels = []
        tot_types = []

        defect_types = os.listdir(self.img_path)

        for defect_type in defect_types:
            if defect_type == 'good':
                img_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png") + \
                            glob.glob(os.path.join(self.img_path, defect_type) + "/*.JPG")
                img_tot_paths.extend(img_paths)
                gt_tot_paths.extend([0] * len(img_paths))
                tot_labels.extend([0] * len(img_paths))
                tot_types.extend(['good'] * len(img_paths))
            else:
                img_paths = glob.glob(os.path.join(self.img_path, defect_type) + "/*.png") + \
                            glob.glob(os.path.join(self.img_path, defect_type) + "/*.JPG")
                gt_paths = glob.glob(os.path.join(self.gt_path, defect_type) + "/*.png")
                img_paths.sort()
                gt_paths.sort()
                img_tot_paths.extend(img_paths)
                gt_tot_paths.extend(gt_paths)
                tot_labels.extend([1] * len(img_paths))
                tot_types.extend([defect_type] * len(img_paths))

        assert len(img_tot_paths) == len(gt_tot_paths), "Something wrong with test and ground truth pair!"

        return img_tot_paths, gt_tot_paths, tot_labels, tot_types

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path, gt, label, img_type = self.img_paths[idx], self.gt_paths[idx], self.labels[idx], self.types[idx]
        img = Image.open(img_path).convert('RGB')
        if self.shrink_factor:
            pad_img, _ = self.imagenet30_testset[int(random.random() * len(self.imagenet30_testset))]
            pad_img = pad_img.resize(img.size)

            img = img.resize((int(img.size[0] * self.shrink_factor), int(img.size[1] * self.shrink_factor)))

            img = center_paste(pad_img, img)

        img = self.transform(img)
        if gt == 0:
            gt = torch.zeros([1, img.size()[-2], img.size()[-2]])
        else:
            gt = Image.open(gt)
            gt = self.gt_transform(gt)

        assert img.size()[1:] == gt.size()[1:], "image.size != gt.size !!!"

        return img, gt, label, img_path


def center_paste_2(large_img, small_img, shrink_factor):
    width , height = small_img.size
    large_img = large_img.resize((width, height))

    new_width = int(width * shrink_factor)
    new_height = int(height * shrink_factor)

    small_img = small_img.resize((new_width, new_height))
    small_width, small_height = small_img.size

    left = (width - small_width) // 2
    top = (height - small_height) // 2

    result_img = large_img.copy()
    result_img.paste(small_img, (left, top))
    return result_img


class IMAGENET30_TEST_DATASET_2(Dataset):
    def __init__(self, root_dir="PATH_TO_ROOT", transform=None):
        """
        Args:
            root_dir (string): Directory with all the classes.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.img_path_list = []
        self.targets = []

        # Map each class to an index
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(sorted(os.listdir(root_dir)))}
        # print(f"self.class_to_idx in ImageNet30_Test_Dataset:\n{self.class_to_idx}")

        # Walk through the directory and collect information about the images and their labels
        for i, class_name in enumerate(os.listdir(root_dir)):
            class_path = os.path.join(root_dir, class_name)
            for instance_folder in os.listdir(class_path):
                instance_path = os.path.join(class_path, instance_folder)
                if instance_path != "PATH_TO_ROOT/airliner/._1.JPEG":
                    for img_name in os.listdir(instance_path):
                        if img_name.endswith('.JPEG'):
                            img_path = os.path.join(instance_path, img_name)
                            # image = Image.open(img_path).convert('RGB')
                            self.img_path_list.append(img_path)
                            self.targets.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.img_path_list)

    def __getitem__(self, idx):
        img_path = self.img_path_list[idx]
        label = self.targets[idx]
        img = Image.open(img_path)
        return img

class Train_MVTecDataset(Dataset):
    def __init__(self, root, category, transform=None, shrink_factor=1, train=True, count=-1):
        self.transform = transform
        self.image_files = []
        self.shrink_factor = shrink_factor
        print("category MVTecDataset:", category)
        if train:
            good_images = glob.glob(os.path.join(root, category, "train", "good", "*.png"))
            if count != -1:
                if count < len(good_images):
                    good_images = good_images[:count]
                else:
                    t = len(good_images)
                    for i in range(count - t):
                        good_images.append(random.choice(good_images[:t]))
            good_images.sort(key=lambda y: y.lower())
            augmented_images = []
            for img in good_images:
                augmented_images.append((img, True))  # Mark this image for augmentation
            good_images = [(img, False) for img in good_images]  # Original images not for augmentation
            self.image_files = good_images + augmented_images
        else:
            image_files = glob.glob(os.path.join(root, category, "test", "*", "*.png"))
            normal_image_files = glob.glob(os.path.join(root, category, "test", "good", "*.png"))
            anomaly_image_files = list(set(image_files) - set(normal_image_files))
            self.image_files = image_files
            self.image_files = [(img, False) for img in image_files]
        self.train = train
        self.imagenet_30 = IMAGENET30_TEST_DATASET_2()


    def __getitem__(self, index):
        image_file , augment= self.image_files[index]
        image = Image.open(image_file)
        image = image.convert('RGB')
        
        # if self.train and augment:
        if augment:
            random_index = int(random.random() * len(self.imagenet_30))
            imagenet30_img = self.imagenet_30[random_index]
            imagenet30_img = imagenet30_img.convert('RGB')
            factors = [0.98, 0.95, 0.93, 0.91, 0.88, 0.82, 0.90, 0.97, 0.85, 0.80]
            image  = center_paste_2(imagenet30_img, image, random.choice(factors))
        '''
        if not self.train:
            imagenet_30 = IMAGENET30_TEST_DATASET()
            random_index = int(random.random() * len(imagenet_30))
            imagenet30_img = imagenet_30[random_index]
            imagenet30_img = imagenet30_img.convert('RGB')
            image = center_paste(imagenet30_img, image, self.shrink_factor)
        '''
        if self.transform is not None:
            image = self.transform(image)
        if os.path.dirname(image_file).endswith("good"):
            target = 0
        else:
            target = 1
        return image, target

    def __len__(self):
        return len(self.image_files)




class RSNATRAIN(torch.utils.data.Dataset):
    def __init__(self, transform):
        self.transform = transform
        self.image_paths = glob.glob('PATH_TO_ROOT/train/normal/*')


    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        dicom = pydicom.dcmread(self.image_paths[idx])
        image = dicom.pixel_array

        # Convert to a PIL Image
        image = Image.fromarray(image).convert('RGB')

        # Apply the transform if it's provided
        if self.transform is not None:
            image = self.transform(image)

        return image, 0

class RSNATEST(torch.utils.data.Dataset):
    def __init__(self, transform, test_id=1):

        self.transform = transform
        self.test_id = test_id

        test_normal_path = glob.glob('PATH_TO_ROOT/test/normal/*')
        test_anomaly_path = glob.glob('PATH_TO_ROOT/test/anomaly/*')

        self.test_path = test_normal_path + test_anomaly_path
        self.test_label = [0] * len(test_normal_path) + [1] * len(test_anomaly_path)

        if self.test_id == 2:
            shifted_test_normal_path = glob.glob('PATH_TO_SHIFTED/Test/1/*')
            shifted_test_anomaly_path = (glob.glob('PATH_TO_SHIFTED/4. Operations Department/Test/0/*') + glob.glob(
                'PATH_TO_SHIFTED/Test/2/*') +
                glob.glob('PATH_TO_SHIFTED/4. Operations Department/Test/3/*'))

            self.test_path = shifted_test_normal_path + shifted_test_anomaly_path
            self.test_label = [0] * len(shifted_test_normal_path) + [1] * len(shifted_test_anomaly_path)


        if self.test_id == 3:
            test_normal_path = glob.glob('PATH_TO_SHIFTED2/test/NORMAL/*')
            test_anomaly_path = glob.glob('PATH_TO_SHIFTED2/test/PNEUMONIA/*')

            self.test_path = test_normal_path + test_anomaly_path
            self.test_label = [0] * len(test_normal_path) + [1] * len(test_anomaly_path)

    def __len__(self):
        return len(self.test_path)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        if self.test_id == 1:
            dicom = pydicom.dcmread(self.test_path[idx])
            image = dicom.pixel_array

            # Convert to a PIL Image
            image = Image.fromarray(image).convert('RGB')

            # Apply the transform if it's provided
            if self.transform is not None:
                image = self.transform(image)

            gt = torch.zeros([1, image.size()[-2], image.size()[-2]])
            gt[:, :, 1:3] = 1
            has_anomaly = 0 if self.test_label[idx] == 0 else 1
            return image, gt, has_anomaly, self.test_path[idx]


        img_path = self.test_path[idx]
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)

        has_anomaly = 0 if self.test_label[idx] == 0 else 1

        gt = torch.zeros([1, img.size()[-2], img.size()[-2]])
        gt[:, :, 1:3] = 1
        return img, gt, has_anomaly, img_path



class Train_Visa(torch.utils.data.Dataset):
    def __init__(self, root, transform=None, imagenet_percent=0.05, count=None):

        self.transform = transform
        self.imagenet_percent = imagenet_percent
        self.img_paths = glob.glob(root)
        if count:
            if count < len(self.img_paths):
                self.img_paths = self.img_paths[:count]
            else:
                t = len(self.img_paths)
                for i in range(count-t):
                    self.img_paths.append(random.choice(self.img_paths[:t]))

        self.labels = [0]*len(self.img_paths)
        print("len(Train_Visa)", len(self.img_paths))
        self.imagenet_30 = IMAGENET30_TEST_DATASET()

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path, label = self.img_paths[idx], self.labels[idx]
        image = Image.open(img_path).convert('RGB')
        r = random.uniform(0, 1)
        if r < self.imagenet_percent:
          random_index = int(random.random() * len(self.imagenet_30))
          imagenet30_img, _ = self.imagenet_30[random_index]
          imagenet30_img = imagenet30_img.convert('RGB')
          factors = [0.98, 0.95, 0.93, 0.91, 0.88, 0.82, 0.90, 0.97, 0.85, 0.80]
          image  = center_paste_2(imagenet30_img, image, random.choice(factors))

        image = self.transform(image)
        return image, label

