import torch
import os
import numpy as np

from PIL import Image, ImageFile
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
import torchvision.transforms as transforms


ImageFile.LOAD_TRUNCATED_IMAGES = True

def encode_onehot(labels, num_classes=10):
    """
    one-hot labels
    Args:
        labels (numpy.ndarray): labels.
        num_classes (int): Number of classes.
    Returns:
        onehot_labels (numpy.ndarray): one-hot labels.
    """
    onehot_labels = np.zeros((len(labels), num_classes))

    for i in range(len(labels)):
        onehot_labels[i, labels[i]] = 1

    return onehot_labels


class Onehot(object):
    def __call__(self, sample, num_classes=10):
        target_onehot = torch.zeros(num_classes)
        target_onehot[sample] = 1

        return target_onehot


def train_transform():
    """
    Training images transform.
    Args
        None
    Returns
        transform(torchvision.transforms): transform
    """
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    return transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])

def train_transform_vgg():
    """
    Training images transform.
    Args
        None
    Returns
        transform(torchvision.transforms): transform
    """
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    return transforms.Compose([
        transforms.RandomResizedCrop(64),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])


def query_transform():
    """
    Query images transform.
    Args
        None
    Returns
        transform(torchvision.transforms): transform
    """
    # Data transform
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize,
    ])

def query_transform_vgg():
    """
    Query images transform.
    Args
        None
    Returns
        transform(torchvision.transforms): transform
    """
    # Data transform
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    return transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        normalize,
    ])


def load_data(tc, root, num_train, args
              ):
    """
    Loading nus-wide dataset.
    Args:
        tc(int): Top class.
        root(str): Path of image files.
        num_query(int): Number of query data.
        num_train(int): Number of training data.
        batch_size(int): Batch size.
        num_workers(int): Number of loading data threads.
    Returns
        query_dataloader, train_dataloader, retrieval_dataloader(torch.evaluate.data.DataLoader): Data loader.
    """
    if tc == 21:
        if args.consistency:

            if args.model=='alexnet':
                query_dataset = NusWideDatasetTC21(
                    root,
                    'test_img.txt',
                    'test_label_onehot.txt',
                    transform=query_transform(),
                )

                train_dataset = NusWideDatasetTC21_2(
                    root,
                    'database_img.txt',
                    'database_label_onehot.txt',
                    transform=train_transform(),
                    train=True,
                    labeled=True,
                    num_train=num_train,
                )

                train_u_dataset = NusWideDatasetTC21_2(
                    root,
                    'database_img.txt',
                    'database_label_onehot.txt',
                    train=True,
                    labeled=False,
                    transform=train_transform(),
                )

                retrieval_dataset = NusWideDatasetTC21(
                    root,
                    'database_img.txt',
                    'database_label_onehot.txt',
                    transform=query_transform(),
                    train=False,
                    labeled=False
                )
            else:
                query_dataset = NusWideDatasetTC21(
                    root,
                    'test_img.txt',
                    'test_label_onehot.txt',
                    transform=query_transform(),
                )

                train_dataset = NusWideDatasetTC21(
                    root,
                    'database_img.txt',
                    'database_label_onehot.txt',
                    transform=train_transform(),
                    train=True,
                    labeled=True,
                    num_train=num_train,
                )

                train_u_dataset = NusWideDatasetTC21(
                    root,
                    'database_img.txt',
                    'database_label_onehot.txt',
                    train=True,
                    labeled=False,
                    transform=train_transform(),
                )

                retrieval_dataset = NusWideDatasetTC21(
                    root,
                    'database_img.txt',
                    'database_label_onehot.txt',
                    transform=query_transform(),
                    train=False,
                    labeled=False
                )
        elif args.model == 'modified_vgg':
            query_dataset = NusWideDatasetTC21(
                root,
                'test_img.txt',
                'test_label_onehot.txt',
                transform=query_transform_vgg(),
            )

            train_dataset = NusWideDatasetTC21(
                root,
                'database_img.txt',
                'database_label_onehot.txt',
                transform=train_transform_vgg(),
                train=True,
                labeled=True,
                num_train=num_train,
            )

            train_u_dataset = NusWideDatasetTC21(
                root,
                'database_img.txt',
                'database_label_onehot.txt',
                train=True,
                labeled=False,
                transform=train_transform_vgg(),
            )

            retrieval_dataset = NusWideDatasetTC21(
                root,
                'database_img.txt',
                'database_label_onehot.txt',
                transform=query_transform_vgg(),
                train=False,
                labeled=False
            )
    else:
        assert False

    # query_dataloader = DataLoader(
    #     query_dataset,
    #     batch_size=batch_size,
    #     pin_memory=True,
    #     num_workers=num_workers,
    # )
    # train_dataloader = DataLoader(
    #     train_dataset,
    #     batch_size=batch_size,
    #     shuffle=True,
    #     pin_memory=True,
    #     num_workers=num_workers,
    # )
    # retrieval_dataloader = DataLoader(
    #     retrieval_dataset,
    #     batch_size=batch_size,
    #     pin_memory=True,
    #     num_workers=num_workers,
    # )

    query_targets = query_dataset.targets
    retrieval_targets = retrieval_dataset.targets
    label_mat = (np.matmul(query_targets, np.transpose(retrieval_targets)) > 0).astype(np.float32)
    query_dataset.label_mat = label_mat

    return query_dataset, train_dataset, train_u_dataset, retrieval_dataset


