import json
import os
from collections import defaultdict
import yaml
from pathlib import Path

# Use sentence-transformers for semantic similarity, which is more robust than lexical similarity.
from .explainer.encoder import Encoder
from requests import Session

from scipy import spatial
from .utils import load_config

os.environ['no_proxy'] = '127.0.0.1'
os.environ['NO_PROXY'] = '127.0.0.1'

CONFIG_PATH = Path(__file__).resolve().parent / "config.yaml"


def find_connected_components(nodes, edges):
    """Finds connected components in a graph using BFS."""
    visited = set()
    components = []
    for node in nodes:
        if node not in visited:
            component = []
            q = [node]
            visited.add(node)
            while q:
                curr = q.pop(0)
                component.append(curr)
                for neighbor in edges[curr]:
                    if neighbor not in visited:
                        visited.add(neighbor)
                        q.append(neighbor)
            components.append(component)
    return components

def deduplicate_entities(data, similarity_threshold: int):
    """
    Deduplicates entities based on name similarity and updates relations.
    """
    merge_logs = []

    entities = data['entities_context']
    relations = data.get('relations_context', []) # Handle cases with no relations

    if not entities:
        return data

    encoder = Encoder(model_name="nomic-embed-text") 

    # --- 1. Generate embeddings for all entity names in a single batch ---
    entity_names = [e['entity'] for e in entities]
    embeddings = encoder.encode(entity_names)

    # --- 2. Group similar entities using cosine similarity ---
    entity_indices = list(range(len(entities)))
    adjacency_list = defaultdict(list)
    for i in range(len(entities)):
        for j in range(i + 1, len(entities)):
            # Ignore pairs with very different types, e.g., person vs. location
            if entities[i]['type'] != entities[j]['type']:
                continue

            # Calculate cosine similarity from embeddings. The result is between 0 and 1.
            similarity_score = 1 - spatial.distance.cosine(embeddings[i], embeddings[j])   

            if (similarity_score * 100) >= similarity_threshold:
                print(f"Found potential match: '{entities[i]['entity']}' and '{entities[j]['entity']}' (Similarity: {similarity_score:.2%})")
                adjacency_list[i].append(j)
                adjacency_list[j].append(i)

    merge_groups = find_connected_components(entity_indices, adjacency_list)

    # --- 3. Merge entities and create mappings ---
    name_map = {}  # Maps old entity name -> new canonical name
    final_entities = []
    processed_indices = set()

    # Pre-calculate entity degrees (number of relations) to determine the most connected entity in a group.
    entity_degrees = defaultdict(int)
    for rel in relations:
        if rel.get('entity1'):
            entity_degrees[rel['entity1']] += 1
        if rel.get('entity2'):
            entity_degrees[rel['entity2']] += 1

    for group in merge_groups:
        if len(group) > 1:
            # This is a group of duplicates, needs merging
            group_entities = [entities[i] for i in group]
            # Choose the canonical entity based on the one with the most relations (edges).
            # If there's a tie in degrees, fall back to the one with the lowest ID.
            canonical_entity = max(group_entities, key=lambda e: (entity_degrees.get(e['entity'], 0), -e['id']))
            
            log_message = f"Merging group: {[e['entity'] for e in group_entities]} -> Canonical entity: '{canonical_entity['entity']}' (ID: {canonical_entity['id']})"
            print(f"\nMerging group: {[e['entity'] for e in group_entities]}")
            print(f"  -> Canonical entity: '{canonical_entity['entity']}' (ID: {canonical_entity['id']})")
            merge_logs.append(log_message)

            all_descriptions = set(canonical_entity['description'].split('<SEP>'))

            for entity_to_merge in group_entities:
                if entity_to_merge['id'] != canonical_entity['id']:
                    # Add description from the merged entity
                    for desc in entity_to_merge['description'].split('<SEP>'):
                        if desc:
                            all_descriptions.add(desc)
                    # Map the old name to the new canonical name
                    name_map[entity_to_merge['entity']] = canonical_entity['entity']
            
            canonical_entity['description'] = '<SEP>'.join(sorted(list(all_descriptions)))
            final_entities.append(canonical_entity)
            processed_indices.update(group)
        else:
            # This is a unique entity
            idx = group[0]
            if idx not in processed_indices:
                final_entities.append(entities[idx])
                processed_indices.add(idx)

    # --- 4. Update relations by merging entity names, keeping all original relations ---
    updated_relations = []

    for rel in relations:
        # Update entity names in the relation if they were merged
        entity1 = name_map.get(rel['entity1'], rel['entity1'])
        entity2 = name_map.get(rel['entity2'], rel['entity2'])

        # Avoid self-referential relations that might be created by the merge
        if entity1 == entity2:
            continue

        # Keep the original relation object but with updated entity names
        rel['entity1'] = entity1
        rel['entity2'] = entity2
        updated_relations.append(rel)

    return {'entities_context': final_entities, 'relations_context': updated_relations}, merge_logs


if __name__ == "__main__":
    if not CONFIG_PATH.exists():
        print(f"Error: Config file not found at {CONFIG_PATH}")
    else:
        config = load_config(CONFIG_PATH)
        dedup_config = config.get('deduplication', {})
        similarity_threshold = dedup_config.get('similarity_threshold', 70)
        input_filename = dedup_config.get('standalone_test_input_file')

        if not input_filename or not os.path.exists(input_filename):
            print(f"Error: Input file not found. Please specify 'standalone_test_input_file' in config.yaml under 'deduplication'.")
        else:
            with open(input_filename, 'r', encoding='utf-8') as f:
                original_data = json.load(f)

            print("--- Starting Entity Deduplication ---")
            deduplicated_data, merge_logs = deduplicate_entities(original_data, similarity_threshold)

            # --- 5. Save the new JSON file ---
            output_filename = input_filename.replace('.json', '_deduplicated.json')
            with open(output_filename, 'w', encoding='utf-8') as f:
                json.dump(deduplicated_data, f, indent=4)

            print("\n--- Deduplication Complete ---")
            print("Merge Logs:")
            for log in merge_logs:
                print(f"- {log}")
            print(f"Original entity count: {len(original_data.get('entities_context', []))}")
            print(f"Deduplicated entity count: {len(deduplicated_data.get('entities_context', []))}")
            print(f"Original relation count: {len(original_data.get('relations_context', []))}")
            print(f"Deduplicated relation count: {len(deduplicated_data.get('relations_context', []))}")
            print(f"\n✅ Successfully saved deduplicated data to: {output_filename}")
