import os
import torch
import pandas as pd
from torch.utils.data import Dataset
from torch_geometric.data import Data
import datasets
from tqdm import tqdm
import json
from src.dataset.utils.retrieval_pcst import retrieval_via_pcst, retrieval_via_pcst_plus
from src.dataset.utils.retrieval_graphpack import retrieval_graphpack
from src.utils.lm_modeling import load_model, load_text2embedding

model_name = 'sbert'
dataset_path = 'dataset'
path = 'dataset/webqsp'
path_nodes = f'{path}/nodes'
path_edges = f'{path}/edges'
path_graphs = f'{path}/graphs'
file_name = f'{path}/data_gt.json'

cached_graph = f'{path}/cached_graphs_gt'
cached_desc = f'{path}/cached_desc_gt'


class WebQSPDataset(Dataset):
    def __init__(self):
        super().__init__()
        self.prompt = 'Please answer the given question.'
        self.graph = None
        self.graph_type = 'Knowledge Graph'
        dataset = datasets.load_dataset(f"{dataset_path}/RoG-webqsp")
        self.dataset = datasets.concatenate_datasets([dataset['train'], dataset['validation'], dataset['test']])
        self.q_embs = torch.load(f'{path}/q_embs.pt')

    def __len__(self):
        """Return the len of the dataset."""
        return len(self.dataset)

    def __getitem__(self, index):
        data = self.dataset[index]
        question = f'Question: {data["question"]}\nAnswer: '
        graph = torch.load(f'{cached_graph}/{index}.pt')
        graph = torch.load(f'{cached_graph}_with_query1/{index}.pt')
        desc = open(f'{cached_desc}/{index}.txt', 'r').read().split('\n\n')[0]
        label = ('|').join(data['answer']).lower()

        return {
            'id': index,
            'question': question,
            'label': label,
            'graph': graph,
            'desc': desc,
        }

    def get_idx_split(self):

        # Load the saved indices
        with open(f'{path}/split/train_indices.txt', 'r') as file:
            train_indices = [int(line.strip()) for line in file]
        with open(f'{path}/split/val_indices.txt', 'r') as file:
            val_indices = [int(line.strip()) for line in file]
        with open(f'{path}/split/test_indices.txt', 'r') as file:
            test_indices = [int(line.strip()) for line in file]

        return {'train': train_indices, 'val': val_indices, 'test': test_indices}



def preprocess():
    os.makedirs(cached_desc, exist_ok=True)
    os.makedirs(cached_graph, exist_ok=True)
    dataset = datasets.load_dataset(f"{dataset_path}/RoG-webqsp")
    dataset = datasets.concatenate_datasets([dataset['train'], dataset['validation'], dataset['test']]) # ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices']
    datas = []

    q_embs = torch.load(f'{path}/q_embs.pt')
    # for index in tqdm(range(len(dataset))):
    for index in tqdm(range(4521,4522)):
        nodes = pd.read_csv(f'{path_nodes}/{index}.csv')
        edges = pd.read_csv(f'{path_edges}/{index}.csv')
        if len(nodes) == 0:
            print(f'Empty graph at index {index}')
            continue
        graph = torch.load(f'{path_graphs}/{index}.pt')
        # subg, desc = retrieval_via_pcst(graph, q_embs[index], nodes, edges, topk=3, topk_e=5, cost_e=0.5)
        subg, desc = retrieval_graphpack(graph, q_embs[index], -1, nodes, edges, topk=2, n=3, load=20)
        torch.save(subg, f'{cached_graph}/{index}.pt')
        open(f'{cached_desc}/{index}.txt', 'w').write(desc)

        data = {
            'index': index,
            'question': dataset[index]['question'],
            'desc': desc,
            'output': ('|').join(dataset[index]['answer']).lower(),
            'nodes': str([i.split(',')[1] for i in desc.split('\n\n')[0].split('\n')[1:]]),
            'label': str(dataset[index]['answer'])
        }
        datas.append(data)

    with open(file_name, "w", encoding="utf-8") as json_file:
        json.dump(datas, json_file, ensure_ascii=False, indent=4)

