import csv
import os
from io import BytesIO

import numpy as np
import pandas as pd
import requests
from PIL import Image
from sklearn.utils import shuffle

#from data_util.data_loader import TwoPartyDataLoader

def balance_X_y(XA, XB, y, seed=5):
    np.random.seed(seed)
    num_pos = np.sum(y == 1)
    # num_neg = np.sum(y == -1)
    # pos_indexes = [i for (i, _y) in enumerate(y) if _y > 0]
    # neg_indexes = [i for (i, _y) in enumerate(y) if _y < 0]

    num_neg = np.sum(y == 0)
    pos_indexes = [i for (i, _y) in enumerate(y) if _y > 0.5]
    neg_indexes = [i for (i, _y) in enumerate(y) if _y < 0.5]

    print("len(pos_indexes)", len(pos_indexes))
    print("len(neg_indexes)", len(neg_indexes))
    print("num of samples", len(pos_indexes) + len(neg_indexes))
    print("num_pos:", num_pos)
    print("num_neg:", num_neg)

    if num_pos < num_neg:
        np.random.shuffle(neg_indexes)
        # randomly pick negative samples of size equal to that of positive samples
        rand_indexes = neg_indexes[:num_pos]
        indexes = pos_indexes + rand_indexes
        np.random.shuffle(indexes)
        y = [y[i] for i in indexes]
        XA = [XA[i] for i in indexes]
        XB = [XB[i] for i in indexes]

    return np.array(XA), np.array(XB), np.array(y)


def get_top_k_labels(data_dir, top_k=5):
    data_path = "NUS_WIDE/Groundtruth/AllLabels"
    label_counts = {}
    for filename in os.listdir(os.path.join(data_dir, data_path)):
        file = os.path.join(data_dir, data_path, filename)
        # print(file)
        if os.path.isfile(file):
            label = file[:-4].split("_")[-1]
            df = pd.read_csv(os.path.join(data_dir, file))
            df.columns = ['label']
            label_counts[label] = (df[df['label'] == 1].shape[0])
    label_counts = sorted(label_counts.items(), key=lambda x: x[1], reverse=True)
    selected = [k for (k, v) in label_counts[:top_k]]
    return selected


def get_labeled_data(data_dir, selected_label, n_samples, dtype="Train"):
    # get labels
    data_path = "NUS_WIDE/Groundtruth/TrainTestLabels/"
    dfs = []
    for label in selected_label:
        file = os.path.join(data_dir, data_path, "_".join(["Labels", label, dtype]) + ".txt")
        df = pd.read_csv(file, header=None)
        print("df shape", df.shape)
        df.columns = [label]
        dfs.append(df)
    data_labels = pd.concat(dfs, axis=1)
    # print(data_labels)
    if len(selected_label) > 1:
        selected = data_labels[data_labels.sum(axis=1) == 1]
    else:
        selected = data_labels
    print(selected.shape)

    # get XA, which are image low level features
    features_path = "NUS_WIDE/NUS_WID_Low_Level_Features/Low_Level_Features"
    print("data_dir: {0}".format(data_dir))
    print("features_path: {0}".format(features_path))
    dfs = []
    if dtype=="Train":
        filenamelist= ['Train_Normalized_CM55.dat', 
                    'Train_Normalized_CH.dat',
                    'Train_Normalized_EDH.dat',
                    'Train_Normalized_CORR.dat',
                    'Train_Normalized_WT.dat' ]
    else: 
        filenamelist= ['Test_Normalized_CM55.dat', 
                    'Test_Normalized_CH.dat',
                    'Test_Normalized_EDH.dat',
                    'Test_Normalized_CORR.dat',
                    'Test_Normalized_WT.dat' ]
    # for file in os.listdir(os.path.join(data_dir, features_path)):
    #     if file.startswith("_".join([dtype, "Normalized"])):
    for file in filenamelist:
        print("get_labeled_data load filename:", file)
        df = pd.read_csv(os.path.join(data_dir, features_path, file), header=None, sep=" ")
        df.dropna(axis=1, inplace=True)
        print("b datasets features", len(df.columns))
        dfs.append(df)
    data_XA = pd.concat(dfs, axis=1)
    data_X_image_selected = data_XA.loc[selected.index]
    print("X image shape:", data_X_image_selected.shape)  # 634 columns

    # get XB, which are tags
    tag_path = "NUS_WIDE/NUS_WID_Tags/"
    file = "_".join([dtype, "Tags1k"]) + ".dat"
    tagsdf = pd.read_csv(os.path.join(data_dir, tag_path, file), header=None, sep="\t")
    tagsdf.dropna(axis=1, inplace=True)
    data_X_text_selected = tagsdf.loc[selected.index]
    print("X text shape:", data_X_text_selected.shape)

    if n_samples is None:
        return data_X_image_selected.values[:], data_X_text_selected.values[:], selected.values[:]
    return data_X_image_selected.values[:n_samples], data_X_text_selected.values[:n_samples], selected.values[:n_samples]


def image_and_text_data(data_dir, selected, n_samples=2000):
    return get_labeled_data(data_dir, selected, n_samples)
