import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from utils import *
# 消融版本，删除第一阶段

# 自信学习
def cl_filter(X, y, y_pros):
    label_dict = {}
    for i in range(len(y)):
        if y[i] not in label_dict:
            label_dict[y[i]] = [[], [], []]
        label_dict[y[i]][0].append(y_pros[i])
        label_dict[y[i]][1].append(i)
    clean_X = []
    clean_y = []
    clean_index = []
    for k in label_dict.keys():
        tempMean = np.mean(label_dict[k][0])
        for i in range(len(label_dict[k][0])):
            if label_dict[k][0][i] >= tempMean:
                clean_X.append(X[label_dict[k][1][i]])
                clean_y.append(k)
                clean_index.append(label_dict[k][1][i])
    clean_X = np.array(clean_X)
    clean_y = np.array(clean_y)
    clean_index = np.array(clean_index)
    return clean_X, clean_y, clean_index


# Refine for Clean set
def refine(clean_X, clean_y, clean_index, K=3):
    # 首先定义一个矩阵用来存储refine后的数据集
    X_new = []
    y_new = []
    new_index = []
    # 首先计算X_clean的距离矩阵
    dist_matrix = np.zeros((clean_X.shape[0], clean_X.shape[0]))
    for i in range(clean_X.shape[0]):
        for j in range(clean_X.shape[0]):
            dist_matrix[i, j] = np.linalg.norm(clean_X[i] - clean_X[j])
    # 然后分别为每个实例找到最近的K个邻居
    for i in range(clean_X.shape[0]):
        neighbors = np.argsort(dist_matrix[i])[:K]
        # 如果这些邻居的y值都相同，则认为这个实例是干净的
        if np.all(clean_y[neighbors] == clean_y[i]):
            X_new.append(clean_X[i])
            y_new.append(clean_y[i])
            new_index.append(clean_index[i])
    X_new = np.array(X_new)
    y_new = np.array(y_new)
    new_index = np.array(new_index)
    return X_new, y_new, new_index

# 定义网络
class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, feature_dim, num_classes):
        super(Net, self).__init__()
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, feature_dim)
        )
        self.classifier = nn.Linear(feature_dim, num_classes)
    
    def forward(self, x):
        feature = self.feature_extractor(x)
        logits = self.classifier(feature)
        return feature, logits
    

# 训练网络
if __name__ == '__main__':
    # 配置参数
    BATCH_SIZE = 1024
    num_epochs = 1001
    filename = 'Music_genre'
    print(filename)

    # 设置时间种子
    ECE_list = []
    for seed in range(10):
        set_seed(seed)

        # 加载数据
        X_ori, adj, trueLabels, num_tsk, num_wks, num_cls = load_data(filename)
        input_dimX = X_ori.shape[1]
        input_dimL = adj.shape[1]

        # 初始化网络
        net1 = Net(input_dimX, 64, 128, num_cls)
        optimizer = optim.Adam(net1.parameters(), lr=1e-3)

        # 计算MV
        y_MV, y_MV_pros, y_MV_total_probs = MV2(adj, num_cls)
        isright = (y_MV == trueLabels)
        ece = compute_ece(y_MV_pros, isright)
        print(f"ECE: {ece}")
        print(f"MV的精度为{np.sum(isright) / len(isright)}")

        # 找到高自信度的样本
        # X_clean, y_clean, index_clean = cl_filter(X_ori, y_MV, y_MV_pros)

        # Refine
        # X_new, y_new, new_index = refine(X_clean, y_clean, index_clean)
        X_new = X_ori
        y_new = y_MV
        new_index = range(len(y_MV))

        for i in range(len(new_index)):
            temp = np.zeros(num_cls)
            temp[y_new[i]] = 1
            y_MV_total_probs[new_index[i]] = temp

        # 转换为Tensor
        X_tensor = torch.tensor(X_ori).float()
        L_tensor = torch.tensor(adj).float()
        y_tensor = torch.tensor(y_MV_total_probs).float()
        
        # 创建Dataset
        train_dataset = TensorDataset(X_tensor, L_tensor, y_tensor)
        
        # 创建DataLoader
        dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        
        # 训练过程
        for epoch in range(num_epochs):
            for (view1, view2, labels) in dataloader:
                # 前向传播
                z1, logits1 = net1(view1)
                
                # 分类损失
                log_probs1 = F.log_softmax(logits1, dim=1)
                
                # 总损失
                criterion = nn.MSELoss()
                loss = criterion(logits1, labels)
                
                # 反向传播
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        # 计算当前的ECE
        with torch.no_grad():
            _, logits1 = net1(X_tensor)
            probs = logits1.detach().cpu().numpy()
            # 对probs进行归一化
            probs = probs / np.sum(probs, axis=1, keepdims=True)
            confidence = np.max(probs, axis=1)
            y_nbmv = np.argmax(probs, axis=1)
            ece = compute_ece(confidence, isright)
            print(f"Current ECE: {ece}")
            print(f"MV修正后的精度为{np.sum(y_nbmv == trueLabels) / len(trueLabels)}")
            ECE_list.append(ece)
    print(f"ECE_list: {ECE_list}")
    print(f"ECE_list的平均值为{np.mean(ECE_list)}")
    print(f"ECE_list的标准差为{np.std(ECE_list)}")
                    