# -*- coding: utf-8 -*-
# 这个文件用于存放一些常用的函数
import numpy as np
import os
import random
import torch
import pandas as pd
from scipy.io import arff
from torch_geometric.data import Data
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.impute import SimpleImputer

def compute_ece(probs, isright, num_bins=10):
    """
    计算 ECE (Expected Calibration Error)
    
    Args:
        probs (numpy.ndarray): 模型预测的概率值（置信度），形状为 (n_samples, )。
        isright (numpy.ndarray): 样本的预测情况，1为正确，0为错误。
        num_bins (int): 置信度区间的数量（默认 10 个区间）。
    
    Returns:
        float: ECE 值
    """
    # 初始化
    ece = 0.0
    bin_boundaries = np.linspace(0.0, 1.0, num_bins + 1)  # 分割区间
    bin_lowers = bin_boundaries[:-1]
    bin_uppers = bin_boundaries[1:]
    
    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
        # 找到属于该区间的样本
        in_bin = (probs > bin_lower) & (probs <= bin_upper)
        bin_size = np.sum(in_bin)
        
        if bin_size > 0:
            # 计算区间内的准确率和平均置信度
            accuracy_in_bin = np.mean(isright[in_bin] == 1)  # 准确率
            avg_confidence_in_bin = np.mean(probs[in_bin])  # 平均置信度
            # 计算加权误差
            ece += np.abs(accuracy_in_bin - avg_confidence_in_bin) * (bin_size / len(probs))

            # print("{}".format(abs(accuracy_in_bin - avg_confidence_in_bin) * bin_size))
    return ece


def MV(adj, n_cls):
    result = []
    result_pros = []
    for i in range(adj.shape[0]):
        count = np.zeros(n_cls)
        for j in range(adj.shape[1]):
            if adj[i][j] != -1:
                count[adj[i][j]] += 1
        index = np.argmax(count)
        result.append(index)
        if count[index] != 0:
            result_pros.append(count[index] / sum(count))
        else:
            result_pros.append(0)
    return np.array(result), np.array(result_pros)


def MV2(adj, n_cls):
    result = []
    result_pros = []
    total_probs = []
    for i in range(adj.shape[0]):
        count = np.zeros(n_cls)
        for j in range(adj.shape[1]):
            if adj[i][j] != -1:
                count[adj[i][j]] += 1
        index = np.argmax(count)
        result.append(index)
        if count[index] != 0:
            result_pros.append(count[index] / sum(count))
            # 保存归一化数组
            total_probs.append(np.array(count) / sum(count))
        else:
            result_pros.append(0)
            total_probs.append(np.zeros(n_cls))
    return np.array(result), np.array(result_pros), np.array(total_probs)


def NBMV(adj, n_cls, m=0):
    total_probs = []
    for i in range(adj.shape[0]):
        count = np.zeros(n_cls)
        for j in range(adj.shape[1]):
            if adj[i][j] != -1:
                count[adj[i][j]] += 1
        temp = np.zeros(n_cls)
        for j in range(n_cls):
            # 贝叶斯公式
            temp[j] = (count[j] + m * 1 / n_cls) / (sum(count) + m)
        total_probs.append(temp)
    return np.array(total_probs)


def set_seed(seed):
    # 设置 Python 内建随机模块的种子
    random.seed(seed)
    # 设置 NumPy 的随机种子
    np.random.seed(seed)
    # 设置 PyTorch 的随机种子
    torch.manual_seed(seed)
    # 设置 GPU 的随机种子（如果使用 CUDA）
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # 如果有多个 GPU，则设置所有 GPU
    # 为了确保在使用 cuDNN 时，操作是确定性的
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # torch.use_deterministic_algorithms(True)


# 加载数据 (X和L，其中X是从arff中读取，L是从response.txt中读取)
def load_data(filename):
    data, meta = arff.loadarff('./datasets/{}.arff'.format(filename))
    df = pd.DataFrame(data)
    X = preprocess(df)

    # 读取真标文件和打标文件
    true_labels_path = './datasets/{}.gold.txt'.format(filename)
    dataset_path = './datasets/{}.response.txt'.format(filename)

    # 读取真实标签
    file = open(true_labels_path, "r")
    strs_true = file.readlines()
    file.close()

    # 读取工人标注
    file = open(dataset_path, "r")
    strs_labels = file.readlines()
    file.close()

    num_tsk = len(strs_true)
    worker_id = []
    class_id = []

    # 对齐工人id和工人下标
    for i in range(len(strs_labels)):
        strs_labels[i] = strs_labels[i].strip().split('	')
        worker_id.append(strs_labels[i][0])
        class_id.append(strs_labels[i][-1])
    worker_id = np.unique(worker_id)
    worker_dist = {}
    for i in range(len(worker_id)):
        worker_dist[worker_id[i]] = i

    num_wkr = len(worker_id)
    num_cls = len(np.unique(class_id))

    # 对其实例id和实例下标
    true_labels = []
    tsk_dist = {}

    for i in range(num_tsk):
        strs_true[i] = strs_true[i].strip().split('	')
        true_labels.append(int(strs_true[i][1]))
        tsk_dist[strs_true[i][0]] = i

    # 构建标记矩阵，未打标为-1
    adj = np.full((num_tsk, num_wkr), -1, dtype=np.int32)
    for i in range(len(strs_labels)):
        worker_index = worker_dist[str(strs_labels[i][0])]
        tsk_index = tsk_dist[strs_labels[i][1]]
        temp_class = int(strs_labels[i][2])
        adj[tsk_index][worker_index] = temp_class

    return X, adj, true_labels, num_tsk, num_wkr, num_cls


def preprocess(df):
    # 转化为X
    X_cols = []
    for c in df.columns:
        if c != 'class':
            X_cols.append(c)
    X = df[X_cols]

    # 缺失值填充
    df = X
    # 首先删除全部是缺失值的特征
    df = df.dropna(axis=1, how='all')
    # 再用均值和mode替换连续属性和离散属性的缺失值
    try:
        imputer = SimpleImputer(strategy='most_frequent')
        df[df.select_dtypes(include=['object']).columns] = imputer.fit_transform(
            df[df.select_dtypes(include=['object']).columns])
    except:
        # print("No nominal attributes")
        pass
    try:
        imputer = SimpleImputer(strategy='mean')
        df[df.select_dtypes(include=['float']).columns] = imputer.fit_transform(
            df[df.select_dtypes(include=['float']).columns])
    except:
        # print("No numerical attributes")
        pass
    # 转化为数值
    for column in df.select_dtypes(include=['object']).columns:
        le = LabelEncoder()
        df[column] = le.fit_transform(df[column])
    # 归一化
    scaler = StandardScaler()
    df = scaler.fit_transform(df)
    # 转化为数组
    X = np.array(df)
    return X

def one_hot(target, n_classes):
    targets = np.array([target]).reshape(-1)
    one_hot_targets = np.eye(n_classes)[targets]
    return one_hot_targets