from os import listdir
from os.path import isfile, isdir, join
import os
import os.path as osp
from utils import save_dict, load_dict
import pandas as pd
import numpy as np


def statistic_user_data(dict, file):
    f = open(file, 'w')
    for uname, item in dict.items():
        n_cls = {}
        for cl, data in item.items():
            n_cls[cl] = len(data)
        f.write('User_{}-> {}\n'.format(uname, n_cls))
    f.close()


if __name__ == "__main__":
    data_path = "E:/Datasets/Target/ChestX"
    csv_path = data_path + "/Data_Entry_2017.csv"
    savename = "ChestX"

    used_labels = ["Atelectasis", "Cardiomegaly", "Effusion", "Infiltration", "Mass"]

    labels_maps = {"Atelectasis": 0, "Cardiomegaly": 1, "Effusion": 2, "Infiltration": 3, "Mass": 4}

    labels_set = []

    data_info = pd.read_csv(csv_path, skiprows=[0], header=None)

    # First column contains the image paths
    image_name_all = np.asarray(data_info.iloc[:, 0])
    labels_all = np.asarray(data_info.iloc[:, 1])

    image_name = []
    labels = []

    for name, label in zip(image_name_all, labels_all):
        label = label.split("|")

        if len(label) == 1 and label[0] != "No Finding" and label[0] != "Pneumonia" and label[
            0] in used_labels:
            labels.append(labels_maps[label[0]])
            image_name.append(name)

    data_len = len(image_name)
    image_name = np.asarray(image_name)
    labels = np.asarray(labels)

    dict = {}
    users = ['0']

    for i, u in enumerate(users):
        user_dict = {}

        for c in range(5):
            image_name_c = image_name[labels == c]
            img_path = "images/"
            f_list = [osp.join(img_path, iname).replace(os.sep, '/') for iname in image_name_c]
            user_dict[int(c)] = f_list

        dict[str(i)] = user_dict

    # write summary file
    statistic_user_data(dict, file='{}_summary.txt'.format(savename))

    # write pkl file
    save_dict(dict, '{}.pkl'.format(savename))

    # test
    # load_file = load_dict('{}.pkl'.format(savename))
    # print()