def preprocess_II():
    os.makedirs(f'{cached_graph}', exist_ok=True)
    dataset = datasets.load_dataset(f"{dataset_path}/RoG-webqsp")
    dataset = datasets.concatenate_datasets([dataset['train'], dataset['validation'], dataset['test']]) # ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices']
    datas = []
    # 处理查询边的嵌入
    model, tokenizer, device = load_model[model_name]()
    text2embedding = load_text2embedding[model_name]
    q_e_emb = text2embedding(model, tokenizer, device, ['query'])[0]

    q_embs = torch.load(f'{path}/q_embs.pt')
    for index in tqdm(range(len(dataset))):
        q_emb = q_embs[index]
        if not os.path.exists(f'{cached_graph}/{index}.pt') :
            print(f'Empty graph at index {index}')
            continue
        graph = torch.load(f'{cached_graph}/{index}.pt')
        # 将查询嵌入当作一个节点加入到grpah里
        num_nodes = graph.num_nodes  # 当前图的节点数
        x = graph.x  # 节点特征
        edge_index = graph.edge_index  # 边索引
        edge_attr = graph.edge_attr  # 边特征（如果有）

        # 1. 扩展节点特征
        # 将 q_emb 从一维扩展为二维 [1, num_node_features]
        q_emb = q_emb.unsqueeze(0)  # 现在形状为 [1, num_node_features]
        new_x = torch.cat([x, q_emb], dim=0)  # 将查询节点的嵌入添加到节点特征中

        # 2. 更新边索引
        # 创建双向边：查询节点 <-> 所有其他节点
        query_node_idx = num_nodes  # 查询节点的索引是当前节点数
        forward_edges = torch.stack([
            torch.full((num_nodes,), query_node_idx, dtype=torch.long),  # 源节点是查询节点
            torch.arange(num_nodes, dtype=torch.long)  # 目标节点是所有现有节点
        ], dim=0)
        
        backward_edges = torch.stack([
            torch.arange(num_nodes, dtype=torch.long),  # 源节点是所有现有节点
            torch.full((num_nodes,), query_node_idx, dtype=torch.long)  # 目标节点是查询节点
        ], dim=0)

        # 合并双向边
        new_edges = torch.cat([forward_edges, backward_edges], dim=1)
        new_edge_index = torch.cat([edge_index, new_edges], dim=1)

        # 3. 扩展边特征
        # 如果原图有边特征，确保 q_e_emb 的形状匹配
        if q_e_emb.dim() == 1:  # 如果 q_e_emb 是一维的
            q_e_embs = q_e_emb.unsqueeze(0)  # 转换为二维 [1, num_edge_features]
        if q_e_embs.size(0) == 1:  # 如果 q_e_emb 只有一行
            q_e_embs = q_e_embs.repeat(num_nodes * 2, 1)  # 复制以匹配新边的数量（双向边）
        new_edge_attr = torch.cat([edge_attr, q_e_embs], dim=0)

        # 构造新的图
        new_graph = Data(x=new_x, edge_index=new_edge_index, edge_attr=new_edge_attr, num_nodes=num_nodes+1)

        torch.save(new_graph, f'{cached_graph}_with_query1/{index}.pt')

def calculate_metrics(data):
    """
    计算召回率、精确率和 F1 分数。
    
    参数:
        data (list): JSON 数据，每个元素是一个字典，包含 'nodes' 和 'label'。
    
    返回:
        dict: 包含每个样本的指标以及整体平均指标。
    """
    results = []
    total_precision = 0
    total_recall = 0
    total_f1 = 0
    
    for item in data:
        nodes = set(item.get("nodes", []))  # 检索到的实体集合
        label = set(item.get("label", []))  # 正确的实体集合
        
        # 计算 True Positives (TP), False Positives (FP), False Negatives (FN)
        TP = len(nodes & label)  # 正确检索到的实体数
        FP = len(nodes - label)  # 错误检索到的实体数
        FN = len(label - nodes)  # 未检索到的正确实体数
        
        # 计算 Precision, Recall, F1
        precision = TP / (TP + FP) if (TP + FP) > 0 else 0
        recall = TP / (TP + FN) if (TP + FN) > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        
        # 保存当前样本的结果
        results.append({
            "nodes": list(nodes),
            "label": list(label),
            "precision": precision,
            "recall": recall,
            "f1": f1
        })
        
        # 累加指标用于计算整体平均值
        total_precision += precision
        total_recall += recall
        total_f1 += f1
    
    # 计算整体平均指标
    num_samples = len(data)
    avg_precision = total_precision / num_samples if num_samples > 0 else 0
    avg_recall = total_recall / num_samples if num_samples > 0 else 0
    avg_f1 = total_f1 / num_samples if num_samples > 0 else 0
    
    return {
        "individual_results": results,
        "average_metrics": {
            "average_precision": avg_precision,
            "average_recall": avg_recall,
            "average_f1": avg_f1
        }
    }

if __name__ == '__main__':

    preprocess()
    preprocess_II()

    dataset = WebQSPDataset()

    data = dataset[1]
    for k, v in data.items():
        if k != 'desc':
            print(f'{k}: {v}')
    

    split_ids = dataset.get_idx_split()
    for k, v in split_ids.items():
        print(f'# {k}: {len(v)}')

    all_nodes = 0
    for i in range(2500):
        if i == 2937:
            continue
        all_nodes = all_nodes + dataset[i]['graph'].num_nodes
    print(all_nodes/len(dataset))
    
    with open(f'{path}/data_gt.json', "r", encoding="utf-8") as f:
        data = json.load(f)
    metrics = calculate_metrics(data)
    print("\nAverage Metrics:")
    print(f"  Average Precision: {metrics['average_metrics']['average_precision']:.4f}")
    print(f"  Average Recall: {metrics['average_metrics']['average_recall']:.4f}")
    print(f"  Average F1: {metrics['average_metrics']['average_f1']:.4f}")