class NusWideDatasetTc10(Dataset):
    """
    Nus-wide dataset, 10 classes.
    Args
        root(str): Path of dataset.
        mode(str, 'train', 'query', 'retrieval'): Mode of dataset.
        transform(callable, optional): Transform images.
    """

    def __init__(self, root, mode, transform=None):
        self.root = root
        self.transform = transform

        if mode == 'train':
            self.data = NusWideDatasetTc10.TRAIN_DATA
            self.targets = NusWideDatasetTc10.TRAIN_TARGETS
        elif mode == 'query':
            self.data = NusWideDatasetTc10.QUERY_DATA
            self.targets = NusWideDatasetTc10.QUERY_TARGETS
        elif mode == 'retrieval':
            self.data = NusWideDatasetTc10.RETRIEVAL_DATA
            self.targets = NusWideDatasetTc10.RETRIEVAL_TARGETS
        else:
            raise ValueError(r'Invalid arguments: mode, can\'t load dataset!')

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root, self.data[index])).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img, self.targets[index], index

    def __len__(self):
        return self.data.shape[0]

    def get_targets(self):
        return torch.from_numpy(self.targets).float()

    @staticmethod
    def init(root, num_query, num_train):
        """
        Initialize dataset.
        Args
            root(str): Path of image files.
            num_query(int): Number of query data.
            num_train(int): Number of training data.
        """
        # Load dataset
        img_txt_path = os.path.join(root, 'img_tc10.txt')
        targets_txt_path = os.path.join(root, 'targets_onehot_tc10.txt')

        # Read files
        with open(img_txt_path, 'r') as f:
            data = np.array([i.strip() for i in f])
        targets = np.loadtxt(targets_txt_path, dtype=np.int64)

        # Split dataset
        perm_index = np.random.permutation(data.shape[0])
        query_index = perm_index[:num_query]
        train_index = perm_index[num_query: num_query + num_train]
        retrieval_index = perm_index[num_query:]

        NusWideDatasetTc10.QUERY_DATA = data[query_index]
        NusWideDatasetTc10.QUERY_TARGETS = targets[query_index, :]

        NusWideDatasetTc10.TRAIN_DATA = data[train_index]
        NusWideDatasetTc10.TRAIN_TARGETS = targets[train_index, :]

        NusWideDatasetTc10.RETRIEVAL_DATA = data[retrieval_index]
        NusWideDatasetTc10.RETRIEVAL_TARGETS = targets[retrieval_index, :]


class NusWideDatasetTC21(Dataset):
    """
    Nus-wide dataset, 21 classes.
    Args
        root(str): Path of image files.
        img_txt(str): Path of txt file containing image file name.
        label_txt(str): Path of txt file containing image label.
        transform(callable, optional): Transform images.
        train(bool, optional): Return training dataset.
        num_train(int, optional): Number of training data.
    """

    def __init__(self, root, img_txt, label_txt, transform=None, train=None, labeled=True, num_train=None):
        self.root = root
        self.transform = transform

        img_txt_path = os.path.join(root, img_txt)
        label_txt_path = os.path.join(root, label_txt)
        self.label_mat=None

        # Read files
        with open(img_txt_path, 'r') as f:
            self.data = np.array([i.strip() for i in f])
        self.targets = np.loadtxt(label_txt_path, dtype=np.float32)

        # Sample training dataset
        if train is True and labeled is True:
            perm_index = np.random.permutation(len(self.data))[:num_train]
            self.data = self.data[perm_index]
            self.targets = self.targets[perm_index]
        elif train is True and labeled is False:
            perm_index = np.random.permutation(len(self.data))[num_train:]
            self.data = self.data[perm_index]
            self.targets = self.targets[perm_index]

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root, self.data[index])).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        #return img, self.targets[index], index
        return img, self.targets[index]

    def __len__(self):
        return len(self.data)

    def get_onehot_targets(self):
        return torch.from_numpy(self.targets).float()

class NusWideDatasetTC21_2(Dataset):
    """
    Nus-wide dataset, 21 classes.
    Args
        root(str): Path of image files.
        img_txt(str): Path of txt file containing image file name.
        label_txt(str): Path of txt file containing image label.
        transform(callable, optional): Transform images.
        train(bool, optional): Return training dataset.
        num_train(int, optional): Number of training data.
    """

    def __init__(self, root, img_txt, label_txt, transform=None, train=None, labeled=True, num_train=None):
        self.root = root
        self.transform = transform

        img_txt_path = os.path.join(root, img_txt)
        label_txt_path = os.path.join(root, label_txt)
        self.label_mat=None

        # Read files
        with open(img_txt_path, 'r') as f:
            self.data = np.array([i.strip() for i in f])
        self.targets = np.loadtxt(label_txt_path, dtype=np.float32)

        # Sample training dataset
        if train is True and labeled is True:
            perm_index = np.random.permutation(len(self.data))[:num_train]
            self.data = self.data[perm_index]
            self.targets = self.targets[perm_index]
        elif train is True and labeled is False:
            perm_index = np.random.permutation(len(self.data))[num_train:]
            self.data = self.data[perm_index]
            self.targets = self.targets[perm_index]

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root, self.data[index])).convert('RGB')
        if self.transform is not None:
            img1 = self.transform(img)
            img2 = self.transform(img)

        #return img, self.targets[index], index
        return img1, img2, self.targets[index]

    def __len__(self):
        return len(self.data)

    def get_onehot_targets(self):
        return torch.from_numpy(self.targets).float()
