from os.path import join, dirname

import torch
from torch.utils.data import DataLoader
from torchvision import transforms

from data.concat_dataset import ConcatDataset
from data.Loader import CustomDataset, TestCustomDataset, get_split_dataset_info, _dataset_info

class Subset(torch.utils.data.Dataset):
    def __init__(self, dataset, limit):
        indices = torch.randperm(len(dataset))[:limit]
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

def get_train_dataloader(args, patches):
    dataset_list = args.source
    assert isinstance(dataset_list, list)
    datasets = []
    val_datasets = []
    img_transformer = get_train_transformers(args)
    for dname in dataset_list:
        if args.dataset == 'PACS':
            name_train, labels_train = _dataset_info(join(dirname(__file__), 'pacs/splits', '%s_train_kfold.txt' % dname))
            
            name_val, labels_val = _dataset_info(join(dirname(__file__), 'pacs/splits', '%s_crossval_kfold.txt' % dname))
        elif args.dataset == 'OfficeHome':
            name_train, labels_train = _dataset_info(join(dirname(__file__), 'OfficeHome', '%s_train.txt' % dname))
            
            name_val, labels_val = _dataset_info(join(dirname(__file__), 'OfficeHome', '%s_test.txt' % dname))      
        else:
            name_train, labels_train = _dataset_info(join(dirname(__file__), 'VLCS/Train val splits', '%s_train_kfold.txt' % dname))
            
            name_val, labels_val = _dataset_info(join(dirname(__file__), 'VLCS/Train val splits', '%s_crossval_kfold.txt' % dname))

        train_dataset = CustomDataset(name_train, labels_train, patches=patches, img_transformer=img_transformer)
        datasets.append(train_dataset)

        val_datasets.append(TestCustomDataset(name_val, labels_val, img_transformer=get_val_transformer(args)))
    dataset = ConcatDataset(datasets)
    val_dataset = ConcatDataset(val_datasets)

    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader, val_loader


def get_val_dataloader(args, patches=False):
    if args.dataset == 'PACS':
        names, labels = _dataset_info(join(dirname(__file__), 'pacs/splits', '%s_test_kfold.txt' % args.target))
    elif args.dataset == 'OfficeHome':
        names, labels = _dataset_info(join(dirname(__file__), 'OfficeHome', '%s_test.txt' % args.target))
    else:
        names, labels = _dataset_info(join(dirname(__file__), 'VLCS/Train val splits', '%s_test_kfold.txt' % args.target))

    img_tr = get_val_transformer(args)
    val_dataset = TestCustomDataset(names, labels, patches=patches, img_transformer=img_tr)
    dataset = ConcatDataset([val_dataset])
    loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=False)
    return loader


def get_train_transformers(args):
    
    # Follow  augmentation protocol in "In Search of Lost Domain Generalization" gulrajani ICLR 2020
    # https://github.com/facebookresearch/DomainBed/blob/main/domainbed/datasets.py
    augment_transform = transforms.Compose([
            transforms.RandomResizedCrop(args.image_size, scale=(0.7, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
            transforms.RandomGrayscale(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
    return augment_transform


def get_val_transformer(args):
    img_tr = [transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
    return transforms.Compose(img_tr)


