import json
import csv
import re
import os

# ========== File Paths ==========
json_path = './data/PrimeKGQA/Bioprocess.json'
nodes_path = './data/nodes.csv'
triplet_path = './data/triplet.csv'
output_json_path = './data/PrimeKGQA_path/Bioprocess.json'

# Ensure the output directory exists
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)

# ========== 1. Read nodes.csv ==========
# These dictionaries will store mappings for quick lookups.
protein_name_to_index = {}
index_to_name = {}
index_to_type = {}

print("Reading nodes.csv...")
with open(nodes_path, mode='r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        index = int(row['node_index'])
        node_name = row['node_name']
        node_type = row['node_type']
        index_to_name[index] = node_name
        index_to_type[index] = node_type
        # Create a lowercase mapping for protein names for case-insensitivity
        if node_type == 'gene/protein':
            protein_name_to_index[node_name.lower()] = index
print("Finished reading nodes.csv.")

# ========== 2. Read triplet.csv ==========
# These dictionaries store the relationships as sets for efficient operations.
ppi_relations = {}  # protein_index -> {set of interacting protein_indices}
pathway_relations = {}  # protein_index -> {set of pathway_indices}
bp_relations = {}  # protein_index -> {set of bioprocess_indices}
bp_to_proteins = {}  # bp_index -> {set of protein_indices}
pathway_to_proteins = {}  # pathway_index -> {set of protein_indices}

print("Reading triplet.csv...")
with open(triplet_path, mode='r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        display_relation = int(row['display_relation'])
        x_index = int(row['x_index'])
        y_index = int(row['y_index'])
        y_type = int(row['y_type'])
        x_type = index_to_type.get(x_index)

        # Protein-Protein Interaction (ppi) relationship (relation type 3)
        if display_relation == 3:
            ppi_relations.setdefault(x_index, set()).add(y_index)

        # Protein -> pathway relationship (relation type 2, y is a pathway)
        if display_relation == 2 and y_type == 8:
            pathway_relations.setdefault(x_index, set()).add(y_index)
            pathway_to_proteins.setdefault(y_index, set()).add(x_index)

        # Protein -> biological process (bp) relationship (relation type 2, y is a bp)
        if display_relation == 2 and y_type == 0:
            bp_relations.setdefault(x_index, set()).add(y_index)
            bp_to_proteins.setdefault(y_index, set()).add(x_index)
print("Finished reading triplet.csv.")

# ========== 3. Read JSON, process each entry, and find associated paths ==========
print(f"Reading and processing {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

# This regex extracts the protein name from the question string
pattern = r'Which biological process is associated with ([\w\s\-]+)\?'

for entry in data:
    path_info_lines = []  # A list to store the path information for the current entry
    input_text = entry.get('question', '')
    match = re.search(pattern, input_text)

    if not match:
        entry['path'] = "Could not parse protein name from input."
        continue

    protein_name = match.group(1).strip().lower()

    # Find the protein's index
    protein_index = protein_name_to_index.get(protein_name)
    if protein_index is None:
        entry['path'] = f"Protein '{protein_name}' not found in nodes.csv."
        continue

    path_info_lines.append(f"Protein: {index_to_name.get(protein_index, protein_name)}")

    # Get all proteins that interact with our main protein
    interacting_proteins = ppi_relations.get(protein_index, set())

    # --- BIOLOGICAL PROCESSES ---
    path_info_lines.append("\n" + "=" * 10 + " Biological Processes " + "=" * 10)
    direct_bps = bp_relations.get(protein_index, set())

    if not direct_bps:
        path_info_lines.append("  - This protein is not directly associated with any biological processes.")
    else:
        for bp_index in sorted(list(direct_bps)):
            bp_name = index_to_name.get(bp_index, f"Unknown BP (index={bp_index})")
            path_info_lines.append(f"\n  -> BP: {bp_name}")

            proteins_in_this_bp = bp_to_proteins.get(bp_index, set())
            shared_proteins = interacting_proteins.intersection(proteins_in_this_bp)

            if shared_proteins:
                path_info_lines.append("     - Shared Interacting Proteins:")
                for p_idx in sorted(list(shared_proteins)):
                    p_name = index_to_name.get(p_idx, f"Unknown Protein (index={p_idx})")
                    path_info_lines.append(f"       - {p_name}")
            else:
                path_info_lines.append("     - No interacting proteins found for this BP.")

    # --- PATHWAYS ---
    path_info_lines.append("\n" + "=" * 15 + " Pathways " + "=" * 15)
    direct_pathways = pathway_relations.get(protein_index, set())

    if not direct_pathways:
        path_info_lines.append("  - This protein is not directly associated with any pathways.")
    else:
        for path_index in sorted(list(direct_pathways)):
            path_name = index_to_name.get(path_index, f"Unknown Pathway (index={path_index})")
            path_info_lines.append(f"\n  -> Pathway: {path_name}")

            proteins_in_this_pathway = pathway_to_proteins.get(path_index, set())
            shared_proteins = interacting_proteins.intersection(proteins_in_this_pathway)

            if shared_proteins:
                path_info_lines.append("     - Shared Interacting Proteins:")
                for p_idx in sorted(list(shared_proteins)):
                    p_name = index_to_name.get(p_idx, f"Unknown Protein (index={p_idx})")
                    path_info_lines.append(f"       - {p_name}")
            else:
                path_info_lines.append("     - No interacting proteins found for this pathway.")

    # Join all collected lines into a single string and add it to the entry
    entry['path'] = "\n".join(path_info_lines)

# ========== 4. Save the updated data to a new JSON file ==========
print(f"Saving updated data to {output_json_path}...")
with open(output_json_path, 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=4)

print("Processing complete.")