import json
from collections import defaultdict, Counter
import json


def augment_dataset_with_synonyms(train_file, original_terms_file, synonym_mapping_file, output_file):
    # Load original terms and synonym mappings
    with open(original_terms_file, "r") as f:
        id_to_original = json.load(f)
    with open(synonym_mapping_file, "r") as f:
        term_to_id = json.load(f)

    # Create a mapping from labels (both original and synonyms) to IDs
    label_to_id = {v: k for k, v in id_to_original.items()}  # Original terms
    label_to_id.update(term_to_id)  # Synonyms

    # Load the train file
    with open(train_file, "r") as f:
        train_data = [json.loads(line) for line in f]

    augmented_data = []

    # Augment each entry with synonyms
    for entry in train_data:
        for key in ["child", "parent"]:  # Process both 'child' and 'parent' fields
            label = entry[key]
            if label in label_to_id:
                entity_id = label_to_id[label]
                # Find all synonyms for this ID
                synonyms = [term for term, id_ in term_to_id.items() if id_ == entity_id]
                synonyms.append(id_to_original.get(entity_id, label))  # Add the original term as well

                # Create augmented entries for each synonym
                for synonym in synonyms:
                    augmented_entry = entry.copy()
                    augmented_entry[key] = synonym
                    augmented_data.append(augmented_entry)
            else:
                print(f"Warning: Label '{label}' not found in mappings. Skipping.")

    # Save the augmented dataset
    with open(output_file, "w") as f:
        for entry in augmented_data:
            f.write(json.dumps(entry) + "\n")

    print(f"Augmented dataset saved to {output_file}")


def generate_exhaustive_dataset(train_file, original_terms_file, synonym_mapping_file, output_file):
    """
    Generate an exhaustive dataset by replacing all terms with all their synonyms.

    Args:
        train_file (str): Path to the input train dataset (JSONL format).
        original_terms_file (str): Path to the original terms JSON file.
        synonym_mapping_file (str): Path to the synonym mappings JSON file.
        output_file (str): Path to save the exhaustive dataset.

    Returns:
        None
    """
    import json
    from itertools import product

    # Load original terms and synonym mappings
    with open(original_terms_file, "r") as f:
        id_to_original = json.load(f)
    with open(synonym_mapping_file, "r") as f:
        term_to_id = json.load(f)

    # Create a mapping from labels to IDs
    label_to_id = {v: k for k, v in id_to_original.items()}  # Original terms
    label_to_id.update(term_to_id)  # Synonyms

    # Inverse mapping: ID to all synonyms
    id_to_synonyms = defaultdict(list)
    for term, entity_id in term_to_id.items():
        id_to_synonyms[entity_id].append(term)
    for entity_id, original_term in id_to_original.items():
        id_to_synonyms[entity_id].append(original_term)

    # Load the train file
    with open(train_file, "r") as f:
        train_data = [json.loads(line) for line in f]

    exhaustive_data = []
    for entry in train_data:
        child_label = entry["child"]
        parent_label = entry["parent"]

        # Get synonyms for child and parent
        child_synonyms = (
            id_to_synonyms[label_to_id[child_label]] if child_label in label_to_id else [child_label]
        )
        parent_synonyms = (
            id_to_synonyms[label_to_id[parent_label]] if parent_label in label_to_id else [parent_label]
        )

        # Create all combinations of synonyms
        for child_syn, parent_syn in product(child_synonyms, parent_synonyms):
            augmented_entry = {
                "child": child_syn,
                "parent": parent_syn,
                "label": entry["label"],
            }
            exhaustive_data.append(augmented_entry)

    # Save the exhaustive dataset
    with open(output_file, "w") as f:
        for entry in exhaustive_data:
            f.write(json.dumps(entry) + "\n")

    print(f"Exhaustive dataset saved to {output_file}")


