# -*- coding: utf-8 -*-
# 这个文件用于存放一些常用的函数
import numpy as np
import os
import random
import pandas as pd
from scipy.io import arff
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.impute import SimpleImputer


def MV(adj, n_cls):
    result = []
    result_pros = []
    for i in range(adj.shape[0]):
        count = np.zeros(n_cls)
        for j in range(adj.shape[1]):
            if adj[i][j] != -1:
                count[adj[i][j]] += 1
        index = np.argmax(count)
        result.append(index)
        if count[index] != 0:
            result_pros.append(count[index] / sum(count))
        else:
            result_pros.append(0)
    return np.array(result), np.array(result_pros)


# 加载数据 (X和L，其中X是从arff中读取，L是从response.txt中读取)
def load_data(filename):
    data, meta = arff.loadarff('realData/{}.arff'.format(filename))
    df = pd.DataFrame(data)
    X = preprocess(df)

    # 读取真标文件和打标文件
    true_labels_path = 'realData/{}.gold.txt'.format(filename)
    dataset_path = 'realData/{}.response.txt'.format(filename)

    # 读取真实标签
    file = open(true_labels_path, "r")
    strs_true = file.readlines()
    file.close()

    # 读取工人标注
    file = open(dataset_path, "r")
    strs_labels = file.readlines()
    file.close()

    num_tsk = len(strs_true)
    worker_id = []
    class_id = []

    # 对齐工人id和工人下标
    for i in range(len(strs_labels)):
        strs_labels[i] = strs_labels[i].strip().split('	')
        worker_id.append(strs_labels[i][0])
        class_id.append(strs_labels[i][-1])
    worker_id = np.unique(worker_id)
    worker_dist = {}
    for i in range(len(worker_id)):
        worker_dist[worker_id[i]] = i

    num_wkr = len(worker_id)
    num_cls = len(np.unique(class_id))

    # 对其实例id和实例下标
    true_labels = []
    tsk_dist = {}

    for i in range(num_tsk):
        strs_true[i] = strs_true[i].strip().split('	')
        true_labels.append(int(strs_true[i][1]))
        tsk_dist[strs_true[i][0]] = i

    # 构建标记矩阵，未打标为-1
    adj = []
    for i in range(num_tsk):
        adj.append([])
    for i in range(len(strs_labels)):
        tsk_index = tsk_dist[strs_labels[i][1]]
        temp_class = int(strs_labels[i][2])
        adj[tsk_index].append(temp_class)

    return X, adj, true_labels, num_cls


def preprocess(df):
    # 转化为X
    X_cols = []
    for c in df.columns:
        if c != 'class':
            X_cols.append(c)
    X = df[X_cols]

    # 缺失值填充
    df = X
    # 首先删除全部是缺失值的特征
    df = df.dropna(axis=1, how='all')
    # 再用均值和mode替换连续属性和离散属性的缺失值
    try:
        imputer = SimpleImputer(strategy='most_frequent')
        df[df.select_dtypes(include=['object']).columns] = imputer.fit_transform(
            df[df.select_dtypes(include=['object']).columns])
    except:
        # print("No nominal attributes")
        pass
    try:
        imputer = SimpleImputer(strategy='mean')
        df[df.select_dtypes(include=['float']).columns] = imputer.fit_transform(
            df[df.select_dtypes(include=['float']).columns])
    except:
        # print("No numerical attributes")
        pass
    # 转化为数值
    for column in df.select_dtypes(include=['object']).columns:
        le = LabelEncoder()
        df[column] = le.fit_transform(df[column])
    # 归一化
    scaler = StandardScaler()
    df = scaler.fit_transform(df)
    # 转化为数组
    X = np.array(df)
    return X

def one_hot(target, n_classes):
    targets = np.array([target]).reshape(-1)
    one_hot_targets = np.eye(n_classes)[targets]
    return one_hot_targets