# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import os
from torchvision import datasets, transforms

from timm.data.constants import \
    IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.data import create_transform
import torch
from torchvision.datasets.folder import make_dataset
from torchvision.datasets.utils import list_dir
from torchvision import transforms
from PIL import Image
import random

class RemoveHighFreqTransform:
    def __init__(self, threshold):
        self.threshold = threshold

    def __call__(self, img):
        img_tensor = transforms.functional.to_tensor(img)
        freq_domain = torch.fft.fftn(img_tensor, dim=(-2, -1))
        freq_domain = torch.fft.fftshift(freq_domain, dim=(-2, -1))
        
        c, h, w = freq_domain.shape
        center_x, center_y = h // 2, w // 2
        mask = torch.zeros_like(freq_domain)
        mask[:, center_x-self.threshold:center_x+self.threshold, center_y-self.threshold:center_y+self.threshold] = 1
        
        low_freq_domain = freq_domain * mask
        low_freq_domain = torch.fft.ifftshift(low_freq_domain, dim=(-2, -1))
        img_filtered = torch.fft.ifftn(low_freq_domain, dim=(-2, -1)).real
        
        return transforms.functional.to_pil_image(img_filtered.clamp(0, 1))
        
    
class FilteredImageFolder(datasets.ImageFolder):
    def __init__(self, root, selected_folders, transform=None):
        super().__init__(root, transform=transform)
        self.selected_folders = selected_folders
        
        self.classes, self.class_to_idx = self._find_classes(self.root)
        
        samples = make_dataset(self.root, self.class_to_idx, extensions=datasets.folder.IMG_EXTENSIONS, is_valid_file=None)
        self.samples = [sample for sample in samples if any(sf in sample[0] for sf in self.selected_folders)]
        self.targets = [s[1] for s in self.samples]
        self.imgs = self.samples

    def _find_classes(self, dir):
        classes = [d.name for d in os.scandir(dir) if d.is_dir() and d.name in self.selected_folders]
        classes.sort()
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

class FixedRandomCenterCrop(transforms.CenterCrop):
    def __init__(self, output_size, image_size, center_x=None, center_y=None):
        super().__init__(output_size)
        self.output_size = (output_size, output_size) if isinstance(output_size, int) else output_size
        self.image_size = (image_size, image_size) if isinstance(image_size, int) else image_size
        # Randomly choose center position (once, during initialization)
        if center_x is None and center_y is None:
            self.center_x = random.uniform(self.output_size[0] / 2, self.image_size[0] - self.output_size[0] / 2)
            self.center_y = random.uniform(self.output_size[1] / 2, self.image_size[1] - self.output_size[1] / 2)
        else:
            self.center_x = center_x
            self.center_y = center_y

        print(f"Random Center crop: {self.center_x}, {self.center_y}")

    def __call__(self, img):
        """
        img (PIL Image): Image to be cropped.
        img should have the size specified by self.image_size
        """
        # Calculate the top-left corner based on the random center
        left = int(self.center_x - self.output_size[0] / 2)
        top = int(self.center_y - self.output_size[1] / 2)
        right = left + self.output_size[0]
        bottom = top + self.output_size[1]

        return img.crop((left, top, right, bottom))

def build_dataset(is_train, args, random_setting):
    transform = build_transform(is_train, args, random_setting)

    print("Transform = ")
    if isinstance(transform, tuple):
        for trans in transform:
            print(" - - - - - - - - - - ")
            for t in trans.transforms:
                print(t)
    else:
        for t in transform.transforms:
            print(t)
    print("---------------------------")

    if args.data_set == 'CIFAR':
        dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform, download=True)
        nb_classes = 100
    elif args.data_set == 'IMNET':
        print("reading from datapath", args.data_path)
        root = os.path.join(args.data_path, 'train' if is_train else 'val')
        dataset = datasets.ImageFolder(root, transform=transform)
        nb_classes = 1000
    elif args.data_set == "image_folder":
        root = args.data_path if is_train else args.eval_data_path
        if args.selected_folders is not None:
            dataset = FilteredImageFolder(root, args.selected_folders, transform=transform)
            print(f"Using {len(dataset)} images from {args.selected_folders} classes")
        else:
            dataset = datasets.ImageFolder(root, transform=transform)
        nb_classes = args.nb_classes
        assert len(dataset.class_to_idx) == nb_classes
    else:
        raise NotImplementedError()
    print("Number of the class = %d" % nb_classes)

    return dataset, nb_classes


def build_transform(is_train, args, random_setting):
    resize_im = args.input_size > 32
    imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
    mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
    std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD

    if args.pre_random_center_crop:
        (random_output_size, random_image_size, random_center_x, random_center_y) = random_setting

    if is_train:
        # this should always dispatch to transforms_imagenet_train
        transform = create_transform(
            input_size=args.input_size,
            is_training=True,
            color_jitter=args.color_jitter,
            auto_augment=args.aa,
            interpolation=args.train_interpolation,
            re_prob=args.reprob,
            re_mode=args.remode,
            re_count=args.recount,
            mean=mean,
            std=std,
        )
        if not resize_im:
            transform.transforms[0] = transforms.RandomCrop(
                args.input_size, padding=4)
        
        if args.low_filter:
            low_filter_transform = RemoveHighFreqTransform(threshold=args.low_filter_threshold)
            transform = transforms.Compose([low_filter_transform] + transform.transforms)
            print(f"training Low-frequency filter with threshold {args.low_filter_threshold}...")

        # pre_center_crop
        if args.pre_center_crop:
            transform = transforms.Compose([transforms.CenterCrop(args.pre_center_crop_size)] + transform.transforms)
            print(f"training Pre-centering crop {args.pre_center_crop_size} size input images...")
        elif args.pre_random_center_crop:
            transform = transforms.Compose([FixedRandomCenterCrop(center_x=random_center_x, center_y=random_center_y, output_size=random_output_size, image_size=random_image_size)]+ transform.transforms)
            print(f"training Pre-random_centering crop {args.pre_random_center_crop_size} size input images...")

        return transform

    t = []
    if resize_im:
        if args.low_filter:
            t.append(RemoveHighFreqTransform(threshold=args.low_filter_threshold))
            print(f"validation Low-frequency filter with threshold {args.low_filter_threshold}...")

        if args.pre_center_crop:
            t.append(transforms.CenterCrop(args.pre_center_crop_size))
            print(f"validation Pre-centering crop {args.pre_center_crop_size} size input images...")
        elif args.pre_random_center_crop:
            t.append(FixedRandomCenterCrop(center_x=random_center_x, center_y=random_center_y, output_size=random_output_size, image_size=random_image_size))
            print(f"validation Pre-random_centering crop {args.pre_random_center_crop_size} size input images...")
          
        # warping (no cropping) when evaluated at 384 or larger
        if args.input_size >= 384:  
            t.append(
            transforms.Resize((args.input_size, args.input_size), 
                            interpolation=transforms.InterpolationMode.BICUBIC), 
        )
            print(f"Warping {args.input_size} size input images...")
        else: 
            if args.crop_pct is None:
                args.crop_pct = 224 / 256
            size = int(args.input_size / args.crop_pct)
            t.append(
                # to maintain same ratio w.r.t. 224 images
                transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC),  
            )
            t.append(transforms.CenterCrop(args.input_size))

    t.append(transforms.ToTensor())
    t.append(transforms.Normalize(mean, std))
    return transforms.Compose(t)
