import json, os
import numpy as np
import torch
from tqdm import tqdm
from collections import Counter

def hybrid_knn(dataset, filter=0, lambdas=[0, 0.01, 0.1, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0], record=0, filename_template="DB_{dataset}_hybrid.json"):

    if 'S1' in dataset:
        dataset = 'S1'
    elif 'S2' in dataset:
        dataset = 'S2'
    else:
        dataset = 'S0'

    # 加载子图表征数据
    training_data = np.load(f'{dataset}-training_data.npz')
    train_ht_embed = torch.FloatTensor(training_data['ht_embed']).cuda()
    train_labels = training_data['r']
    train_h = training_data['h']
    train_t = training_data['t']

    # 加载药物描述特征
    des_emb = np.load('DB_drug_emb.npz')['emb']
    
    # 准备训练集的描述特征
    train_descript_feat = torch.FloatTensor(
        np.array([np.concatenate([des_emb[h], des_emb[t]]) for h, t in zip(train_h, train_t)])
    ).cuda()

    # 准备测试数据
    test_data = np.load(f'{dataset}-test_data.npz')
    ht_embed = torch.FloatTensor(test_data['ht_embed']).cuda()
    test_labels = test_data['r']
    test_h = test_data['h']
    test_t = test_data['t']
    
    # 准备测试描述特征
    test_descript_feat = torch.FloatTensor(
        np.array([np.concatenate([des_emb[h], des_emb[t]]) for h,t in zip(test_h, test_t)])
    ).cuda()

    # 公共参数
    k_values = [1,5,10,20,30]
    results = {}
    
    if filter:
        # read EmerGNN/DrugBank/DB_S1_representative_samples.json
        with open(f'DB_{dataset}_representative_samples.json', 'r') as f:
            representative_samples_set = json.load(f)

    for lam in lambdas:
        pred_at_ks = {k: [] for k in k_values}
        accuracy_at_ks = {k: 0.0 for k in k_values}
        neighbors = []
        neighbors_labels = []

        # 提前过滤代表性样本
        if filter:
            # 获取代表性样本的原始下标
            filtered_indices = list(representative_samples_set)
            # 过滤训练数据
            filtered_train_ht_embed = train_ht_embed[filtered_indices]
            filtered_train_descript_feat = train_descript_feat[filtered_indices]
            filtered_train_labels = train_labels[filtered_indices]
        else:
            # 如果没有过滤，使用所有样本，并生成对应的原始下标
            filtered_indices = torch.arange(train_ht_embed.shape[0]).tolist()
            filtered_train_ht_embed = train_ht_embed
            filtered_train_descript_feat = train_descript_feat
            filtered_train_labels = train_labels

        # 批量计算子图表征距离
        subgraph_dists = torch.cdist(ht_embed, filtered_train_ht_embed, p=2)
        
        # 批量计算描述特征距离
        descript_dists = torch.cdist(test_descript_feat, filtered_train_descript_feat, p=2)
        
        # 混合距离
        combined_dists = subgraph_dists + lam * descript_dists
        
        # 获取排序后的下标
        sorted_indices = torch.argsort(combined_dists, dim=1).cpu().numpy()

        # 统计预测结果
        for j in range(ht_embed.shape[0]):
            for k in k_values:
                # 获取最近的K个样本的原始下标
                nearest_k_indices = [filtered_indices[idx] for idx in sorted_indices[j][:k]]
                # 获取对应的标签
                k_labels = train_labels[nearest_k_indices]
                # 统计预测结果
                pred = Counter(k_labels).most_common(1)[0][0]
                pred_at_ks[k].append(pred)
                
            # 记录近邻信息
            nearest_k_indices = [filtered_indices[idx] for idx in sorted_indices[j][:10]]
            neighbors.append(nearest_k_indices)
            neighbors_labels.append(train_labels[nearest_k_indices])

        # 计算准确率
        for k in k_values:
            accuracy = np.mean(np.array(pred_at_ks[k]) == test_labels)
            accuracy_at_ks[k] = accuracy
        
        # 保存结果
        results[lam] = accuracy_at_ks
        
        # 打印结果
        print(f"lambda={lam}:")
        print(" ".join([f"acc@{k}={v:.3f}" for k, v in accuracy_at_ks.items()]))

        # 记录详细信息
        if record:
            records = [{
                'id': idx,
                'K-neighbor': nbr,
                'K-neighbor_labels': lbl.tolist()
            } for idx, (nbr, lbl) in enumerate(zip(neighbors, neighbors_labels))]
            
            filename = filename_template.format(dataset=dataset)
            if filter:
                filename = filename.replace('.json', '_filtered.json')
            with open(filename, 'w') as f:
                json.dump(records, f, indent=4)
            print(f"Records saved to {filename}")

    return results

if __name__ == '__main__':
    
    os.environ['CUDA_VISIBLE_DEVICES'] = '5'
    lams = [0.5]
    hybrid_knn('S1', filter=0, lambdas=lams, record=1)
