import random

import torch
from torch.utils.data import Dataset
import scipy.io as sio
from scipy.io import loadmat
from PIL import Image
import numpy as np
import os
from avalanche.benchmarks.datasets.multi_label_dataset.select_data import SelectData
from tools.get_path import get_project_path

class NUS_WIDE(Dataset):
    def __init__(self, imageset="train", transforms=None,task_classes = None,task_id = None):
        super().__init__()
        usr_root = os.path.expanduser("~")
        self.root = usr_root + "/data/Datasets/NUS-WIDE/"
        if imageset=="train" or imageset=="val":
            image_dir = self.root + '/ImageList/TrainImagelist.txt'
            mat_dir = self.root + "/mat/Train_labels.mat"
        elif imageset == "test":
            image_dir = self.root + '/ImageList/TestImagelist.txt'
            mat_dir = self.root + "/mat/Test_labels.mat"

        label_dir = "AllLabels"
        txt_dir = 'NUS_WID_Tags/All_Tags.txt'
        output_dir = 'NUS_WIDE_10K/NUS_WIDE_10k.list'

        self.transforms = transforms

        with open(image_dir) as f:
            self.image_lists = f.readlines()

        labels_mat = loadmat(mat_dir)
        self.targets = labels_mat["labels"]
        self.targets = self.targets.astype(np.float32)
        split = int(len(self.image_lists)*0.7)
        if imageset=="train":
            self.image_lists =self.image_lists[:split]
            self.targets = self.targets[:split,:]
        if imageset=="val":
            self.image_lists =self.image_lists[split:]
            self.targets = self.targets[split:,:]

    def __len__(self):
        return len(self.image_lists)

    def __getitem__(self, index):
        image_path = self.image_lists[index].strip().split("\\")
        image_path = image_path[0]+"/"+image_path[1]

        img = Image.open(self.root + "/NUSWIDE/Flickr/" + image_path)
        if len(img.split()) != 3:
            img = img.convert('RGB')

        # target = torch.tensor(self.targets[index],dtype=torch.float)
        target = self.targets[index]

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target


def make_one_hot_labels():
    root_ = "/home/zhangyan/data/Datasets/NUS-WIDE/"

    print("--- label ---")
    LABEL_P = root_ + "/TrainTestLabels"

    # class order determined by `Concepts81.txt`
    cls_id = {}
    with open(root_+"Concepts81.txt", "r") as f:
        for cid, line in enumerate(f):
            cn = line.strip()
            cls_id[cn] = cid
    # print("\nclass-ID:", cls_id)
    id_cls = {cls_id[k]: k for k in cls_id}
    # print("\nID-class:", id_cls)
    N_CLASS = len(cls_id)
    print("\n#classes:", N_CLASS)

    train_class_files,test_class_files = [],[]
    for filename in os.listdir(LABEL_P):
        if "Train" in filename:
            train_class_files.append(filename)
        elif "Test" in filename:
            test_class_files.append(filename)

    def aaa(class_files,mode):

        with open(root_+f"ImageList/{mode}Imagelist.txt") as f:
            N_SAMPLE = len(f.readlines())

        # print("\nlabel file:", len(class_files), class_files)
        label_key = lambda x: x.split(".txt")[0].split("_")[1]

        labels = np.zeros([N_SAMPLE, N_CLASS], dtype=np.int8)
        for cf in class_files:
            c_name = label_key(cf)
            cid = cls_id[c_name]
            print('->', cid, c_name)
            with open(os.path.join(LABEL_P, cf), "r") as f:
                for sid, line in enumerate(f):
                    if int(line) > 0:
                        labels[sid][cid] = 1
        print("labels:", labels.shape, ", cardinality:", labels.sum())
        # labels: (269648, 81) , cardinality: 503848
        # np.save("labels.npy", labels.astype(np.int8))
        labels = labels.astype(np.uint8)
        sio.savemat(root_ + f"/mat/{mode}_labels.mat", {"labels": labels}, do_compression=True)

    aaa(test_class_files,"Test")

