#from __future__ import print_function, division

import torch
import numpy as np
from sklearn.preprocessing import StandardScaler
import random
from PIL import Image
import torch.utils.data as data
import os
import os.path
import accimage
import contextlib
from torch.utils.data.dataset import Dataset
import copy


@contextlib.contextmanager
def temp_seed(seed):
    state = np.random.get_state()
    np.random.seed(seed)
    try:
        yield
    finally:
        np.random.set_state(state)

def make_dataset(image_list, labels, root):
    if labels:  # labels=None for imagenet
      len_ = len(image_list)
      images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
    else:      # split and get the labels
      if len(image_list[0].split()) > 2:
        #images = [(os.path.join(root, 'images', val.split()[0]), np.array([int(la) for la in val.split()[1:]])) for val in image_list]
        images = [(os.path.join(root, val.split()[0]), np.array([int(la) for la in val.split()[1:]])) for val
                  in image_list]
      else:
        images = [(val.split()[0], int(val.split()[1])) for val in image_list]
    return images


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    #from torchvision import get_image_backend
    #if get_image_backend() == 'accimage':
    #    return accimage_loader(path)
    #else:
        return pil_loader(path)


class NusWideDatasetTC21(Dataset):
    """
    Nus-wide dataset, 21 classes.
    Args
        root(str): Path of image files.
        img_txt(str): Path of txt file containing image file name.
        label_txt(str): Path of txt file containing image label.
        transform(callable, optional): Transform images.
        train(bool, optional): Return training dataset.
        num_train(int, optional): Number of training data.
    """
    img_txt_paths = ['test_img.txt',
                     'database_img.txt']
    label_txt_paths = ['test_label_onehot.txt',
                     'database_label_onehot.txt']


    def __init__(self, root, split='train', transform=None,  num_train=None, nb_fold=0):
        self.root = root
        self.transform = transform
        if split == 'train':
            self.nb_classes = 15
        elif split == 'query':
            self.nb_classes = 6
        elif split == 'gallery':
            self.nb_classes = 21
        elif split == 'eval':
            self.nb_classes = 15
        else:
            assert False

        img_paths = []
        for img_txt in self.img_txt_paths:
            img_txt_path = os.path.join(root, img_txt)
            # Read files

            with open(img_txt_path, 'r') as f:
                img_paths += [i.strip() for i in f]
        img_paths = np.array(img_paths)
        self.paths = copy.deepcopy(img_paths)

        img_labels = []
        for label_txt in self.label_txt_paths:
            label_txt_path = os.path.join(root, label_txt)
            img_labels.append(np.loadtxt(label_txt_path, dtype=np.int32))
        img_labels = np.vstack(img_labels)
        self.label_mat=None
        self.targets = copy.deepcopy(img_labels)


        with temp_seed(nb_fold):
            # Choice Unseen class according to fold idx
            unseen_cls_idxs = np.random.choice(21, 6, replace=False)

            unseen_img_labels = []
            unseen_img_paths = []
            for cls_idx in unseen_cls_idxs:
                idxs = np.where(img_labels[:, cls_idx] == 1)[0]
                sub_unseen_labels = img_labels[idxs]
                sub_unseen_img_paths = img_paths[idxs]

                img_labels = np.delete(img_labels, idxs, axis=0)
                img_paths = np.delete(img_paths, idxs, axis=0)

                unseen_img_labels.append(sub_unseen_labels)
                unseen_img_paths.append(sub_unseen_img_paths)

            unseen_img_labels = np.concatenate(unseen_img_labels, axis=0)
            unseen_img_paths = np.concatenate(unseen_img_paths, axis=0)

            # Split unseen data into gallery and query set
            self.query_x = unseen_img_paths[::2]
            self.query_y = unseen_img_labels[::2]

            unseen_gallery_x = np.delete(unseen_img_paths, np.arange(0, np.shape(unseen_img_paths)[0], 2), axis=0)
            unseen_gallery_y = np.delete(unseen_img_labels, np.arange(0, np.shape(unseen_img_labels)[0], 2), axis=0)

        # Delete one-hot gt for unseen class
        extracted_img_labels = np.delete(img_labels, unseen_cls_idxs, axis=1)

        # Split seen data into source, gallery set
        self.source_x = img_paths[::2]
        self.source_y = extracted_img_labels[::2]

        seen_gallery_x = np.delete(img_paths, np.arange(0, np.shape(img_paths)[0], 2), axis=0)
        seen_gallery_y = np.delete(img_labels, np.arange(0, np.shape(img_labels)[0], 2), axis=0)

        self.gallery_x = np.concatenate((seen_gallery_x, unseen_gallery_x), axis=0)
        self.gallery_y = np.concatenate((seen_gallery_y, unseen_gallery_y), axis=0)

        seen_eval_y = np.delete(seen_gallery_y, unseen_cls_idxs, axis=1)

        if split == 'train':
            self.imgs = list(zip(self.source_x, self.source_y))
        elif split == 'gallery':
            self.imgs = list(zip(self.gallery_x, self.gallery_y))
        elif split == 'query':
            self.imgs = list(zip(self.query_x, self.query_y))

            # Construct label mat between gallary and query
            self.label_mat = (np.matmul(self.query_y, np.transpose(self.gallery_y)) > 0).astype(np.float32)
        elif split == 'eval':
            self.imgs = list(zip(seen_gallery_x, seen_eval_y))

            self.label_mat = (np.matmul(seen_eval_y, np.transpose(seen_eval_y)) > 0).astype(np.float32)



    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root, self.imgs[index][0])).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        #return img, self.targets[index], index
        return img, self.imgs[index][1]

    def __len__(self):
        return len(self.imgs)

    def get_onehot_targets(self):
        return torch.from_numpy(self.targets).float()


