import json
import csv
import re
import os

# ========== File Paths ==========

json_path = './data/PrimeKGQA/Off-label use.json'
nodes_path = './data/nodes.csv'
triplet_path = './data/triplet.csv'
output_json_path = './data/PrimeKGQA_path/Off-label use.json'

# Ensure the output directory exists
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)

# ========== 1. Read nodes.csv ==========
# These dictionaries will store mappings for quick lookups.
disease_name_to_index = {}
index_to_name = {}
index_to_type = {}

print("Reading nodes.csv...")
with open(nodes_path, mode='r', newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        index = int(row['node_index'])
        node_name = row['node_name']
        node_type = row['node_type']
        index_to_name[index] = node_name
        index_to_type[index] = node_type
        # Create a lowercase mapping for disease names for case-insensitivity
        if node_type == 'disease':
            disease_name_to_index[node_name.lower()] = index
print("Finished reading nodes.csv.")

# ========== 2. Read triplet.csv ==========
# These dictionaries store the relationships between nodes.
disease_to_protein = {}  # disease_index -> list of protein indices
protein_to_pathway = {}  # protein_index -> list of pathway indices
protein_to_bp = {}  # protein_index -> list of bioprocess indices
ppi_relation = {}  # protein_index -> list of protein indices
disease_to_drug = {}  # disease_index -> list of drug indices

print("Reading triplet.csv...")
with open(triplet_path, mode='r', newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        try:
            display_relation = int(row['display_relation'])
            x_index = int(row['x_index'])
            y_index = int(row['y_index'])
            y_type = int(row['y_type'])
            x_type = index_to_type.get(x_index)
        except (ValueError, KeyError):
            continue  # Skip malformed rows

        # Disease -> Protein (relation 6, x is disease, y is protein)
        if x_type == 'disease' and display_relation == 6 and y_type == 1:
            disease_to_protein.setdefault(x_index, []).append(y_index)

        # Protein -> Pathway (relation 2, x is protein, y is pathway)
        if x_type == 'gene/protein' and display_relation == 2 and y_type == 8:
            protein_to_pathway.setdefault(x_index, []).append(y_index)

        # Protein -> BP (relation 2, x is protein, y is bp)
        if x_type == 'gene/protein' and display_relation == 2 and y_type == 0:
            protein_to_bp.setdefault(x_index, []).append(y_index)

        # Protein -> Protein (relation 3, x is protein, y is protein)
        if x_type == 'gene/protein' and display_relation == 3 and y_type == 1:
            ppi_relation.setdefault(x_index, []).append(y_index)

        # Disease -> Drug (relation 14, x is disease, y is drug)
        if x_type == 'disease' and display_relation == 14 and y_type == 6:
            disease_to_drug.setdefault(x_index, []).append(y_index)
print("Finished reading triplet.csv.")

# ========== 3. Read JSON, process each entry, and find associated paths ==========
print(f"Reading and processing {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

# This regex extracts the disease name from the question string
pattern = r'Which drug is used off-label for ([\w\s\-]+)\?'

for entry in data:
    path_info_lines = []  # A list to store the path information for the current entry
    input_text = entry.get('question', '')
    match = re.search(pattern, input_text)

    if not match:
        entry['path'] = "Could not parse disease name from input."
        continue

    disease_name = match.group(1).strip().lower()

    # Find the disease's index
    disease_index = disease_name_to_index.get(disease_name)
    if disease_index is None:
        entry['path'] = f"Disease '{disease_name}' not found in nodes.csv."
        continue

    path_info_lines.append(f"Disease: {index_to_name.get(disease_index, disease_name)}")

    # --- Disease -> Protein -> (Pathways, BPs, other Proteins) ---
    path_info_lines.append("\n" + "=" * 10 + " Associated Proteins & Connections " + "=" * 10)
    protein_indices = disease_to_protein.get(disease_index, [])
    if not protein_indices:
        path_info_lines.append("  - No associated proteins found.")
    else:
        for protein_index in protein_indices:
            protein_name = index_to_name.get(protein_index, f"Unknown (Index: {protein_index})")
            path_info_lines.append(f"\n  -> Protein: {protein_name}")

            # Find Protein -> Pathway
            pathway_indices = protein_to_pathway.get(protein_index, [])
            if pathway_indices:
                pathway_names = ", ".join([index_to_name.get(p, 'Unknown') for p in pathway_indices])
                path_info_lines.append(f"     - Pathways: {pathway_names}")
            else:
                path_info_lines.append("     - Pathways: None found.")

            # Find Protein -> Biological Process (BP)
            bp_indices = protein_to_bp.get(protein_index, [])
            if bp_indices:
                bp_names = ", ".join([index_to_name.get(b, 'Unknown') for b in bp_indices])
                path_info_lines.append(f"     - Biological Processes: {bp_names}")
            else:
                path_info_lines.append("     - Biological Processes: None found.")

            # Find Protein -> Protein (PPI)
            interacting_proteins = ppi_relation.get(protein_index, [])
            if interacting_proteins:
                ppi_names = ", ".join([index_to_name.get(p, 'Unknown') for p in interacting_proteins])
                path_info_lines.append(f"     - Interacts With: {ppi_names}")
            else:
                path_info_lines.append("     - Interacts With: None found.")

    # --- Disease -> Drug ---
    path_info_lines.append("\n" + "=" * 10 + " Associated Drugs " + "=" * 10)
    drug_indices = disease_to_drug.get(disease_index, [])
    if drug_indices:
        for drug_index in drug_indices:
            drug_name = index_to_name.get(drug_index, f"Unknown (Index: {drug_index})")
            path_info_lines.append(f"  - {drug_name}")
    else:
        path_info_lines.append("  - No directly associated drugs found.")

    # Join all collected lines into a single string and add it to the entry
    entry['path'] = "\n".join(path_info_lines)

# ========== 4. Save the updated data to a new JSON file ==========
print(f"Saving updated data to {output_json_path}...")
with open(output_json_path, 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=4)

print("Processing complete.")