import os
import torchvision
from torchvision import transforms
from ddg import DDG
from torchvision.datasets import ImageFolder
from torch.utils.data import ConcatDataset

def get_dataset(configs):
    dataset_name = configs['dataset']
    print(f'==> Preparing {dataset_name} data..')
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    normalize = transforms.Normalize(mean=mean, std=std)

    resize_shape = 512  # Example resize shape, adjust as needed
    crop_shape = 448    # Example crop shape, adjust as needed
    train_transform = transforms.Compose([
        transforms.Resize(resize_shape),
        transforms.RandomCrop(crop_shape),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    test_transform = transforms.Compose([
        transforms.Resize(resize_shape),
        transforms.CenterCrop(crop_shape),
        transforms.ToTensor(),
        normalize,
    ])
    
    dataset_configs = {
    'car': {
        'num_classes': 196,
        'orig': "/path/to/StandfordCar",
        'cdp_and_cip': "path/to/icml24/lang-segment-anything/fore_and_back/StandfordCar",
        'syncdp': "path/to/icml24/LayerDiffuse_DiffusersCLI/generated_fore/StandfordCar",
        'diffusemix_dir': 'path/to/icml24/diffuseMix/result-car/blended',
        'dafusion_dir':'path/to/icml24/textual_inversion_sdxl/generated_images/StandfordCar',
    },
    'cub': {
        'num_classes': 200,
        'orig': "/path/to/CUB_200_2011",
        'cdp_and_cip': "path/to/iclr25/sam_segment/cdp_and_cip/CUB_200_2011",
        'syncdp': "path/to/iclr25/inversion/generated_fore/CUB_200_2011-0.4-1",
        'diffusemix_dir': 'path/to/icml24/diffuseMix/result-cub/blended',
        'dafusion_dir':'path/to/icml24/textual_inversion_sdxl/generated_images/CUB_200_2011',
    },
    'aircraft': {
        'num_classes': 100,
        'orig': "/path/to/Aircraft",
        'cdp_and_cip': "path/to/icml24/lang-segment-anything/fore_and_back/Aircraft",
        'syncdp': "path/to/icml24/LayerDiffuse_DiffusersCLI/generated_fore/Aircraft",
        'diffusemix_dir': 'path/to/icml24/diffuseMix/result-aircraft/blended',
        'dafusion_dir':'path/to/icml24/textual_inversion_sdxl/generated_images/Aircraft',
    },
    }
    
    if dataset_name not in dataset_configs:
        raise ValueError("Dataset not supported")

    paths = dataset_configs[dataset_name]
    test_data = ImageFolder(root=os.path.join(paths['orig'], "test"), transform=test_transform)

    # if configs['train_mode'] == 'vanilla':
    #     train_data = ImageFolder(root=os.path.join(paths['orig'], "train"), transform=train_transform)
    
    # # root_orig, root_cdp, root_cip, root_syn_cdp=None, prob_aug=0.5, prob_syn=0.25, prob_mix=0.5, num_syn=3, transform=None, strength=None, beta_alpha=1.0
    # elif configs['train_mode'] == 'ddg':
    #     train_data = DDG(
    #         root_orig=os.path.join(paths['orig'], "train"),
    #         prob_aug=configs['prob_aug'],
    #         root_cdp=os.path.join(paths['cdp_and_cip'], "cdp"),
    #         root_cip=os.path.join(paths['cdp_and_cip'], "cip_pad"),
    #         root_syncdp=paths['syncdp'],
    #         prob_syn=configs['prob_syn'],
    #         prob_mix=configs['prob_mix'],
    #         num_syn=configs['num_syn'],
    #         transform=train_transform
    #     )
# root_cdp, root_cip, root_syncdp

    train_data = ImageFolder(root=os.path.join(paths['orig'], "train"), transform=train_transform)
    return train_data, test_data, dataset_configs[dataset_name]['num_classes']