from collections import Counter
from PIL import Image
import os

from torch.utils.data import Dataset

from unlabeled_extrapolation.datasets import breeds, domainnet

def fill_default_data_path(args):
    if args.data_path is None:
        if args.dataset_name == 'breeds':
            args.data_path = os.environ.get('BREEDS_ROOT')
        elif args.dataset_name == 'domainnet':
            args.data_path = os.environ.get('DOMAINNET_ROOT')
        else:
            raise NotImplementedError()

class CustomDataset(Dataset):
    def __init__(
        self,
        ds1, ds2,
        dataset_name, data_path, data_attr_name,
        transform,
    ):
        super().__init__()
        ds1_samples = getattr(ds1, data_attr_name)
        ds2_samples = getattr(ds2, data_attr_name)
        self.samples = [(item[0], 0) for item in ds1_samples] + \
            [(item[0], 1) for item in ds2_samples]
        self._transform = transform
        self._dataset_name = dataset_name
        self._data_path = data_path # only necessary for domainnet

    def __getitem__(self, i):
        if self._dataset_name == 'breeds':
            path, y = self.samples[i]
            x = Image.open(path)
        elif self._dataset_name == 'domainnet':
            path, y = self.samples[i]
            x = Image.open(os.path.join(self._data_path, path))
            y = int(y)
        else:
            raise NotImplementedError()
        x = x.convert('RGB')
        x = self._transform(x)
        return x, y

    def __len__(self):
        return len(self.samples)

def filter_to_single_class(dataset, class_to_use, data_attr_name):
    setattr(
        dataset, data_attr_name,
        list(filter(lambda item: int(item[1]) == class_to_use, getattr(dataset, data_attr_name)))
    )

def get_class_datasets(dataset_name, data_path, domain, version,
                       transform, class_1, class_2):
    if dataset_name == 'breeds':
        data_attr = '_image_paths_by_class'
        use_source = (version == 'source')
        train_1 = breeds.Breeds(
            data_path, domain, source=use_source, target=(not use_source), split='train')
        test_1 = breeds.Breeds(
            data_path, domain, source=use_source, target=(not use_source), split='val')
        train_2 = breeds.Breeds(
            data_path, domain, source=use_source, target=(not use_source), split='train')
        test_2 = breeds.Breeds(
            data_path, domain, source=use_source, target=(not use_source), split='val')
    elif dataset_name == 'domainnet':
        data_attr = 'data'
        train_1 = domainnet.DomainNet(domain, split='train', root=data_path, version=version)
        test_1 = domainnet.DomainNet(domain, split='test', root=data_path, version=version)
        train_2 = domainnet.DomainNet(domain, split='train', root=data_path, version=version)
        test_2 = domainnet.DomainNet(domain, split='test', root=data_path, version=version)
    else:
        raise ValueError(f'Unsupported dataset: {dataset_name}.')
    filter_to_single_class(train_1, class_1, data_attr)
    filter_to_single_class(test_1, class_1, data_attr)
    filter_to_single_class(train_2, class_2, data_attr)
    filter_to_single_class(test_2, class_2, data_attr)
    print(f'Class {class_1} contains {len(train_1)} training and {len(test_1)} testing samples.')
    print(f'Class {class_2} contains {len(train_2)} training and {len(test_2)} testing samples.')
    train_ds = CustomDataset(
        train_1, train_2, dataset_name, data_path, data_attr, transform)
    test_ds = CustomDataset(
        test_1, test_2, dataset_name, data_path, data_attr, transform)
    return train_ds, test_ds

def get_diff_domain_datasets(dataset_name, data_path, domain_1, domain_2, version,
                             transform, class_1, class_2):
    if dataset_name == 'breeds':
        data_attr = '_image_paths_by_class'
        train_1 = breeds.Breeds(
            data_path, domain_1, source=True, target=False, split='train')
        test_1 = breeds.Breeds(
            data_path, domain_1, source=True, target=False, split='val')
        train_2 = breeds.Breeds( # domain_2 == domain_1 for breeds
            data_path, domain_1, source=False, target=True, split='train')
        test_2 = breeds.Breeds(
            data_path, domain_1, source=False, target=True, split='val')
    elif dataset_name == 'domainnet':
        data_attr = 'data'
        train_1 = domainnet.DomainNet(domain_1, split='train', root=data_path, version=version)
        test_1 = domainnet.DomainNet(domain_1, split='test', root=data_path, version=version)
        train_2 = domainnet.DomainNet(domain_2, split='train', root=data_path, version=version)
        test_2 = domainnet.DomainNet(domain_2, split='test', root=data_path, version=version)
    else:
        raise ValueError(f'Unsupported dataset: {dataset_name}.')
    filter_to_single_class(train_1, class_1, data_attr)
    filter_to_single_class(test_1, class_1, data_attr)
    filter_to_single_class(train_2, class_2, data_attr)
    filter_to_single_class(test_2, class_2, data_attr)
    print(f'Domain {domain_1} contains {len(train_1)} training and {len(test_1)} testing samples of class {class_1}.')
    print(f'Domain {domain_2} contains {len(train_2)} training and {len(test_2)} testing samples of class {class_2}.')
    train_ds = CustomDataset(
        train_1, train_2, dataset_name, data_path, data_attr, transform)
    test_ds = CustomDataset(
        test_1, test_2, dataset_name, data_path, data_attr, transform)
    return train_ds, test_ds

def get_domain_datasets(dataset_name, data_path, domain_1, domain_2, version, transform):
    if dataset_name == 'breeds':
        data_attr = '_image_paths_by_class'
        train_1 = breeds.Breeds(
            data_path, domain_1, source=True, target=False, split='train')
        test_1 = breeds.Breeds(
            data_path, domain_1, source=True, target=False, split='val')
        train_2 = breeds.Breeds( # domain_2 == domain_1 for breeds
            data_path, domain_1, source=False, target=True, split='train')
        test_2 = breeds.Breeds(
            data_path, domain_1, source=False, target=True, split='val')
    elif dataset_name == 'domainnet':
        data_attr = 'data'
        train_1 = domainnet.DomainNet(domain_1, split='train', root=data_path, version=version)
        test_1 = domainnet.DomainNet(domain_1, split='test', root=data_path, version=version)
        train_2 = domainnet.DomainNet(domain_2, split='train', root=data_path, version=version)
        test_2 = domainnet.DomainNet(domain_2, split='test', root=data_path, version=version)
    else:
        raise ValueError(f'Unsupported dataset: {dataset_name}.')
    print(f'Domain {domain_1} contains {len(train_1)} training and {len(test_1)} testing samples.')
    print(f'Domain {domain_2} contains {len(train_2)} training and {len(test_2)} testing samples.')
    train_ds = CustomDataset(
        train_1, train_2, dataset_name, data_path, data_attr, transform)
    test_ds = CustomDataset(
        test_1, test_2, dataset_name, data_path, data_attr, transform)
    return train_ds, test_ds

def get_domainnet_off_limits(domains):
    # Get the off-limits classes for each domain
    size_cutoff = {
        'real': 100,
        'sketch': 100,
        'painting': 100,
        'clipart': 50
    }
    off_limits = {}
    for domain in domains:
        domainnet_ds = domainnet.DomainNet(domain, split='train', version='sentry', verbose=False)
        ys = [item[1] for item in domainnet_ds.data]
        cntr = Counter(ys)
        off = [cls for cls in cntr if cntr[cls] < size_cutoff[domain]]
        off_limits[domain] = off
    return off_limits
