import importlib
import logging
import re

def manipulate_entity(context_dict: dict, entity_to_manipulate: str, perturb_strategy: str, doc_name: str):
    if perturb_strategy == 'drop':
        return drop_entity_context(context_dict, entity_to_manipulate)
    elif perturb_strategy == 'inject' and doc_name is not None:
        return inject_synonym_context(context_dict, entity_to_manipulate, doc_name)
    else:
        raise ValueError(f"Unknown strategy: {perturb_strategy}")


def manipulate_relation(context_dict: dict, relation_to_manipulate: dict, perturb_strategy: str, doc_name: str):
    """
    Dispatches relation manipulation to the correct function based on the strategy.
    """
    if perturb_strategy == 'drop':
        return drop_relation_context(context_dict, relation_to_manipulate)
    elif perturb_strategy == 'inject' and doc_name is not None:
        return inject_relation_synonym_context(context_dict, relation_to_manipulate, doc_name)
    else:
        raise ValueError(f"Unknown strategy for relation manipulation: {perturb_strategy} or missing doc_name.")


def drop_entity_context(context_dict: dict, entity_to_drop: str):
    """
    Removes all occurrences of a specific entity from the context.
    This includes the entity object, any relations involving it, and its name from
    text chunks and other entities' descriptions.
    """
    logging.info(f"Dropping entity '{entity_to_drop}' from context.")

    # 1. Remove the entity from the entities list and clean other entities' descriptions
    remaining_entities = []
    for entity in context_dict.get("entities_context", []):
        if entity.get("entity") == entity_to_drop:
            continue  # Skip the entity to be dropped

        # For remaining entities, clean their description of the dropped entity's name
        if "description" in entity and isinstance(entity.get("description"), str):
            new_description = re.sub(r'\b' + re.escape(entity_to_drop) + r'\b', '', entity["description"], flags=re.IGNORECASE)
            entity["description"] = re.sub(r'\s\s+', ' ', new_description).strip()
        
        remaining_entities.append(entity)
    context_dict["entities_context"] = remaining_entities

    # 2. Remove any relations involving the entity and clean other relations' descriptions
    remaining_relations = []
    for r in context_dict.get("relations_context", []):
        if r.get("entity1") == entity_to_drop or r.get("entity2") == entity_to_drop:
            continue  # Skip relation involving the dropped entity

        # For remaining relations, clean their description of the dropped entity's name
        if "description" in r and isinstance(r.get("description"), str):
            new_description = re.sub(r'\b' + re.escape(entity_to_drop) + r'\b', '', r["description"], flags=re.IGNORECASE)
            r["description"] = re.sub(r'\s\s+', ' ', new_description).strip()
        remaining_relations.append(r)
    context_dict["relations_context"] = remaining_relations

    # 3. Remove occurrences of the entity name from the text chunks
    for text_unit in context_dict.get("text_units_context", []):
        if "content" in text_unit and isinstance(text_unit.get("content"), str):
            new_content = re.sub(r'\b' + re.escape(entity_to_drop) + r'\b', '', text_unit["content"], flags=re.IGNORECASE)
            text_unit["content"] = re.sub(r'\s\s+', ' ', new_content).strip()
    return context_dict


def drop_relation_context(context_dict: dict, relation_to_drop: dict):
    """
    Removes all relations between two specific entities from the context dictionary.
    This is an undirected removal; it removes both A->B and B->A relations if they exist.
    """
    e1_to_drop = relation_to_drop.get("entity1")
    e2_to_drop = relation_to_drop.get("entity2")

    if not e1_to_drop or not e2_to_drop:
        logging.warning(f"Relation to drop is missing 'entity1' or 'entity2': {relation_to_drop}")
        return context_dict

    logging.info(f"Attempting to drop all relations between '{e1_to_drop}' and '{e2_to_drop}'.")

    original_relations = context_dict.get("relations_context", [])
    new_relations = []
    dropped_count = 0
    for r in original_relations:
        # Check for a match in either direction to handle undirected relationships
        if (r.get("entity1") == e1_to_drop and r.get("entity2") == e2_to_drop) or \
           (r.get("entity1") == e2_to_drop and r.get("entity2") == e1_to_drop):
            logging.info(f"  - Dropping relation: {r}")
            dropped_count += 1
            continue
        new_relations.append(r)
    
    if dropped_count == 0:
        logging.warning(f"Could not find any relations to drop between '{e1_to_drop}' and '{e2_to_drop}'.")
    else:
        logging.info(f"Dropped {dropped_count} relation(s).")

    context_dict["relations_context"] = new_relations
    return context_dict


def inject_synonym_context(context_dict: dict, entity_to_inject: str, doc_name: str) -> dict:
    """
    Replaces entity names with a randomly chosen synonym from the provided mapping.
    This helps test the model's robustness to variations in entity names.
    """
    try:
        # Dynamically import the synonym map module
        synonym_map_module = importlib.import_module("synonym_map")
        # Construct the variable name for the synonym dictionary (e.g., TALE_TIGGY_WINKLE_SYNONYMS)
        synonym_dict_name = f"{doc_name.upper().replace('-', '_')}_SYNONYMS"
        # Get the dictionary from the module
        synonym_dict = getattr(synonym_map_module, synonym_dict_name)
    except (ImportError, AttributeError) as e:
        raise ImportError(f"Could not load synonym dictionary '{synonym_dict_name}' from synonym_map.py: {e}. Skipping injection.")
    
     # Get the new name from the dictionary
    new_name = synonym_dict.get(entity_to_inject)
    if not new_name:
        # If no synonym is found for this specific entity, do nothing.
        return context_dict
    
    print(f"Injecting: '{entity_to_inject}' -> '{new_name}'")

    # Replace in entities_context
    for entity_info in context_dict.get("entities_context", []):
        if entity_info.get("entity") == entity_to_inject:
            entity_info["entity"] = new_name

    # Replace in relations_context
    for relation_info in context_dict.get("relations_context", []):
        if relation_info.get("entity1") == entity_to_inject:
            relation_info["entity1"] = new_name
        if relation_info.get("entity2") == entity_to_inject:
            relation_info["entity2"] = new_name

    return context_dict


def inject_relation_synonym_context(context_dict: dict, relation_to_inject: dict, doc_name: str) -> dict:
    """
    Replaces a relation's label with a synonym from a predefined map.
    """
    relation_label = relation_to_inject.get("label")
    if not relation_label:
        logging.warning("Relation to inject has no 'label' field. Skipping.")
        return context_dict

    try:
        synonym_map_module = importlib.import_module("synonym_map")
        synonym_dict_name = f"{doc_name.upper().replace('-', '_')}_RELATION_SYNONYMS"
        synonym_dict = getattr(synonym_map_module, synonym_dict_name)
    except (ImportError, AttributeError):
        logging.warning(f"Could not load relation synonym dictionary '{synonym_dict_name}' from synonym_map.py. Skipping injection.")
        return context_dict

    new_label = synonym_dict.get(relation_label)
    if not new_label:
        logging.warning(f"No synonym found for relation label '{relation_label}' in '{synonym_dict_name}'. Skipping.")
        return context_dict
    
    logging.info(f"Injecting synonym for relation label: '{relation_label}' -> '{new_label}'")

    # Find the specific relation in the context and update its label
    for relation_info in context_dict.get("relations_context", []):
        if relation_info == relation_to_inject:
            relation_info["label"] = new_label
            break

    return context_dict