import os
import torch
from torch.utils.data import Subset, random_split
from torchvision import datasets, transforms
from torchvision.transforms import InterpolationMode, AutoAugment, AutoAugmentPolicy
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

THIS_PATH = os.path.dirname(__file__)
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '..', '..', '..', '..', '..'))
EUROSAT_PATH = os.path.join(ROOT_PATH, 'datasets/disentanglement/eurosat')

DATASET_STATS = {
    "eurosat": {"mean": (0.3444, 0.3803, 0.4078), "std": (0.2034, 0.1367, 0.1158), "num_classes": 10},
}
always_hue_shift = transforms.ColorJitter(
    brightness=0,
    contrast=0,
    saturation=0,
    hue=0.5)

geometric_transforms = [
    transforms.RandomRotation(degrees=180),
]

def _build_eurosat_transforms(name="eurosat", augment=False, normalize=True, image_size=64, test_augmentation=False):
    stats = DATASET_STATS[name]


    train_aug_list, test_tf = [], []
    if augment:
        train_aug_list = [
            transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0),
                                         interpolation=InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(0.5),
            AutoAugment(AutoAugmentPolicy.IMAGENET),
        ]

    common_list = [
        transforms.ToTensor(),
    ]


    if normalize:
        common_list.append(transforms.Normalize(stats["mean"], stats["std"]))
    train_tf = transforms.Compose(train_aug_list + common_list)
    if test_augmentation:
        test_tf = [
            transforms.Resize(int(image_size * 256 / 224), interpolation=InterpolationMode.BICUBIC),
            transforms.CenterCrop(image_size),
            always_hue_shift,
            transforms.RandomChoice(geometric_transforms),
        ]
    else:
        test_tf = [
        transforms.Resize(int(image_size * 256 / 224), interpolation=InterpolationMode.BICUBIC),
        transforms.CenterCrop(image_size),
        ]
    test_tf += [
        transforms.ToTensor(),
        transforms.Normalize(stats["mean"], stats["std"]) if normalize else transforms.Lambda(lambda x: x),
    ]
    test_tf = transforms.Compose(test_tf)

    return train_tf, test_tf


def get_eurosat_loaders(
        name="eurosat",
        root=EUROSAT_PATH,
        batch_size=128,
        num_workers=4,
        val_ratio=0.1,
        train_ratio=0.7,
        augment=True,
        normalize=True,
        download=True,
        pin_memory=True,
        persistent_workers=True,
        seed=42,
        image_size=64,
        test_aug=False
):

    if os.path.exists(root) and len(os.listdir(os.path.join(root, "eurosat"))) > 0:
        download = False
    name = name.lower()

    train_tf, test_tf = _build_eurosat_transforms(name, augment=augment, normalize=normalize, image_size=image_size, test_augmentation=test_aug)

    full_dataset_raw = datasets.EuroSAT(root=root, download=download, transform=None)

    class_names = full_dataset_raw.classes
    num_classes = DATASET_STATS[name]["num_classes"]


    num_total = len(full_dataset_raw)
    num_val = int(val_ratio * num_total)
    num_train = int(train_ratio * num_total)
    num_test = num_total - num_train - num_val

    generator = torch.Generator().manual_seed(seed)
    train_indices, val_indices, test_indices = random_split(
        range(num_total), [num_train, num_val, num_test], generator=generator
    )

    train_set = Subset(datasets.EuroSAT(root=root, download=False, transform=train_tf), train_indices)
    val_set = Subset(datasets.EuroSAT(root=root, download=False, transform=test_tf), val_indices)
    test_set = Subset(datasets.EuroSAT(root=root, download=False, transform=test_tf), test_indices)

    loader_args = {
        "batch_size": batch_size,
        "num_workers": num_workers,
        "pin_memory": pin_memory,
        "persistent_workers": persistent_workers if num_workers > 0 else False
    }

    return train_set, test_set