import json, os
import numpy as np
import torch
from tqdm import tqdm
from collections import Counter

def hybrid_knn(dataset, filter=0, lambdas=[0, 0.01, 0.1, 0.3, 0.5, 0.7, 1.0, 1.5, 2.0], record=0, filename_template="TS_{dataset}_hybrid.json"):

    if 'S1' in dataset:
        dataset = 'S1'
    elif 'S2' in dataset:
        dataset = 'S2'
    else:
        dataset = 'S0'

    # 加载子图表征数据
    training_data = np.load(f'{dataset}-training_data.npz')
    train_ht_embed = torch.FloatTensor(training_data['ht_embed']).cuda()
    train_labels = training_data['r']
    train_h = training_data['h']
    train_t = training_data['t']

    # 加载药物描述特征
    des_emb = np.load('TS_drug_emb.npz')['emb']
    
    # 准备训练集的描述特征
    train_descript_feat = torch.FloatTensor(
        np.array([np.concatenate([des_emb[h], des_emb[t]]) for h, t in zip(train_h, train_t)])
    ).cuda()

    # 准备测试数据
    test_data = np.load(f'{dataset}-test_data.npz')
    ht_embed = torch.FloatTensor(test_data['ht_embed']).cuda()
    test_labels = test_data['r']
    test_h = test_data['h']
    test_t = test_data['t']
    # print(test_labels.shape) # (355, 200)
        
    # 准备测试描述特征
    test_descript_feat = torch.FloatTensor(
        np.array([np.concatenate([des_emb[h], des_emb[t]]) for h,t in zip(test_h, test_t)])
    ).cuda()
    test_descript_feat_reverse = torch.FloatTensor(
        np.array([np.concatenate([des_emb[t], des_emb[h]]) for h,t in zip(test_h, test_t)])
    ).cuda()

    # 公共参数
    k_values = [1,5,10,20,30,50]
    results = {}
    
    if filter:
        with open(f'TS_{dataset}_representative_samples.json', 'r') as f:
            representative_samples_set = json.load(f)
        # 获取代表性样本的原始下标
        filtered_indices = sorted(representative_samples_set)
        # 过滤训练数据
        train_ht_embed = train_ht_embed[filtered_indices]
        train_descript_feat = train_descript_feat[filtered_indices]


    for lam in lambdas:
        # print(f"\nProcessing lambda={lam}...")
        pred_at_ks = {k: [] for k in k_values}
        recall_at_ks = {k: 0.0 for k in k_values}
        neighbors = []
        neighbors_labels = []

        # 批量处理测试样本
        for j in (range(ht_embed.shape[0])):
            # 子图表征距离
            subgraph_dist = torch.nn.functional.pairwise_distance(
                ht_embed[j].unsqueeze(0), 
                train_ht_embed,
                p=2
            )
            
            # 描述特征距离
            descript_dist = torch.nn.functional.pairwise_distance(
                test_descript_feat[j].unsqueeze(0),
                train_descript_feat,
                p=2
            )
            descript_dist_ex = torch.nn.functional.pairwise_distance(
                test_descript_feat_reverse[j].unsqueeze(0),
                train_descript_feat,
                p=2
            )
            # 取最小距离
            descript_dist = torch.min(descript_dist, descript_dist_ex) 

            # 混合距离
            if dataset == 'S2':
                combined_dist =   lam * subgraph_dist +   descript_dist
            else:
                combined_dist =   subgraph_dist +  lam * descript_dist
            sorted_indices = torch.argsort(combined_dist).cpu().numpy()
            
            # 记录近邻信息
            if filter:
                original_indices = [filtered_indices[idx] for idx in sorted_indices[:20]]
                neighbors.append(original_indices)
                neighbors_labels.append([np.nonzero(row)[0].tolist() for row in train_labels[original_indices]])
                sorted_indices = original_indices
            else:
                neighbors.append(sorted_indices[:20])
                neighbors_labels.append([np.nonzero(row)[0].tolist() for row in train_labels[sorted_indices[:20]]])

            # 计算准确率
            true_label = np.nonzero(test_labels[j])[0].tolist()
            for k in k_values:
                pred = []
                for row in train_labels[sorted_indices[:k]]:
                    pred += np.nonzero(row)[0].tolist()
                pred = list(set(pred))
                recall = 0
                for r in true_label:
                    if r in pred:
                        recall += 1
                recall_at_ks[k] += recall / len(true_label) / ht_embed.shape[0]
                # if j < 5:
                #     print(j, recall, len(true_label), ht_embed.shape[0], recall_at_ks[k])
                # else:
                #     exit()
        
        # 保存结果
        results[lam] = recall_at_ks
        
        # 打印结果
        print(f"lambda={lam}:")
        print(" ".join([f"rec@{k}={v:.3f}" for k, v in recall_at_ks.items()]))

        # 记录详细信息
        if record:
            records = [{
                'id': idx,
                'K-neighbor': list(nbr),
                'K-neighbor_labels': lbl
            } for idx, (nbr, lbl) in enumerate(zip(neighbors, neighbors_labels))]
            
            filename = filename_template.format(dataset=dataset)
            if filter:
                filename = filename.replace('.json', '_filtered.json')
            with open(filename, 'w') as f:
                json.dump(records, f, indent=4)
            print(f"Records saved to {filename}")
            
    return results

if __name__ == '__main__':
    
    os.environ['CUDA_VISIBLE_DEVICES'] = '4'
    # lams = [0,0.1,0.3,0.5,0.8,1.0,2.0]
    # lams =[3.0, 5.0]
    # lams = [0,0.01, 0.05, 0.1,0.15]
    lams = [0.1]
    hybrid_knn('S1', filter=0, lambdas=lams, record=1)

