import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from Dataloader_funcs.DL_registry import *
from torchvision.transforms import RandAugment, RandomResizedCrop, v2, ColorJitter,RandomHorizontalFlip
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from Dataloader_funcs.utils import ThreeAugment

@register_DL('CIFAR100')
def CIFAR100(root_dir, batch_size=32, num_workers=4, val_split=.5, test_transform=None, train_transform=None):
    if train_transform == None:
        train_transform = transforms.Compose([
            RandomHorizontalFlip(),
            ColorJitter(.3,.3,.3,.3),
            RandomResizedCrop(32),
            # RandAugment(num_ops=2, magnitude=14),
            ThreeAugment(),
            # v2.ToDtype(torch.float32, scale=True),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))  # CIFAR-100 mean/std
        ])
    if test_transform == None:
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
        ])

    

    # Load training and validation datasets
    train_dataset = datasets.CIFAR100(root=root_dir, train=True, download=True, transform=train_transform)
    full_testset = datasets.CIFAR100(root=root_dir, train=False, download=True, transform=test_transform)

    val_size = int(len(full_testset) * val_split)
    test_size = len(full_testset) - val_size
    val_dataset, test_dataset = random_split(full_testset, [val_size, test_size], generator=torch.Generator().manual_seed(42))

    # DataLoader for batching
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, val_loader, test_loader

@register_DL('Distributed CIFAR100')
def get_distributed_cifar100_loader(data_path, batch_size, num_workers, sampler=False, rank=0, world_size=1):
    train_transform = transforms.Compose([
        RandAugment(num_ops=2, magnitude=14),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))  # CIFAR-100 mean/std
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762))
    ])

    train_dataset = datasets.CIFAR100(root=data_path, train=True, download=True, transform=train_transform)
    val_dataset = datasets.CIFAR100(root=data_path, train=False, download=True, transform=test_transform)

    if sampler:
        train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
        val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
    else:
        train_sampler = None
        val_sampler = None

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
                              num_workers=num_workers, sampler=train_sampler, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                            num_workers=num_workers, sampler=val_sampler, pin_memory=True)

    return train_loader, val_loader