import json
import csv
import re
import os

# ========== File Paths ==========

json_path = './data/PrimeKGQA/DDI.json'
nodes_path = './data/nodes.csv'
triplet_path = './data/triplet.csv'
output_json_path = './data/PrimeKGQA_path/DDI.json'

# Ensure the output directory exists
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)

# ========== 1. Read nodes.csv ==========
print("Reading nodes.csv...")
drug_name_to_index = {}
index_to_name = {}
index_to_type = {}

with open(nodes_path, mode='r', newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        index = int(row['node_index'])
        node_name = row['node_name']
        node_type = row['node_type']
        index_to_name[index] = node_name
        index_to_type[index] = node_type
        if node_type == 'drug':
            drug_name_to_index[node_name.lower()] = index
print("Finished reading nodes.csv.")

# ========== 2. Read triplet.csv and build relationship maps ==========
print("Reading triplet.csv...")
drug_to_proteins = {}  # drug -> {proteins}
drug_to_phenotypes = {}  # drug -> {phenotypes} (side effects)
protein_to_pathways = {}  # protein -> {pathways}
protein_to_bps = {}  # protein -> {biological processes}
ddi_relations = {}  # drug -> {interacting drugs}

with open(triplet_path, mode='r', newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        try:
            display_relation = int(row['display_relation'])
            x_index = int(row['x_index'])
            y_index = int(row['y_index'])
            x_type_str = index_to_type.get(x_index)
            y_type_str = index_to_type.get(y_index)
        except (ValueError, KeyError):
            continue

        # drug -> protein (relations 10-target, 12-enzyme, 13-transporter)
        if x_type_str == 'drug' and y_type_str == 'gene/protein' and display_relation in [10, 12, 13]:
            drug_to_proteins.setdefault(x_index, set()).add(y_index)

        # drug -> phenotype (side effect, relation 7)
        if x_type_str == 'drug' and y_type_str == 'effect/phenotype' and display_relation == 7:
            drug_to_phenotypes.setdefault(x_index, set()).add(y_index)

        # drug -> drug (synergistic interaction, relation 1)
        if x_type_str == 'drug' and y_type_str == 'drug' and display_relation == 1:
            ddi_relations.setdefault(x_index, set()).add(y_index)
            ddi_relations.setdefault(y_index, set()).add(x_index)  # Make interaction bidirectional

        # protein -> X
        if x_type_str == 'gene/protein':
            # protein -> pathway (relation 2)
            if y_type_str == 'pathway' and display_relation == 2:
                protein_to_pathways.setdefault(x_index, set()).add(y_index)
            # protein -> bp (relation 2)
            if y_type_str == 'biological_process' and display_relation == 2:
                protein_to_bps.setdefault(x_index, set()).add(y_index)
print("Finished reading triplet.csv.")

# ========== 3. Read JSON, process each entry, and find DDI paths ==========
print(f"Reading and processing {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

pattern = r'Which drug has a drug drug interaction with ([\w\s\-]+)\?'

for entry in data:
    path_info_lines = []
    input_text = entry.get('input', '')
    match = re.search(pattern, input_text)
    if not match:
        entry['path'] = "Could not parse drug name from input."
        continue

    drug1_name = match.group(1).strip().lower()

    drug1_index = drug_name_to_index.get(drug1_name)
    if drug1_index is None:
        entry['path'] = f"Drug '{drug1_name}' not found in nodes.csv."
        continue

    # --- Start building path for Drug 1 ---
    path_info_lines.append(f"Drug 1: {index_to_name.get(drug1_index, drug1_name)}")

    # Get all proteins for Drug 1 and their aggregated pathways/BPs
    proteins1 = drug_to_proteins.get(drug1_index, set())
    if proteins1:
        protein1_names = ", ".join(sorted([index_to_name.get(p, 'Unknown') for p in proteins1]))
        path_info_lines.append(f"  - Associated Proteins (Protein Set 1): {protein1_names}")
    else:
        path_info_lines.append("  - Associated Proteins (Protein Set 1): None found.")

    pathways1 = set().union(*(protein_to_pathways.get(p_id, set()) for p_id in proteins1))
    bps1 = set().union(*(protein_to_bps.get(p_id, set()) for p_id in proteins1))

    # Find all interacting drugs (Drug 2)
    interacting_drugs = ddi_relations.get(drug1_index, set())

    if not interacting_drugs:
        path_info_lines.append("\n  - No drug-drug interactions found.")
    else:
        # For each interacting drug, find its info and the shared functions
        for drug2_index in sorted(list(interacting_drugs)):
            drug2_name = index_to_name.get(drug2_index, f"Unknown (ID: {drug2_index})")
            path_info_lines.append(f"\n  -> Interacting Drug (Drug 2): {drug2_name}")

            # Get proteins for Drug 2 and their aggregated pathways/BPs
            proteins2 = drug_to_proteins.get(drug2_index, set())
            if proteins2:
                protein2_names = ", ".join(sorted([index_to_name.get(p, 'Unknown') for p in proteins2]))
                path_info_lines.append(f"     - Associated Proteins (Protein Set 2): {protein2_names}")
            else:
                path_info_lines.append("     - Associated Proteins (Protein Set 2): None found.")

            pathways2 = set().union(*(protein_to_pathways.get(p_id, set()) for p_id in proteins2))
            bps2 = set().union(*(protein_to_bps.get(p_id, set()) for p_id in proteins2))

            # Find intersections and effects
            shared_pathways = pathways1.intersection(pathways2)
            shared_bps = bps1.intersection(bps2)
            effects2 = drug_to_phenotypes.get(drug2_index, set())

            # Format and append results
            if shared_pathways:
                path_names = ", ".join(sorted([index_to_name.get(p, 'Unknown') for p in shared_pathways]))
                path_info_lines.append(f"     - Shared Pathways: {path_names}")
            else:
                path_info_lines.append("     - Shared Pathways: None found.")

            if shared_bps:
                bp_names = ", ".join(sorted([index_to_name.get(b, 'Unknown') for b in shared_bps]))
                path_info_lines.append(f"     - Shared Biological Processes: {bp_names}")
            else:
                path_info_lines.append("     - Shared Biological Processes: None found.")

            if effects2:
                effect_names = ", ".join(sorted([index_to_name.get(e, 'Unknown') for e in effects2]))
                path_info_lines.append(f"     - Side Effects of {drug2_name}: {effect_names}")
            else:
                path_info_lines.append(f"     - Side Effects of {drug2_name}: None found.")

    # Join all lines and save to the entry
    entry['path'] = "\n".join(path_info_lines)

# ========== Save the updated data to a new JSON file ==========
print(f"Saving updated data to {output_json_path}...")
with open(output_json_path, 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=4)

print("Processing complete.")