import os
import torch
import torch.distributed as dist
from torch.utils.data import random_split, Subset
from torchvision import datasets, transforms
from torchvision.transforms import InterpolationMode, AutoAugment, AutoAugmentPolicy

THIS_PATH = os.path.dirname(os.path.abspath(__file__))
ROOT_PATH = os.path.abspath(os.path.join(THIS_PATH, '..', '..', '..', '..', '..'))
DEFAULT_DATA_ROOT = os.path.join(ROOT_PATH, 'datasets/disentanglement')

def is_ddp():
    return dist.is_available() and dist.is_initialized()


def is_main_process():
    return not is_ddp() or dist.get_rank() == 0


DATASET_STATS = {
    "cifar10": {"mean": (0.4914, 0.4822, 0.4465), "std": (0.2470, 0.2435, 0.2616), "num_classes": 10},
    "cifar100": {"mean": (0.5071, 0.4867, 0.4408), "std": (0.2675, 0.2565, 0.2761), "num_classes": 100},
    "flowers102": {"mean": (0.434, 0.385, 0.296), "std": (0.294, 0.228, 0.261), "num_classes": 102},
    "stanfordcars": {"mean": (0.470, 0.460, 0.455), "std": (0.276, 0.268, 0.274), "num_classes": 196},
    "stl10": {"mean": (0.4467, 0.4398, 0.4066), "std": (0.2603, 0.2566, 0.2713), "num_classes": 10},
    "food101": {"mean": (0.545, 0.443, 0.324), "std": (0.268, 0.259, 0.282), "num_classes": 101},
    "caltech101": {"mean": (0.547, 0.533, 0.504), "std": (0.289, 0.287, 0.318), "num_classes": 101},
    "oxfordiiitpet": {"mean": (0.478, 0.445, 0.396), "std": (0.269, 0.261, 0.274), "num_classes": 37},
    "fgvcaircraft": {"mean": (0.485, 0.491, 0.446), "std": (0.232, 0.230, 0.261), "num_classes": 100},
}

always_hue_shift = transforms.ColorJitter(
    brightness=0,
    contrast=0,
    saturation=0,
    hue=0.5)

geometric_transforms = [
    transforms.RandomRotation(degrees=180),
]



def _build_transforms(name, augment=True, normalize=True, image_size=224, test_augmentation=False):

    stats = DATASET_STATS[name]
    mean, std = stats["mean"], stats["std"]

    train_aug, eval_tf = [], []
    if augment:
        train_aug = [
            transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0), interpolation=InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(0.5),
            AutoAugment(AutoAugmentPolicy.IMAGENET),
        ]

    train_tf = transforms.Compose(train_aug + [
        transforms.ToTensor(),
        transforms.Normalize(mean, std) if normalize else transforms.Lambda(lambda x: x),
    ])

    if test_augmentation:
        eval_tf = [
            transforms.Resize(int(image_size * 256 / 224), interpolation=InterpolationMode.BICUBIC),
            transforms.CenterCrop(image_size),
            always_hue_shift,
            transforms.RandomChoice(geometric_transforms),
        ]
    else:
        eval_tf = [
        transforms.Resize(int(image_size * 256 / 224), interpolation=InterpolationMode.BICUBIC),
        transforms.CenterCrop(image_size),
        ]
    eval_tf += [
        transforms.ToTensor(),
        transforms.Normalize(mean, std) if normalize else transforms.Lambda(lambda x: x),
    ]
    eval_tf = transforms.Compose(eval_tf)

    return train_tf, eval_tf

def _safe_download_wrapper(dataset_creation_fn):

    try:
        return dataset_creation_fn(download=False)
    except RuntimeError as e:
        msg = str(e).lower()
        if "not found" not in msg and "corrupted" not in msg:
            raise e

    if is_main_process():
        print(f"Dataset not found. Downloading on main process...")

        _ = dataset_creation_fn(download=True, transform_override=None)

    if is_ddp():
        dist.barrier()

    return dataset_creation_fn(download=False)

