from torch.utils.data import Dataset
import os
import pathlib
from PIL import Image


VALID_DOMAINS = [
    'clipart',
    'infograph',
    'painting',
    'quickdraw',
    'real',
    'sketch'
]

SENTRY_DOMAINS = [
    'clipart',
    'painting',
    'real',
    'sketch'
]

NUM_CLASSES_DICT = {
    'full': 345,
    'sentry': 40
}

VALID_SPLITS = ['train', 'test']

VALID_VERSIONS = ['full', 'sentry']

ROOT = ''
SENTRY_SPLITS_ROOT = ''
SENTRY_CLASSNAMES_FILE = ''


def load_dataset(domains, split, version):
    if len(domains) == 1 and domains[0] == 'all':
        if version == 'sentry':
            domains = SENTRY_DOMAINS
        else:
            domains = VALID_DOMAINS

    data = []
    for domain in domains:
        if version == 'sentry':
            idx_file = os.path.join(SENTRY_SPLITS_ROOT, f'{domain}_{split}_mini.txt')
        else:
            idx_file = os.path.join(ROOT, f'{domain}_{split}.txt')
        with open(idx_file, 'r') as f:
            data += [line.split() for line in f]
    return data


class DomainNet(Dataset):
    def __init__(self, domain, split='train', root=ROOT,
                 transform=None, unlabeled=False, verbose=True,
                 version='sentry', classes=None):
        super().__init__()

        if version not in VALID_VERSIONS:
            raise ValueError(f'dataset version must be in {VALID_VERSIONS} but was {version}')
        domain_list = domain.split(',')
        for domain in domain_list:
            if domain != 'all' and version == 'full' and domain not in VALID_DOMAINS:
                raise ValueError(f'domain must be in {VALID_DOMAINS} but was {domain}')
            if domain != 'all' and version == 'sentry' and domain not in SENTRY_DOMAINS:
                raise ValueError(f'domain must be in {SENTRY_DOMAINS} but was {domain}')
        if split not in VALID_SPLITS:
            raise ValueError(f'split must be in {VALID_SPLITS} but was {split}')
        self._root_data_dir = root
        self._domain_list = domain_list
        self._split = split
        self._transform = transform
        self._version = version

        self._unlabeled = unlabeled
        self.data = load_dataset(domain_list, split, version)
        self.means = [0.485, 0.456, 0.406]
        self.stds = [0.228, 0.224, 0.225]
        if classes is not None:
            indices = [i for i in range(len(self.data)) if int(self.data[i][1]) in classes]
            self.data = [self.data[i] for i in indices]
        if verbose:
            print(f'Loaded domains {", ".join(domain_list)}, split is {split}')
            print(f'Total number of images: {len(self.data)}')
            print(f'Total number of classes: {self.get_num_classes()}')


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path, y = self.data[idx]
        x = Image.open(os.path.join(self._root_data_dir, path))
        x = x.convert('RGB')
        if self._transform is not None:
            x = self._transform(x)
        # if self._unlabeled:
        #     return x, -1
        # else:
        return x, int(y)

    def get_num_classes(self):
        return len(set([self.data[idx][1] for idx in range(len(self.data))]))

    def get_classnames(self):
        if self._version == 'sentry':
            classnames_filename = SENTRY_CLASSNAMES_FILE
        else:
            raise NotImplementedError('get_classnames has not been implemented for non-SENTRY')

        with open(classnames_filename, 'r') as classnames_file:
            return tuple(classname.strip() for classname in classnames_file)

    @staticmethod
    def get_clip_features_path(clip_model_name, split, domain, **other_kwargs):
        return pathlib.Path('DomainNet') / clip_model_name / domain / split / 'features.pkl'


def get_dataset_domainet_pretrain(dataset=None, data_dir=None, transform=None, train=True, download=True):
    return DomainNet('clipart,painting,real,sketch', transform=transform, split='train')

def get_dataset_domainnet_onedomain(domainname, data_dir=None, transform=None, train=True, download=True):
    if train:
        return DomainNet(domainname, transform=transform, split='train')
    else:
        return DomainNet(domainname, transform=transform, split='test')



class DomainNetPair(Dataset):
    def __init__(self, domain_list, split='train', root=ROOT,
                 transform=None, unlabeled=False, verbose=True,
                 version='sentry', classes=None, domains=None):
        super().__init__()

        if version not in VALID_VERSIONS:
            raise ValueError(f'dataset version must be in {VALID_VERSIONS} but was {version}')
        if split not in VALID_SPLITS:
            raise ValueError(f'split must be in {VALID_SPLITS} but was {split}')
        self._root_data_dir = root
        self._domain_list = domain_list
        self._split = split
        self._transform = transform
        self._version = version

        self._unlabeled = unlabeled
        self.means = [0.485, 0.456, 0.406]
        self.stds = [0.228, 0.224, 0.225]

        self.data1 = load_dataset([domain_list[0]], split, version)
        indices1 = [i for i in range(len(self.data1)) if int(self.data1[i][1]) == classes[0]]
        self.data1 = [[self.data1[i][0], '0'] for i in indices1]

        self.data2 = load_dataset([domain_list[1]], split, version)
        indices2 = [i for i in range(len(self.data2)) if int(self.data2[i][1]) == classes[1]]
        self.data2 = [[self.data2[i][0], '1'] for i in indices2]

        self.data = self.data1 + self.data2

        if verbose:
            print(f'Loaded domains {", ".join(domain_list)}, split is {split}')
            print(f'Total number of images: {len(self.data)}')
            print(f'Total number of classes: {self.get_num_classes()}')


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path, y = self.data[idx]
        x = Image.open(os.path.join(self._root_data_dir, path))
        x = x.convert('RGB')
        if self._transform is not None:
            x = self._transform(x)
        # if self._unlabeled:
        #     return x, -1
        # else:
        return x, int(y)

    def get_num_classes(self):
        return len(set([self.data[idx][1] for idx in range(len(self.data))]))

    def get_classnames(self):
        if self._version == 'sentry':
            classnames_filename = SENTRY_CLASSNAMES_FILE
        else:
            raise NotImplementedError('get_classnames has not been implemented for non-SENTRY')

        with open(classnames_filename, 'r') as classnames_file:
            return tuple(classname.strip() for classname in classnames_file)

    @staticmethod
    def get_clip_features_path(clip_model_name, split, domain, **other_kwargs):
        return pathlib.Path('DomainNet') / clip_model_name / domain / split / 'features.pkl'
