import json
import csv
import re

# ========== File Paths ==========
json_path = './data/PrimeKGQA/Indication.json'
nodes_path = './data/nodes.csv'
triplet_path = './data/triplet.csv'
output_json_path = './data/PrimeKGQA_path/Indication.json'

# Ensure the output directory exists
os.makedirs(os.path.dirname(output_json_path), exist_ok=True)

# ========== 1. Read nodes.csv ==========
# These dictionaries will store mappings for quick lookups.
drug_name_to_index = {}
index_to_name = {}
index_to_type = {}

print("Reading nodes.csv...")
with open(nodes_path, mode='r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        index = int(row['node_index'])
        node_name = row['node_name']
        node_type = row['node_type']
        index_to_name[index] = node_name
        index_to_type[index] = node_type
        # Create a lowercase mapping for drug names to handle case-insensitivity
        if node_type == 'drug':
            drug_name_to_index[node_name.lower()] = index
print("Finished reading nodes.csv.")

# ========== 2. Read triplet.csv ==========
# These dictionaries store the relationships between nodes.
target_relations = {}  # drug_index -> list of target indices
pathway_relations = {}  # target_index -> list of pathway indices
pathway_parent_map = {}  # child_pathway_index -> list of parent_pathway_indices
target_disease_relations = {}  # target_index -> list of disease indices

print("Reading triplet.csv...")
with open(triplet_path, mode='r', encoding='utf-8') as f:
    reader = csv.DictReader(f)
    for row in reader:
        display_relation = int(row['display_relation'])
        x_index = int(row['x_index'])
        y_index = int(row['y_index'])
        y_type = int(row['y_type'])

        # drug -> target relationship (relation type 10)
        if display_relation == 10:
            target_relations.setdefault(x_index, []).append(y_index)

        # target -> pathway relationship (relation type 2, y is a pathway)
        if display_relation == 2 and y_type == 8:
            pathway_relations.setdefault(x_index, []).append(y_index)

        # pathway -> parent pathway relationship (relation type 5, y is a pathway)
        if display_relation == 5 and y_type == 8:
            pathway_parent_map.setdefault(x_index, []).append(y_index)

        # target -> disease relationship (relation type 6, y is a disease)
        if display_relation == 6 and y_type == 2:
            target_disease_relations.setdefault(x_index, []).append(y_index)
print("Finished reading triplet.csv.")

# ========== 3. Read JSON and process each entry, and find associated paths  ==========
print(f"Reading and processing {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f:
    data = json.load(f)

# This regex extracts the drug name from the question string
pattern = r'Which disease can be treated with ([\w\s\-]+)\?'

for entry in data:
    path_info_lines = []  # A list to store the path information for the current entry
    input_text = entry.get('question', '')
    match = re.search(pattern, input_text)

    if not match:
        entry['path'] = "Could not parse drug name from question."
        continue

    drug_name = match.group(1).strip().lower()
    path_info_lines.append(f"Drug: {index_to_name.get(drug_name_to_index.get(drug_name), drug_name)}")

    # Find the drug's index
    drug_index = drug_name_to_index.get(drug_name)
    if drug_index is None:
        path_info_lines.append("  - Drug index not found in nodes.csv.")
        entry['path'] = "\n".join(path_info_lines)
        continue

    # Find associated targets (relation == 10)
    target_indices = target_relations.get(drug_index, [])
    if not target_indices:
        path_info_lines.append("  - No associated targets found.")
        entry['path'] = "\n".join(path_info_lines)
        continue

    path_info_lines.append("  Targets:")
    for target_index in target_indices:
        target_name = index_to_name.get(target_index, f"Unknown (index={target_index})")
        path_info_lines.append(f"    - {target_name}")

        # For each target, find its associated pathways
        pathways = pathway_relations.get(target_index, [])
        if pathways:
            path_info_lines.append("      Pathways:")
            for path_index in pathways:
                path_name = index_to_name.get(path_index, f"Unknown (index={path_index})")
                path_info_lines.append(f"        - {path_name}")

                # Find parent pathways (relation == 5)
                parent_indices = pathway_parent_map.get(path_index, [])
                if parent_indices:
                    path_info_lines.append("          Parent Pathways:")
                    for parent_index in parent_indices:
                        parent_name = index_to_name.get(parent_index, f"Unknown (index={parent_index})")
                        path_info_lines.append(f"            - {parent_name}")
        else:
            path_info_lines.append("      - No associated pathways found for this target.")

        # Find diseases related to the target (relation == 6)
        diseases = target_disease_relations.get(target_index, [])
        if diseases:
            path_info_lines.append("      Related Diseases:")
            for disease_index in diseases:
                disease_name = index_to_name.get(disease_index, f"Unknown (index={disease_index})")
                path_info_lines.append(f"        - {disease_name}")
        else:
            path_info_lines.append("      - No related diseases found for this target.")

    # Join all collected lines into a single string and add it to the entry
    entry['path'] = "\n".join(path_info_lines)

# ========== 4. Save the updated data to a new JSON file ==========
print(f"Saving updated data to {output_json_path}...")
with open(output_json_path, 'w', encoding='utf-8') as f:
    json.dump(data, f, indent=4)

print("Processing complete.")