def get_cls_factory_dataloader(
        name,
        root=DEFAULT_DATA_ROOT,
        val_ratio=0.1,
        augment=True,
        normalize=True,
        seed=42,
        image_size=224,
        test_aug=False
):

    name = name.lower()
    if name not in DATASET_STATS:
        raise ValueError(f"Dataset '{name}' is not supported.")

    dataset_root = os.path.join(root, name)

    if is_main_process():
        os.makedirs(dataset_root, exist_ok=True)
    if is_ddp():
        dist.barrier()

    train_tf, eval_tf = _build_transforms(name, augment, normalize, image_size, test_augmentation=test_aug)

    if name == "cifar100":
        train_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.CIFAR100(
                dataset_root, train=True, download=download,
                transform=kwargs.get('transform_override', train_tf)
            )
        )
        test_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.CIFAR100(
                dataset_root, train=False, download=download,
                transform=kwargs.get('transform_override', eval_tf)
            )
        )
        # val_ds = None
    elif name == "cifar10":
        train_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.CIFAR10(
                dataset_root, train=True, download=download,
                transform=kwargs.get('transform_override', train_tf)
            )
        )
        test_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.CIFAR10(
                dataset_root, train=False, download=download,
                transform=kwargs.get('transform_override', eval_tf)
            )
        )
    elif name == "flowers102":
        train_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.Flowers102(
                dataset_root, split="train", download=download,
                transform=kwargs.get('transform_override', train_tf)
            )
        )
        val_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.Flowers102(
                dataset_root, split="val", download=download,
                transform=kwargs.get('transform_override', eval_tf)
            )
        )
        test_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.Flowers102(
                dataset_root, split="test", download=download,
                transform=kwargs.get('transform_override', eval_tf)
            )
        )

    elif name == "stanfordcars":
        train_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.StanfordCars(
                dataset_root, split="train", download=download,
                transform=kwargs.get('transform_override', train_tf)
            )
        )
        test_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.StanfordCars(
                dataset_root, split="test", download=download,
                transform=kwargs.get('transform_override', eval_tf)
            )
        )
        val_ds = None

    elif name == "stl10":
        train_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.STL10(
                dataset_root, split="train", download=download,
                transform=kwargs.get('transform_override', train_tf)
            )
        )
        test_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.STL10(
                dataset_root, split="test", download=download,
                transform=kwargs.get('transform_override', eval_tf)
            )
        )
        val_ds = None

    elif name == "food101":
        train_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.Food101(
                dataset_root, split="train", download=download,
                transform=kwargs.get('transform_override', train_tf)
            )
        )
        test_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.Food101(
                dataset_root, split="test", download=download,
                transform=kwargs.get('transform_override', eval_tf)
            )
        )
        val_ds = None

    elif name == "caltech101":
        # Caltech101은 split이 없으므로, 다운로드 후 직접 분할합니다.
        full_dataset_raw = _safe_download_wrapper(
            lambda download, **kwargs: datasets.Caltech101(
                dataset_root, download=download,
                transform=kwargs.get('transform_override', None)
            )
        )

        num_total = len(full_dataset_raw)
        num_val = int(val_ratio * num_total)
        num_train = num_total - num_val
        generator = torch.Generator().manual_seed(seed)
        train_indices, val_indices = random_split(range(num_total), [num_train, num_val], generator=generator)

        train_ds = Subset(datasets.Caltech101(dataset_root, download=False, transform=train_tf), train_indices)
        val_ds = Subset(datasets.Caltech101(dataset_root, download=False, transform=eval_tf), val_indices)
        test_ds = val_ds  # 별도의 test set이 없으므로 val set을 test set으로 사용

    elif name == "oxfordiiitpet":
        train_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.OxfordIIITPet(
                dataset_root, split="trainval", download=download,
                transform=kwargs.get('transform_override', train_tf)
            )
        )
        test_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.OxfordIIITPet(
                dataset_root, split="test", download=download,
                transform=kwargs.get('transform_override', eval_tf)
            )
        )
        val_ds = None

    elif name == "fgvcaircraft":
        train_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.FGVCAircraft(
                dataset_root, split="train", download=download,
                transform=kwargs.get('transform_override', train_tf)
            )
        )
        val_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.FGVCAircraft(
                dataset_root, split="val", download=download,
                transform=kwargs.get('transform_override', eval_tf)
            )
        )
        test_ds = _safe_download_wrapper(
            lambda download, **kwargs: datasets.FGVCAircraft(
                dataset_root, split="test", download=download,
                transform=kwargs.get('transform_override', eval_tf)
            )
        )

    else:
        raise ValueError(f"{name} dataloader not implemented!")

    num_classes = DATASET_STATS[name]["num_classes"]

    return train_ds, test_ds
