import csv
import json
import re
import os

# ========== File Paths ==========

json_path = './data/PrimeKGQA/Disease-Protein.json'
nodes_path = './data/nodes.csv'
triplet_path = './data/triplet.csv'
output_json_path = './data/PrimeKGQA_path/Disease-Protein.json'

# Ensure the output directory exists
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)

# ========== 1. Read nodes.csv ==========
print("Reading nodes.csv...")
disease_name_to_index = {}
index_to_name = {}
index_to_type = {}

with open(nodes_path, mode='r', newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        index = int(row['node_index'])
        node_name = row['node_name']
        node_type = row['node_type']
        index_to_name[index] = node_name
        index_to_type[index] = node_type
        if node_type == 'disease':
            disease_name_to_index[node_name.lower()] = index
print("Finished reading nodes.csv.")

# ========== 2. Read triplet.csv ==========
print("Reading triplet.csv...")
# These dictionaries will store relationships as sets for efficient intersection operations.
disease_to_effects = {}
disease_to_proteins = {}
disease_to_drug = {}
protein_to_effects = {}
protein_to_drug = {}
protein_to_pathways = {}
protein_to_bps = {}

with open(triplet_path, mode='r', newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        try:
            display_relation = int(row['display_relation'])
            x_index = int(row['x_index'])
            y_index = int(row['y_index'])
            x_type_str = index_to_type.get(x_index)
            y_type_str = index_to_type.get(y_index)
        except (ValueError, KeyError):
            continue

        # Disease -> Effect/Phenotype (relation 4)
        if x_type_str == 'disease' and display_relation == 4:
            disease_to_effects.setdefault(x_index, set()).add(y_index)

        # Disease -> Protein (relation 6)
        if x_type_str == 'disease' and display_relation == 6:
            disease_to_proteins.setdefault(x_index, set()).add(y_index)

        # Disease -> Drug (Indication, relation 11)
        if x_type_str == 'disease' and display_relation == 11:
            disease_to_drug.setdefault(x_index, set()).add(y_index)

        # Protein -> Effect/Phenotype (relation 6)
        if x_type_str == 'gene/protein' and display_relation == 6:
            protein_to_effects.setdefault(x_index, set()).add(y_index)

        # Protein -> Pathway (relation 2)
        if x_type_str == 'gene/protein' and display_relation == 2 and y_type_str == 'pathway':
            protein_to_pathways.setdefault(x_index, set()).add(y_index)

        # Protein -> Biological Process (relation 2)
        if x_type_str == 'gene/protein' and display_relation == 2 and y_type_str == 'biological_process':
            protein_to_bps.setdefault(x_index, set()).add(y_index)

        # Protein -> Drug (Target, relation 10)
        if x_type_str == 'gene/protein' and display_relation == 10:
            protein_to_drug.setdefault(x_index, set()).add(y_index)
print("Finished reading triplet.csv.")

# ========== 3. Read JSON, process each entry, and find associated paths ==========
print(f"Reading and processing {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

# This regex extracts the disease name from the question string
pattern = r'Which protein is associated with ([\w\s\-]+)\?'

for entry in data:
    path_info_lines = []
    input_text = entry.get('question', '')
    match = re.search(pattern, input_text)
    if not match:
        entry['path'] = "Could not parse disease name from input."
        continue

    disease_name = match.group(1).strip().lower()

    disease_index = disease_name_to_index.get(disease_name)
    if disease_index is None:
        entry['path'] = f"Disease '{disease_name}' not found in nodes.csv."
        continue

    path_info_lines.append(f"Disease: {index_to_name.get(disease_index, disease_name)}")

    # Get all entities directly related to the disease to use for intersections later
    associated_proteins = disease_to_proteins.get(disease_index, set())
    direct_disease_effects = disease_to_effects.get(disease_index, set())
    direct_disease_drugs = disease_to_drug.get(disease_index, set())

    if not associated_proteins:
        path_info_lines.append("  - No associated proteins found for this disease.")
    else:
        for protein_id in sorted(list(associated_proteins)):
            protein_name = index_to_name.get(protein_id, f"Unknown (ID: {protein_id})")
            path_info_lines.append(f"\n  -> Associated Protein: {protein_name}")

            # Find shared Effects/Phenotypes (Protein -> Effect <- Disease)
            effects_from_protein = protein_to_effects.get(protein_id, set())
            shared_effects = direct_disease_effects.intersection(effects_from_protein)
            if shared_effects:
                effect_names = ", ".join(sorted([index_to_name.get(e, 'Unknown') for e in shared_effects]))
                path_info_lines.append(f"     - Shared Effects (with Disease): {effect_names}")
            else:
                path_info_lines.append("     - Shared Effects (with Disease): None found.")

            # Find shared Drugs (Protein -> Drug <- Disease)
            drugs_from_protein = protein_to_drug.get(protein_id, set())
            shared_drugs = direct_disease_drugs.intersection(drugs_from_protein)
            if shared_drugs:
                drug_names = ", ".join(sorted([index_to_name.get(d, 'Unknown') for d in shared_drugs]))
                path_info_lines.append(f"     - Shared Drugs (with Disease): {drug_names}")
            else:
                path_info_lines.append("     - Shared Drugs (with Disease): None found.")

            # Protein -> Pathway
            pathways = protein_to_pathways.get(protein_id, set())
            if pathways:
                pathway_names = ", ".join(sorted([index_to_name.get(p, 'Unknown') for p in pathways]))
                path_info_lines.append(f"     - Pathways: {pathway_names}")
            else:
                path_info_lines.append("     - Pathways: None found.")

            # Protein -> Biological Process
            bps = protein_to_bps.get(protein_id, set())
            if bps:
                bp_names = ", ".join(sorted([index_to_name.get(b, 'Unknown') for b in bps]))
                path_info_lines.append(f"     - Biological Processes: {bp_names}")
            else:
                path_info_lines.append("     - Biological Processes: None found.")

    # Join all collected lines into a single string and add it to the entry
    entry['path'] = "\n".join(path_info_lines)

# ========== 4. Save the updated data to a new JSON file ==========
print(f"Saving updated data to {output_json_path}...")
with open(output_json_path, 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=4)

print("Processing complete.")