from common_imports import datasets, transforms
from common_use_functions import path_join

def get_imagenet_dataset_without_transform(data_path, normalize=True, norm_params=[[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]):
    """
    This function gets the imagenet-1k dataset without any data augmentation.

    data_path: The path for folder containing the imagenet dataset.
    normalize: Boolean determines if we apply the normalization.
    norm_params: The parameter for the normalization. (Not used when normalize=False)
    """
    transform_largescale = None
    if normalize:
        transform_largescale = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=norm_params[0],
                                std=norm_params[1]),
        ])
    else:
        transform_largescale = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ])

    train_dataset= datasets.ImageFolder(
            path_join(data_path, "train"),
            transform=transform_largescale,
        )
    
    # test_dataset= datasets.ImageFolder(
    #         path_join(data_path, "val"),
    #         transform=transform_largescale,
    #     )

    test_dataset= datasets.ImageFolder(
            path_join(data_path, "reform_val"),
            transform=transform_largescale,
        )
    
    return train_dataset, test_dataset

def get_imagenet_dataset_without_transform_for_regnet(data_path, normalize=True, norm_params=[[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]):
    """
    This function gets the imagenet-1k dataset without any data augmentation.

    data_path: The path for folder containing the imagenet dataset.
    normalize: Boolean determines if we apply the normalization.
    norm_params: The parameter for the normalization. (Not used when normalize=False)
    """
    transform_largescale = None
    if normalize:
        transform_largescale = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=norm_params[0],
                                std=norm_params[1]),
        ])
    else:
        transform_largescale = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
        ])

    train_dataset= datasets.ImageFolder(
            path_join(data_path, "train"),
            transform=transform_largescale,
        )
    
    # test_dataset= datasets.ImageFolder(
    #         path_join(data_path, "val"),
    #         transform=transform_largescale,
    #     )

    test_dataset= datasets.ImageFolder(
            path_join(data_path, "reform_val"),
            transform=transform_largescale,
        )
    
    return train_dataset, test_dataset


def get_imagenet_dataset_without_transform_for_vit(data_path, normalize=True, norm_params=[[0.485, 0.456, 0.406], [0.229, 0.224, 0.225]]):
    """
    This function gets the imagenet-1k dataset without any data augmentation.

    data_path: The path for folder containing the imagenet dataset.
    normalize: Boolean determines if we apply the normalization.
    norm_params: The parameter for the normalization. (Not used when normalize=False)
    """
    transform_largescale = None
    if normalize:
        transform_largescale = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
            transforms.Normalize(mean=norm_params[0],
                                std=norm_params[1]),
        ])
    else:
        transform_largescale = transforms.Compose([
            transforms.Resize(384, interpolation=transforms.InterpolationMode.BICUBIC),
            transforms.CenterCrop(384),
            transforms.ToTensor(),
        ])

    train_dataset= datasets.ImageFolder(
            path_join(data_path, "train"),
            transform=transform_largescale,
        )
    
    # test_dataset= datasets.ImageFolder(
    #         path_join(data_path, "val"),
    #         transform=transform_largescale,
    #     )

    test_dataset= datasets.ImageFolder(
            path_join(data_path, "reform_val"),
            transform=transform_largescale,
        )
    
    return train_dataset, test_dataset