'''For Pairs dataset'''
def generate_exhaustive_dataset_with_flag(train_file, original_terms_file, synonym_mapping_file, output_file):
    """
    Generate an exhaustive dataset by replacing all terms with all their synonyms and add a 'synonym' flag.

    Args:
        train_file (str): Path to the input train dataset (JSONL format).
        original_terms_file (str): Path to the original terms JSON file.
        synonym_mapping_file (str): Path to the synonym mappings JSON file.
        output_file (str): Path to save the exhaustive dataset.

    Returns:
        None
    """
    import json
    from itertools import product
    from collections import defaultdict

    # Load original terms and synonym mappings
    with open(original_terms_file, "r") as f:
        id_to_original = json.load(f)
    with open(synonym_mapping_file, "r") as f:
        term_to_id = json.load(f)

    # Create a mapping from labels to IDs
    label_to_id = {v: k for k, v in id_to_original.items()}  # Original terms
    label_to_id.update(term_to_id)  # Synonyms

    # Inverse mapping: ID to all synonyms (including original term)
    id_to_synonyms = defaultdict(list)
    for term, entity_id in term_to_id.items():
        id_to_synonyms[entity_id].append(term)
    for entity_id, original_term in id_to_original.items():
        id_to_synonyms[entity_id].append(original_term)

    # Load the train file
    with open(train_file, "r") as f:
        train_data = [json.loads(line) for line in f]

    exhaustive_data = []
    for entry in train_data:
        child_label = entry["child"]
        parent_label = entry["parent"]

        # Get synonyms for child and parent
        child_synonyms = (
            id_to_synonyms[label_to_id[child_label]] if child_label in label_to_id else [child_label]
        )
        parent_synonyms = (
            id_to_synonyms[label_to_id[parent_label]] if parent_label in label_to_id else [parent_label]
        )

        # Create all combinations of synonyms
        for child_syn, parent_syn in product(child_synonyms, parent_synonyms):
            augmented_entry = {
                "child": child_syn,
                "parent": parent_syn,
                "label": entry["label"],
                "synonym": 0 if (child_syn == child_label and parent_syn == parent_label) else 1
            }
            exhaustive_data.append(augmented_entry)

    # Save the exhaustive dataset
    with open(output_file, "w") as f:
        for entry in exhaustive_data:
            f.write(json.dumps(entry) + "\n")

    print(f"Exhaustive dataset with 'synonym' flag saved to {output_file}")


def filter_dataset_with_duplicates_removal(input_file, output_file, max_occurrences=10):
    """
    Filter the exhaustive dataset to limit synonym occurrences while removing duplicates and removing the 'synonym' key.

    Args:
        input_file (str): Path to the exhaustive dataset (JSONL format).
        output_file (str): Path to save the filtered dataset.
        max_occurrences (int): Maximum number of occurrences for each synonym.

    Returns:
        None
    """
    import json
    from collections import defaultdict

    # Load the exhaustive dataset
    with open(input_file, "r") as f:
        dataset = [json.loads(line) for line in f]

    # Track occurrences of unique entries and synonyms
    synonym_occurrence_counter = defaultdict(int)
    unique_entries = set()  # To track unique combinations of (child, parent, label)

    filtered_dataset = []
    for entry in dataset:
        # Convert the entry to a tuple to check uniqueness, excluding 'synonym'
        entry_tuple = (entry["child"], entry["parent"], entry["label"])

        # Skip if already added
        if entry_tuple in unique_entries:
            continue

        if entry.get("synonym", 0) == 1:
            # Handle synonym entries
            if synonym_occurrence_counter[entry["child"]] < max_occurrences:
                # Remove 'synonym' key
                filtered_entry = {k: v for k, v in entry.items() if k != "synonym"}
                filtered_dataset.append(filtered_entry)
                synonym_occurrence_counter[entry["child"]] += 1
                unique_entries.add(entry_tuple)
        else:
            # Always include original entries (synonym == 0)
            filtered_entry = {k: v for k, v in entry.items() if k != "synonym"}
            filtered_dataset.append(filtered_entry)
            unique_entries.add(entry_tuple)

    # Save the filtered dataset
    with open(output_file, "w") as f:
        for entry in filtered_dataset:
            f.write(json.dumps(entry) + "\n")

    print(f"Filtered dataset saved to {output_file}")


'''For Triplets dataset'''
def generate_exhaustive_triplets_dataset_with_flag(
    triplets_file, original_terms_file, synonym_mapping_file, output_file, expand_negative=False
):
    """
    Generate an exhaustive triplets dataset by replacing all terms with all their synonyms.

    Args:
        triplets_file (str): Path to the input triplets dataset (JSONL format).
        original_terms_file (str): Path to the original terms JSON file.
        synonym_mapping_file (str): Path to the synonym mappings JSON file.
        output_file (str): Path to save the exhaustive dataset.
        expand_negative (bool): Whether to expand the "negative" field with synonyms. Defaults to False.

    Returns:
        None
    """
    import json
    from itertools import product
    from collections import defaultdict

    # Load original terms and synonym mappings
    with open(original_terms_file, "r") as f:
        id_to_original = json.load(f)
    with open(synonym_mapping_file, "r") as f:
        term_to_id = json.load(f)

    # Create a mapping from labels to IDs
    label_to_id = {v: k for k, v in id_to_original.items()}  # Original terms
    label_to_id.update(term_to_id)  # Synonyms

    # Inverse mapping: ID to all synonyms
    id_to_synonyms = defaultdict(list)
    for term, entity_id in term_to_id.items():
        id_to_synonyms[entity_id].append(term)
    for entity_id, original_term in id_to_original.items():
        id_to_synonyms[entity_id].append(original_term)

    # Load the triplets file
    with open(triplets_file, "r") as f:
        triplets_data = [json.loads(line) for line in f]

    exhaustive_data = []
    for entry in triplets_data:
        child_label = entry["child"]
        parent_label = entry["parent"]
        negative_label = entry["negative"]

        # Get synonyms for child and parent
        child_synonyms = (
            id_to_synonyms[label_to_id[child_label]] if child_label in label_to_id else [child_label]
        )
        parent_synonyms = (
            id_to_synonyms[label_to_id[parent_label]] if parent_label in label_to_id else [parent_label]
        )

        # Get synonyms for negative only if expand_negative is True
        if expand_negative:
            negative_synonyms = (
                id_to_synonyms[label_to_id[negative_label]] if negative_label in label_to_id else [negative_label]
            )
        else:
            negative_synonyms = [negative_label]

        # Create all combinations of synonyms
        for child_syn, parent_syn, negative_syn in product(child_synonyms, parent_synonyms, negative_synonyms):
            augmented_entry = {
                "child": child_syn,
                "parent": parent_syn,
                "negative": negative_syn,
                "synonym": int(
                    child_syn != child_label or parent_syn != parent_label or negative_syn != negative_label
                ),
            }
            exhaustive_data.append(augmented_entry)

    # Save the exhaustive dataset
    with open(output_file, "w") as f:
        for entry in exhaustive_data:
            f.write(json.dumps(entry) + "\n")

    print(f"Exhaustive triplets dataset saved to {output_file}")



