import json
import csv
import re
import os

# ========== File Paths ==========

json_path = './data/PrimeKGQA/contraindication.json'
nodes_path = './data/nodes.csv'
triplet_path = './data/triplet.csv'
output_json_path = './data/PrimeKGQA_path/Contraindication.json'

# Ensure the output directory exists
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)

# ========== 1. Read nodes.csv ==========
print("Reading nodes.csv...")
drug_name_to_index = {}
index_to_name = {}
index_to_type = {}

with open(nodes_path, mode='r', newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        index = int(row['node_index'])
        node_name = row['node_name']
        node_type = row['node_type']
        index_to_name[index] = node_name
        index_to_type[index] = node_type
        if node_type == 'drug':
            drug_name_to_index[node_name.lower()] = index
print("Finished reading nodes.csv.")

# ========== 2. Read triplet.csv and build relationship maps ==========
print("Reading triplet.csv...")
# Forward maps from drug/protein
drug_to_diseases = {}
drug_to_proteins = {}
drug_to_phenotypes = {}
protein_to_pathways = {}
protein_to_bps = {}

# Inverse maps from disease
disease_to_proteins = {}

with open(triplet_path, mode='r', newline='', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        try:
            display_relation = int(row['display_relation'])
            x_index = int(row['x_index'])
            y_index = int(row['y_index'])
            x_type_str = index_to_type.get(x_index)
            y_type_str = index_to_type.get(y_index)
        except (ValueError, KeyError):
            continue

        # drug -> disease (contraindication, relation 8)
        if x_type_str == 'drug' and y_type_str == 'disease' and display_relation == 8:
            drug_to_diseases.setdefault(x_index, set()).add(y_index)

        # drug -> protein (relations 10-target, 12-enzyme, 13-transporter)
        if x_type_str == 'drug' and y_type_str == 'gene/protein' and display_relation in [10, 12, 13]:
            drug_to_proteins.setdefault(x_index, set()).add(y_index)

        # drug -> phenotype (side effect, relation 7)
        if x_type_str == 'drug' and y_type_str == 'effect/phenotype' and display_relation == 7:
            drug_to_phenotypes.setdefault(x_index, set()).add(y_index)

        # disease -> protein (relation 6) for inverse map
        if x_type_str == 'disease' and y_type_str == 'gene/protein' and display_relation == 6:
            disease_to_proteins.setdefault(x_index, set()).add(y_index)

        # protein -> X
        if x_type_str == 'gene/protein':
            # protein -> pathway (relation 2)
            if y_type_str == 'pathway' and display_relation == 2:
                protein_to_pathways.setdefault(x_index, set()).add(y_index)
            # protein -> bp (relation 2)
            if y_type_str == 'biological_process' and display_relation == 2:
                protein_to_bps.setdefault(x_index, set()).add(y_index)
print("Finished reading triplet.csv.")

# ========== 3. Read JSON, process each entry, and find associated paths ==========
print(f"Reading and processing {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

# A plausible pattern for this logic
pattern = r'Which disease is contraindication for ([\w\s\-]+)\?'

for entry in data:
    path_info_lines = []
    input_text = entry.get('question', '')
    match = re.search(pattern, input_text)
    if not match:
        entry['path'] = "Could not parse drug name from input."
        continue

    drug_name = match.group(1).strip().lower()

    drug_index = drug_name_to_index.get(drug_name)
    if drug_index is None:
        entry['path'] = f"Drug '{drug_name}' not found in nodes.csv."
        continue

    path_info_lines.append(f"Drug: {index_to_name.get(drug_index, drug_name)}")

    # Get all primary connections for the drug for use in intersections
    direct_diseases = drug_to_diseases.get(drug_index, set())
    direct_proteins = drug_to_proteins.get(drug_index, set())
    direct_phenotypes = drug_to_phenotypes.get(drug_index, set())

    if not direct_diseases:
        path_info_lines.append("\n  - No contraindicated diseases found for this drug.")
    else:
        for disease_id in sorted(list(direct_diseases)):
            disease_name = index_to_name.get(disease_id, f"Unknown Disease (ID: {disease_id})")
            path_info_lines.append(f"\n  -> Contraindicated Disease: {disease_name}")

            # Path 1: Find shared Proteins (Disease -> Protein <- Drug)
            proteins_from_disease = disease_to_proteins.get(disease_id, set())
            shared_proteins = direct_proteins.intersection(proteins_from_disease)
            path_info_lines.append(f"     - Shared Proteins (with Drug):")

            if shared_proteins:
                for prot_id in sorted(list(shared_proteins)):
                    prot_name = index_to_name.get(prot_id, f"Unknown Protein (ID: {prot_id})")
                    path_info_lines.append(f"       -> Protein: {prot_name}")

                    # Drilldown Path: Protein -> Pathway / BP
                    pathways = protein_to_pathways.get(prot_id, set())
                    bps = protein_to_bps.get(prot_id, set())

                    if pathways:
                        path_names = ", ".join(sorted([index_to_name.get(p, 'Unknown') for p in pathways]))
                        path_info_lines.append(f"          - Pathways: {path_names}")
                    if bps:
                        bp_names = ", ".join(sorted([index_to_name.get(b, 'Unknown') for b in bps]))
                        path_info_lines.append(f"          - Biological Processes: {bp_names}")
                    if not pathways and not bps:
                        path_info_lines.append(f"          - No associated pathways or BPs found.")
            else:
                path_info_lines.append("       - None found.")

    # Path 2: List the drug's direct side effects
    path_info_lines.append("\n" + "=" * 10 + " Associated Side Effects (Phenotypes) " + "=" * 10)
    if direct_phenotypes:
        pheno_names = ", ".join(sorted([index_to_name.get(p, 'Unknown') for p in direct_phenotypes]))
        path_info_lines.append(f"  - {pheno_names}")
    else:
        path_info_lines.append("  - No associated side effects found.")

    # Join all collected lines into a single string and add it to the entry
    entry['path'] = "\n".join(path_info_lines)

# ========== 4. Save the updated data to a new JSON file ==========
print(f"Saving updated data to {output_json_path}...")
with open(output_json_path, 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=4)

print("Processing complete.")