if __name__ == '__main__':
    root = '../data/NUS_WIDE'
    nb_fold = 3
    split_dataset = NusWideDatasetTC21(root, split='query', nb_fold=nb_fold)
    print('done')

    with open('{}/fold_{}/train.txt'.format(root, nb_fold), 'w') as f:

        for i in range(split_dataset.source_y.shape[0]):
            paths = split_dataset.source_x[i].split('/')
            f.write('{}/{} '.format(paths[-2], paths[-1]))
            for j in range(split_dataset.source_y[i].shape[0]):
            #for label in split_dataset.source_y[i]:
                if j != split_dataset.source_y[i].shape[0] - 1:
                    f.write('{} '.format(split_dataset.source_y[i][j]))
                else:
                    f.write('{}'.format(split_dataset.source_y[i][j]))
            f.write('\n')

    with open('{}/fold_{}/database.txt'.format(root, nb_fold), 'w') as f:

        for i in range(split_dataset.gallery_y.shape[0]):
            paths = split_dataset.gallery_x[i].split('/')
            f.write('{}/{} '.format(paths[-2], paths[-1]))
            for j in range(split_dataset.gallery_y[i].shape[0]):
            #for label in split_dataset.source_y[i]:
                if j != split_dataset.gallery_y[i].shape[0] - 1:
                    f.write('{} '.format(split_dataset.gallery_y[i][j]))
                else:
                    f.write('{}'.format(split_dataset.gallery_y[i][j]))
            f.write('\n')

    with open('{}/fold_{}/test.txt'.format(root, nb_fold), 'w') as f:

        for i in range(split_dataset.query_y.shape[0]):
            paths = split_dataset.query_x[i].split('/')
            f.write('{}/{} '.format(paths[-2], paths[-1]))
            for j in range(split_dataset.query_y[i].shape[0]):
            #for label in split_dataset.source_y[i]:
                if j != split_dataset.query_y[i].shape[0] - 1:
                    f.write('{} '.format(split_dataset.query_y[i][j]))
                else:
                    f.write('{}'.format(split_dataset.query_y[i][j]))
            f.write('\n')

    label_mat = torch.from_numpy(split_dataset.label_mat)
    torch.save(label_mat, os.path.join(root, 'fold_%d'%(nb_fold), 'Label_mat.tar'))