def filter_triplets_dataset_with_duplicates_removal(input_file, output_file, max_occurrences=10):
    """
    Filter the exhaustive triplets dataset to limit synonym occurrences while removing duplicates.

    Args:
        input_file (str): Path to the exhaustive triplets dataset (JSONL format).
        output_file (str): Path to save the filtered dataset (JSONL format).
        max_occurrences (int): Maximum number of occurrences for each synonym.

    Returns:
        None
    """
    import json
    from collections import defaultdict

    # Load the exhaustive dataset
    with open(input_file, "r") as f:
        dataset = [json.loads(line) for line in f]

    # Track occurrences of unique entries and synonyms
    synonym_occurrence_counter = defaultdict(int)
    unique_entries = set()  # To track unique combinations of (child, parent, negative, synonym)

    filtered_dataset = []
    for entry in dataset:
        # Convert the entry to a tuple to check uniqueness
        entry_tuple = (
            entry["child"],
            entry["parent"],
            entry["negative"],
            entry["synonym"],
        )

        # Skip if already added
        if entry_tuple in unique_entries:
            continue

        if entry["synonym"] == 1:
            # Handle synonym entries
            key = f"{entry['child']}|{entry['parent']}|{entry['negative']}"  # Unique key for counting occurrences
            if synonym_occurrence_counter[key] < max_occurrences:
                filtered_dataset.append(entry)
                synonym_occurrence_counter[key] += 1
                unique_entries.add(entry_tuple)
        else:
            # Always include original entries (synonym == 0)
            filtered_dataset.append(entry)
            unique_entries.add(entry_tuple)

    # Save the filtered dataset
    with open(output_file, "w") as f:
        for entry in filtered_dataset:
            # Remove the "synonym" key to prepare for training
            entry.pop("synonym", None)
            f.write(json.dumps(entry) + "\n")

    print(f"Filtered triplets dataset saved to {output_file}")



# RUN
original_terms_file = "original_terms.json"
synonym_mapping_file = "synonym_mapping.json"
dataset_path = "./hpo_datasets_multi_random-Triplets/val.jsonl"
augmented_dataset_path = "./hpo_datasets_multi_random-Triplets/augmented_val.jsonl"
filtered_dataset_path = "./hpo_datasets_multi_random-Triplets/filtered_val.jsonl"

# id_to_original, term_to_id = load_mappings(original_terms_file, synonym_mapping_file)
# add_synonyms_to_dataset(dataset_path, term_to_id, id_to_original, output_path)
# augment_dataset_with_synonyms(dataset_path, original_terms_file, synonym_mapping_file, output_path)
# augment_dataset_with_synonyms_sequential(dataset_path, original_terms_file, synonym_mapping_file, output_path)

'''Pairs augmentation'''
# Generate full augmented dataset with all possible combinations
generate_exhaustive_dataset_with_flag(dataset_path, original_terms_file, synonym_mapping_file, augmented_dataset_path)

# Filter the dataset to reduce the size while keeping synonyms diversity
filter_dataset_with_duplicates_removal(augmented_dataset_path, filtered_dataset_path, max_occurrences=5)

'''Triplets augmentation'''
generate_exhaustive_triplets_dataset_with_flag(dataset_path, original_terms_file, synonym_mapping_file, augmented_dataset_path, expand_negative=False)
filter_triplets_dataset_with_duplicates_removal(augmented_dataset_path, filtered_dataset_path, max_occurrences=5)