import os
import json
import pickle as pkl
import numpy as np
import collections
import itertools
import time
import sys
import random
from collections import deque, defaultdict
from typing import Dict, List
from tqdm import tqdm
import csv
import collections


def create_child_to_parents_map(taxo_file_path: str, dataset: str) -> Dict[str, List[str]]:
    with open(f"../data/{dataset}/key_value.json", 'r') as f:
        id_term_map = json.load(f)

    child_to_parents = defaultdict(list)

    print(f"Reading taxonomy from: {taxo_file_path}")
    with open(taxo_file_path, 'r', encoding='utf-8') as f:
        for line in f:

            line = line.strip()
            if not line:
                continue

            parts = line.split('\t')

            if len(parts) == 2:
                parent_id, child_id = parts
                child_to_parents[id_term_map[child_id]].append(
                    id_term_map[parent_id])
            else:
                print(f"Warning: Skipping malformed line: '{line}'")

    with open(f"../data/{dataset}/test_taxo.json", "w") as f:
        json.dump(child_to_parents, f, indent=4)


def terms_to_json(file, dataset):
    def_dict = {}

    with open(file, 'r') as file:
        for line in file:
            contents = line.strip().split("\t")
            def_dict[contents[0]] = contents[1]

    with open(f"../data/{dataset}/defs.json", "w") as f:
        json.dump(def_dict, f, indent=4)


def csv_to_json(csv_file):
    def_dict = {}
    with open(csv_file, mode='r') as file:
        csv_reader = csv.reader(file)

        for line in csv_reader:

            i, lab, term, defi = line
            def_dict[term] = defi

    with open("../data/psychology/defs.json", 'w') as f:
        json.dump(def_dict, f, indent=4)


def analyze_parent_child_relationships(filepath: str):
    child_to_parents_map = collections.defaultdict(list)

    print(f"--- Analyzing file: {filepath} ---")

    try:
        with open(filepath, 'r') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if not line:
                    continue

                parts = line.split('\t')
                if len(parts) != 2:
                    print(f"Warning: Skipping malformed line #{i+1}: '{line}'")
                    continue

                parent_id, child_id = parts[0], parts[1]

                child_to_parents_map[child_id].append(parent_id)

    except FileNotFoundError:
        print(f"Error: The file '{filepath}' was not found.")
        return None, None

    print("\n--- Verifying Unique Parent Constraint ---")

    multi_parent_children = {}
    for child, parents in child_to_parents_map.items():
        if len(parents) > 1:
            multi_parent_children[child] = parents

    if not multi_parent_children:
        print("Every child has a unique parent.")

        child_to_parent_dict = {
            child: parents[0] for child, parents in child_to_parents_map.items()
        }

        print("Generated Dictionary (first 5 items):")
        for i, (child, parent) in enumerate(child_to_parent_dict.items()):
            if i >= 5:
                break
            print(f"  '{child}': '{parent}'")
        print(f"Total items in dictionary: {len(child_to_parent_dict)}")

        return child_to_parent_dict, {}

    else:
        print(
            f"Check failed: Found {len(multi_parent_children)} children with multiple parents.")
        for child, parents in multi_parent_children.items():
            print(
                f"  - Child '{child}' is linked to {len(parents)} parents: {parents}")

        print("\nCannot generate a unique child:parent dictionary due to the errors above.")
        return None, multi_parent_children


def id_to_json(filepath: str, dataset: str):
    try:
        with open(filepath, 'r') as f:
            contents = f.readlines()
        key_value = {id_no: term for line in contents for id_no,
                     term in [line.strip().split("\t")]}
        with open(f"../data/{dataset}/key_value.json", "w") as f:
            json.dump(key_value, f, indent=4)
    except FileNotFoundError:
        raise FileNotFoundError(f"File not found: {filepath}")


def pre_process_mag(args, outID=True):
    print("Processing MAG datasets...")
    dataset = args.dataset
    negsamples = args.negsamples

    def load_file(filepath: str) -> list[str]:
        try:
            with open(filepath, 'r') as f:
                return f.readlines()
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {filepath}")

    id_to_json(f"../data/{dataset}/{dataset}.terms", f"{dataset}")
    print("Saved ID-term map as json.")

    with open(f"../data/{dataset}/key_value.json", 'r') as f:
        id_term_map = json.load(f)

    def process_pair(pair):
        ids = pair.strip().split("\t")
        return (id_term_map[ids[0]], id_term_map[ids[1]])

    taxonomy_file = os.path.join(f"../data/{dataset}/{dataset}.taxo")
    full_taxonomy_pairs = load_file(taxonomy_file)

    all_concept_set_str = set([])
    all_taxo_dict_str = collections.defaultdict(list)
    for pair in full_taxonomy_pairs:
        parent, child = process_pair(pair)
        all_concept_set_str.add(child)
        all_concept_set_str.add(parent)
        all_taxo_dict_str[parent].append(child)

    concepts = sorted(all_concept_set_str)
    concept_id = {concept: idx for idx, concept in enumerate(concepts)}
    id_concept = {idx: concept for idx, concept in enumerate(concepts)}

    all_taxo_dict = collections.defaultdict(list)
    all_taxo_dict_reverse = collections.defaultdict(list)

    all_kids = list()
    if outID:
        concept_set = set(concept_id.values())
        for parent_str, children_str in all_taxo_dict_str.items():
            parent_id = concept_id[parent_str]
            children_ids = [concept_id[c] for c in children_str]
            all_kids.extend(children_ids)
            all_taxo_dict[parent_id].extend(children_ids)
            for child_id in children_ids:
                all_taxo_dict_reverse[child_id].append(parent_id)
    else:
        concept_set = all_concept_set_str
        all_taxo_dict = all_taxo_dict_str

    print(f"Loaded {len(concept_set)} total concepts.")

    train_taxnomy_file = os.path.join(
        f"../data/{dataset}/{dataset}_train.taxo")
    train_taxonomy_pairs = load_file(train_taxnomy_file)

    parent_list, child_list = [], []
    train_concept_set = set()
    chd2par_dict = collections.defaultdict(set)
    taxo_dict = collections.defaultdict(list)
    taxo_edges = []

    print("Processing Training Data...")
    for pair in train_taxonomy_pairs:
        parent, child = process_pair(pair)

        if outID:
            parent, child = concept_id[parent], concept_id[child]
        parent_list.append(parent)
        child_list.append(child)
        train_concept_set.add(parent)
        train_concept_set.add(child)

        chd2par_dict[child].add(parent)
        taxo_dict[parent].append(child)
        taxo_edges.append((parent, child))

    all_children = set(all_kids)
    roots = concept_set - all_children
    print(f"Found {len(taxo_edges)} training edges....")
    print(f"Found {len(roots)} root nodes in the training taxonomy.")

    print("Calculating radii based on depth and descendants using single-step normalization...")

    all_children_nodes = {child for parent,
                          children in all_taxo_dict.items() for child in children}
    roots = concept_set - all_children_nodes
    print(f"Found {len(roots)} root nodes for depth calculation.")

    depths = {}
    queue = deque([(root, 1) for root in roots])
    visited_for_depth = set(roots)

    for root in roots:
        depths[root] = 1

    while queue:
        current_node, current_depth = queue.popleft()
        children = all_taxo_dict.get(current_node, [])
        for child in children:
            if child not in visited_for_depth:
                visited_for_depth.add(child)
                depths[child] = current_depth + 1
                queue.append((child, current_depth + 1))
    print("Node depth calculation complete.")

    memo = {}

    def get_all_descendants(start_node):
        if start_node in memo:
            return memo[start_node]
        descendants = set()
        queue = deque(all_taxo_dict.get(start_node, []))
        visited = set(queue)
        while queue:
            current_node = queue.popleft()
            descendants.add(current_node)
            children = all_taxo_dict.get(current_node, [])
            for c in children:
                if c not in visited:
                    visited.add(c)
                    queue.append(c)
        memo[start_node] = descendants
        return descendants

    raw_scores = {}
    for node in tqdm(concept_set, desc="Calculating raw scores (h + log(l+1))"):
        h = depths.get(node, 1)
        l = len(get_all_descendants(node))
        raw_scores[node] = h + (np.log1p(l)/np.log(2))

    all_raw_values = list(raw_scores.values())
    raw_score_min = min(all_raw_values)
    raw_score_max = max(all_raw_values)
    raw_score_range = raw_score_max - \
        raw_score_min if raw_score_max > raw_score_min else 1.0

    normalized_radii = {
        node: {'radii': 1.0 - ((score - raw_score_min) / raw_score_range),
               'depth': depths.get(node),
               'descendents': len(get_all_descendants(node)),
               'raw_score_min': raw_score_min,
               'raw_score_range': raw_score_range,
               }
        for node, score in raw_scores.items()
    }
    print("Final normalized radii calculation complete.")

    if args.dataset == 'computer_science':
        supernode = concept_id['computer science']
    elif args.dataset == 'psychology':
        supernode = concept_id['psychology']

    dic_file = os.path.join(f"../data/{args.dataset}/defs.json")
    def_dic = json.load(open(dic_file))

    id_context = {}
    definitions_not_found_count = 0

    for cid, concept in id_concept.items():
        if args.dataset not in ['computer_science', 'psychology']:
            concept_lower = concept.lower()
            if concept_lower in def_dic:
                id_context[cid] = f"{concept_lower}: {def_dic[concept_lower]}"
            else:
                id_context[cid] = f"{concept_lower}: {concept_lower}"
                definitions_not_found_count += 1
        else:
            concept_lower = concept.lower()
            id_context[cid] = f"{concept_lower}"

    test_terms_file = os.path.join(
        f"../data/{dataset}/{dataset}_test.terms")

    test_term_lines = load_file(test_terms_file)
    with open(f"../data/{args.dataset}/test_taxo.json") as f:
        test_map = json.load(f)

    print("Processing validation and test sets...")

    def get_eval_data(lines):
        concept_ids, gts_ids = [], []
        for line in lines:
            _, child_term = line.strip().split("\t")
            parent_ids = list()
            if child_term in concept_id:
                child_id = concept_id[child_term]
                parent_terms = test_map.get(child_term, [])

                for p in parent_terms:
                    if p in concept_id:
                        parent_ids.append(concept_id[p])

                concept_ids.append(child_id)
                gts_ids.append(parent_ids)
            else:
                print(
                    f"Found invalid term {child_term}..removing from test set.")

        return concept_ids, gts_ids

    val_concepts_ids, val_gts_ids = [], []
    test_concepts_ids, test_gts_ids = get_eval_data(test_term_lines)

    sampled_negative_parent_dict = {}
    negative_parent_list = []

    child_parent_pair = [[child, parent]
                         for child, parent in zip(child_list, parent_list)]

    count_hard_neg_samples = 0
    count_fallback_samples = 0
    training_triplets = list()
    print(
        f"Definitions for {definitions_not_found_count} concepts not found in the wiki dictionary.")
    print(f"There are {len(child_parent_pair)} pairs in the training set.")
    for child_id, parent_id in tqdm(child_parent_pair, desc="Generating (c, p, n) Triplets"):

        found_negatives = []
        hard_candidate_pool = set()

        siblings = set(all_taxo_dict.get(parent_id, []))
        siblings.discard(child_id)
        hard_candidate_pool.update(siblings)

        grandparents = set(all_taxo_dict_reverse.get(parent_id, []))
        hard_candidate_pool.update(grandparents)

        filtered_hard_candidates = [
            node for node in hard_candidate_pool
            if node in train_concept_set
        ]

        num_hard_to_sample = min(negsamples, len(filtered_hard_candidates))

        if num_hard_to_sample > 0:
            hard_samples = np.random.choice(
                filtered_hard_candidates,
                size=num_hard_to_sample,
                replace=False
            ).tolist()
            found_negatives.extend(hard_samples)
            count_hard_neg_samples += len(hard_samples)

        while len(found_negatives) < negsamples:
            random_negative = np.random.choice(list(train_concept_set))

            if (random_negative != child_id and
                random_negative != parent_id and
                random_negative not in chd2par_dict[child_id] and
                    random_negative not in found_negatives):
                found_negatives.append(random_negative)
                count_fallback_samples += 1

        for neg_id in found_negatives:
            training_triplets.append((child_id, parent_id, neg_id))

    child_neg_parent_pair = []
    print("Negative sampling done.")
    print(
        f"Generated {count_hard_neg_samples} hard negative samples.")
    print(
        f"Used fallback random sampling for {count_fallback_samples} samples.")

    child_parent_negative_parent_triple = training_triplets
    print(
        f"There are {len(child_parent_negative_parent_triple)} samples for training.")
    print(
        f"There are {len(test_concepts_ids)} test concepts with {len(test_gts_ids)} ground truth mappings.")

    path2root = collections.defaultdict(list)
    print("Preprocessing complete.")

    with open(f'../levels/{args.dataset}_levels.json', 'w') as f:
        json.dump(normalized_radii, f, indent=4)

    return (
        concept_set, concept_id, id_concept, id_context, train_concept_set, taxo_dict,
        sampled_negative_parent_dict, child_parent_negative_parent_triple, parent_list, child_list,
        negative_parent_list, all_taxo_dict, path2root, child_parent_pair,
        child_neg_parent_pair, val_concepts_ids, val_gts_ids, test_concepts_ids, test_gts_ids, normalized_radii
    )

# Parents are labels and children are image paths. Dataset structured as a forest.


def pre_process_images(args, outID=True):
    print("Processing Birds dataset...")
    dataset = args.dataset

    def load_file(filepath):
        try:
            with open(filepath, 'r') as f:
                return f.readlines()
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {filepath}")

    def process_pair(pair):
        ids = pair.strip().split()
        return (ids[1], ids[0])

    # 0 represents the image path and 1 represents the label
    # Parent is the label and child is the image. This structure aligns image with label.
    taxonomy_file = os.path.join(f"../data/{dataset}/{dataset}.taxo")
    full_taxonomy_pairs = load_file(taxonomy_file)

    all_concept_set_str = set([])
    all_taxo_dict_str = collections.defaultdict(list)

    all_labels_set = set()
    for pair in full_taxonomy_pairs:
        parent, child = process_pair(pair)
        all_concept_set_str.add(child)
        all_concept_set_str.add(parent)

        all_taxo_dict_str[parent].append(child)

    concepts = sorted(all_concept_set_str)
    concept_id = {concept: idx for idx, concept in enumerate(concepts)}
    id_concept = {idx: concept for idx, concept in enumerate(concepts)}

    for c in concepts:
        if c.startswith('/'):
            continue
        all_labels_set.add(concept_id[c])

    all_taxo_dict = collections.defaultdict(list)
    all_taxo_dict_reverse = collections.defaultdict(list)

    all_kids = list()

    if outID:
        concept_set = set(concept_id.values())
        for parent_str, children_str in all_taxo_dict_str.items():
            parent_id = concept_id[parent_str]
            children_ids = [concept_id[c] for c in children_str]
            all_kids.extend(children_ids)
            all_taxo_dict[parent_id].extend(children_ids)

            for child_id in children_ids:
                all_taxo_dict_reverse[child_id].append(parent_id)
    else:
        concept_set = all_concept_set_str
        all_taxo_dict = all_taxo_dict_str

    print(f"Loaded {len(concept_set)} total image paths and labels")
    train_taxonomy_file = os.path.join(
        f"../data/{dataset}/{dataset}_train.taxo")
    train_taxonomy_pairs = load_file(train_taxonomy_file)

    parent_list, child_list = [], []
    train_concept_set = set()

    chd2par_dict = collections.defaultdict(set)
    taxo_dict = collections.defaultdict(list)
    taxo_edges = []

    print("Processing training data....")

    for pair in train_taxonomy_pairs:
        parent, child = process_pair(pair)

        if outID:
            parent, child = concept_id[parent], concept_id[child]

        parent_list.append(parent)
        child_list.append(child)
        train_concept_set.add(parent)
        train_concept_set.add(child)

        chd2par_dict[child].add(parent)
        taxo_dict[parent].append(child)

        taxo_edges.append((parent, child))

    all_children = set(all_kids)
    roots = concept_set-all_children

    print(f"Found {len(taxo_edges)} training samples... ")
    print(f"Found {len(roots)} roots in the training tree...")

    print("Calculating radii based on depth and descendants using single-step normalization...")
    all_children_nodes = {child for parent,
                          children in all_taxo_dict.items() for child in children}
    roots = concept_set - all_children_nodes
    print(f"Found {len(roots)} root nodes for depth calculation.")

    queue = deque([(root, 1) for root in roots])
    visited_for_depth = set(roots)
    depths = {}
    for root in roots:
        depths[root] = 1

    while queue:
        current_node, current_depth = queue.popleft()
        children = all_taxo_dict.get(current_node, [])
        for child in children:
            if child not in visited_for_depth:
                visited_for_depth.add(child)
                depths[child] = current_depth+1
                queue.append((child, current_depth+1))
    print("Node depth calculation completed.....")

    memo = {}

    def get_all_descendants(start_node):
        if start_node in memo:
            return memo[start_node]
        descendants = set()
        queue = deque(all_taxo_dict.get(start_node, []))
        visited = set(queue)
        while queue:
            current_node = queue.popleft()
            descendants.add(current_node)
            children = all_taxo_dict.get(current_node, [])
            for c in children:
                if c not in visited:
                    visited.add(c)
                    queue.append(c)
        memo[start_node] = descendants
        return descendants

    raw_scores = {}
    for node in tqdm(concept_set, desc="Calculating raw scores (h + log(l+1))"):
        h = depths.get(node, 1)
        l = len(get_all_descendants(node))
        raw_scores[node] = h + (np.log1p(l)/np.log(2))

    all_raw_values = list(raw_scores.values())
    raw_score_min = min(all_raw_values)
    raw_score_max = max(all_raw_values)
    raw_score_range = raw_score_max - \
        raw_score_min if raw_score_max > raw_score_min else 1.0

    normalized_radii = {
        node: {'radii': 1.0 - ((score - raw_score_min) / raw_score_range),
               'depth': depths.get(node),
               'descendents': len(get_all_descendants(node)),
               'raw_score_min': raw_score_min,
               'raw_score_range': raw_score_range,
               }
        for node, score in raw_scores.items()
    }

    print("Normalization radii calculation completed for images....")

    id_context = id_concept
    print("Processing test split....")

    test_file = os.path.join(f"../data/{dataset}/{dataset}_test.taxo")
    test_term_lines = load_file(test_file)

    def get_eval_data(lines):
        concept_ids, gts_ids = [], []
        for line in lines:
            _, child_term = process_pair(line)

            parent_ids = list()

            if child_term in concept_id:
                child_id = concept_id[child_term]
                parent_ids = all_taxo_dict_reverse.get(child_id, [])
                concept_ids.append(child_id)
                gts_ids.append(parent_ids)

            else:
                print(
                    f"Found invalid label {child_term} since it doesnt have an associated label..removing from test set")

        return concept_ids, gts_ids

    val_concepts_ids, val_gts_ids = [], []
    test_concepts_ids, test_gts_ids = get_eval_data(test_term_lines)

    sampled_negative_parent_dict = {}
    negative_parent_list = []

    child_parent_pair = [[child, parent]
                         for child, parent in zip(child_list, parent_list)]

    training_triplets = list()
    count_fallback_samples = 0
    print(
        f"There are {len(child_parent_pair)} training pairs in the training set.")

    max_attempts_per_pair = 200
    k_negatives = args.negsamples
    train_concept_list = list(all_labels_set)

    for child_id, parent_id in tqdm(child_parent_pair, desc="Generating (c, p, n) Triplets"):
        found_negatives = []
        found_set = set()
        attempts = 0

        while len(found_negatives) < k_negatives and attempts < max_attempts_per_pair:
            attempts += 1
            random_negative = np.random.choice(train_concept_list)

            parent_set = set(chd2par_dict.get(child_id, ()))

            if (random_negative != child_id and
                random_negative != parent_id and
                random_negative not in parent_set and
                    random_negative not in found_set):
                found_negatives.append(random_negative)
                found_set.add(random_negative)
                count_fallback_samples += 1

        needed = k_negatives - len(found_negatives)
        if needed > 0:
            parent_set = set(chd2par_dict.get(child_id, ()))
            pool = [x for x in train_concept_list
                    if x != child_id and x != parent_id and x not in parent_set and x not in found_set]

            if pool:
                # sample up to `needed` items (without replacement)
                to_take = min(needed, len(pool))
                chosen = random.sample(pool, to_take)
                found_negatives.extend(chosen)
                found_set.update(chosen)
                count_fallback_samples += len(chosen)

        if len(found_negatives) < k_negatives:
            continue

        for neg_id in found_negatives:
            training_triplets.append((child_id, parent_id, neg_id))

    child_neg_parent_pair = []
    print("Negative sampling done!!!")
    print(
        f"Used random negative sampling for {count_fallback_samples} samples.")

    child_parent_negative_parent_triple = training_triplets
    print(
        f"There are {len(child_parent_negative_parent_triple)} samples for training.")
    print(
        f"There are {len(test_concepts_ids)} test concepts with {len(test_gts_ids)} ground truth mappings.")

    path2root = collections.defaultdict(list)
    print("Preprocessing complete.")

    with open(f'../levels/{args.dataset}_levels.json', 'w') as f:
        json.dump(normalized_radii, f, indent=4)

    return (
        concept_set, concept_id, id_concept, id_context, train_concept_set, taxo_dict,
        sampled_negative_parent_dict, child_parent_negative_parent_triple, parent_list, child_list,
        negative_parent_list, all_taxo_dict, path2root, child_parent_pair,
        child_neg_parent_pair, val_concepts_ids, val_gts_ids, test_concepts_ids, test_gts_ids, normalized_radii, all_labels_set
    )


def preprocess(args, outID=True):
    """
    Preprocesses taxonomy data for taxonomy construction and evaluation tasks.

    Args:
        args: Command-line arguments or an object containing dataset parameters.
        outID (bool): If True, outputs IDs for concepts; otherwise, outputs names.

    Returns:
        Tuple containing processed data structures for taxonomy evaluation.
    """
    dataset = args.dataset

    def load_file(filepath: str) -> list[str]:
        """Helper function to load a file and return lines."""
        try:
            with open(filepath, 'r') as f:
                return f.readlines()
        except FileNotFoundError:
            raise FileNotFoundError(f"File not found: {filepath}")

    def process_pair(pair: str, dataset: str) -> tuple[str, str]:
        """Helper function to split and process a taxonomy pair."""
        text = pair.strip().split("\t")
        if (dataset == "wordnet" or "wordnet" in dataset[:8]):
            return (text[-1], text[-2])
        return (text[-2], text[-1]) if len(text) >= 3 else (text[0], text[1])

    taxonomy_file = os.path.join(f"../data/{dataset}/{dataset}_raw_en.taxo")
    taxonomy = load_file(taxonomy_file)

    concept_set = set([])
    all_taxo_dict = collections.defaultdict(list)

    for pair in taxonomy:
        child, parent = process_pair(pair, dataset)
        concept_set.add(child)
        concept_set.add(parent)

    concepts = sorted(concept_set)
    concept_id = {concept: idx for idx, concept in enumerate(concepts)}
    id_concept = {idx: concept for concept, idx in concept_id.items()}

    if outID:
        concept_set = set([concept_id[con] for con in list(concept_set)])
        for pair in taxonomy:
            child, parent = process_pair(pair, dataset)
            all_taxo_dict[concept_id[parent]].append(concept_id[child])

    print("Calculating radii based on depth and descendants using single-step normalization...")

    all_children_nodes = {child for parent,
                          children in all_taxo_dict.items() for child in children}
    roots = concept_set - all_children_nodes
    print(f"Found {len(roots)} root nodes for depth calculation.")

    depths = {}
    queue = deque([(root, 1) for root in roots])
    visited_for_depth = set(roots)

    for root in roots:
        depths[root] = 1

    while queue:
        current_node, current_depth = queue.popleft()
        children = all_taxo_dict.get(current_node, [])
        for child in children:
            if child not in visited_for_depth:
                visited_for_depth.add(child)
                depths[child] = current_depth + 1
                queue.append((child, current_depth + 1))
    print("Node depth calculation complete.")

    memo = {}

    def get_all_descendants(start_node):
        if start_node in memo:
            return memo[start_node]
        descendants = set()
        queue = deque(all_taxo_dict.get(start_node, []))
        visited = set(queue)
        while queue:
            current_node = queue.popleft()
            descendants.add(current_node)
            children = all_taxo_dict.get(current_node, [])
            for c in children:
                if c not in visited:
                    visited.add(c)
                    queue.append(c)
        memo[start_node] = descendants
        return descendants

    raw_scores = {}
    for node in tqdm(concept_set, desc="Calculating raw scores (h + log(l+1))"):
        h = depths.get(node, 1)
        l = len(get_all_descendants(node))
        raw_scores[node] = h + (np.log1p(l)/np.log(2))

    all_raw_values = list(raw_scores.values())
    raw_score_min = min(all_raw_values)
    raw_score_max = max(all_raw_values)
    raw_score_range = raw_score_max - \
        raw_score_min if raw_score_max > raw_score_min else 1.0

    normalized_radii = {
        node: {'radii': 1.0 - ((score - raw_score_min) / raw_score_range),
               'depth': depths.get(node),
               'descendents': len(get_all_descendants(node)),
               'raw_score_min': raw_score_min,
               'raw_score_range': raw_score_range,
               }
        for node, score in raw_scores.items()
    }

    print("Final normalized radii calculation complete.")

    train_taxonomy_file = os.path.join(
        f"../data/{dataset}/{dataset}_train.taxo")
    train_taxonomy = load_file(train_taxonomy_file)

    parent_list, child_list = [], []
    train_concept_set = set([])
    chd2par_dict = collections.defaultdict(set)
    taxo_dict = collections.defaultdict(list)

    taxo_edges = []
    for pair in train_taxonomy:
        parent, child = process_pair(pair, dataset)
        if outID:
            parent, child = concept_id[parent], concept_id[child]
        parent_list.append(parent)
        child_list.append(child)
        train_concept_set.add(parent)
        train_concept_set.add(child)
        chd2par_dict[child].add(parent)
        taxo_dict[parent].append(child)
        taxo_edges.append((parent, child))

    all_children = set(child_list)
    roots = train_concept_set - all_children

    if dataset == "wordnet" or "wordnet" in dataset[:8]:
        supernode = len(concepts)
        concept_id[dataset] = supernode
        id_concept[supernode] = dataset

        for root in roots:
            taxo_dict[supernode].append(root)
            chd2par_dict[root].add(supernode)
    else:
        if outID:
            supernode = concept_id[dataset]

    sibling_dict = collections.defaultdict(set)
    for parent, children in taxo_dict.items():
        for child in children:
            sibling_dict[child].update(set(children) - {child})

    if dataset == "wordnet" or "wordnet" in dataset[:8]:
        observe_nodes = train_concept_set - \
            {supernode} - set(taxo_dict[supernode])
    else:
        observe_nodes = train_concept_set

    sib_pair = [[k, l] for k, children in sibling_dict.items()
                for l in children]

    cousin_dict = collections.defaultdict(set)
    for node in observe_nodes:
        pars = chd2par_dict[node]
        for par in pars:
            cousins = sibling_dict[par] - pars
            cousin_dict[node].update(cousins)
            for uncle in cousins:
                cousin_dict[node].update(taxo_dict[uncle])
            cousin_dict[node] -= sibling_dict[node]

    relative_triple = [[node, s, c]
                       for node in observe_nodes for s in sibling_dict[node] for c in cousin_dict[node]]

    negative_parent_dict = {
        cid: sibling_dict[cid] | cousin_dict[cid] for cid in id_concept}

    negative_parent_list = []
    sampled_negative_parent_dict = {}

    for cid in child_list:
        negative_parents = list(negative_parent_dict[cid])
        if len(negative_parents) > args.negsamples:
            negative_parents = list(np.random.choice(
                negative_parents, args.negsamples, replace=False))
        sampled_negative_parent_dict[cid] = negative_parents
        negative_parent_list.extend(negative_parents)

    child_parent_negative_parent_triple = [
        [child_list[i], parent_list[i], neg]
        for i, cid in enumerate(child_list)
        for neg in sampled_negative_parent_dict[cid]
    ]

    child_parent_pair = [[child, parent]
                         for child, parent in zip(child_list, parent_list)]

    child_neg_parent_pair = [
        [cid, neg]
        for cid in child_list
        for neg in sampled_negative_parent_dict[cid]
    ]

    child_sibling_pair = [
        [cid, sib]
        for cid in child_list
        for sib in sibling_dict[cid]
    ]

    dic_file = os.path.join(f"../data/{dataset}/dic.json")
    def_dic = json.load(open(dic_file))
    def_dic = {key.lower(): value for key, value in def_dic.items()}
    if dataset == "wordnet" or "wordnet" in dataset[:8]:
        if dataset not in def_dic:
            def_dic[dataset] = ["Supernode"]

    id_context = {
        cid: f"{concept.lower()}: {def_dic[concept.lower()][0]}"
        for cid, concept in id_concept.items()
    }

    test_terms_file = os.path.join(f"../data/{dataset}/{dataset}_eval.terms")
    test_gt_file = os.path.join(f"../data/{dataset}/{dataset}_eval.gt")
    test_terms = load_file(test_terms_file)
    test_gt = load_file(test_gt_file)

    test_concepts_id = [concept_id[term.strip()] for term in test_terms]
    test_gt_id = [concept_id[term.strip()] for term in test_gt]

    shuffled_data = list(zip(test_concepts_id, test_gt_id))
    np.random.shuffle(shuffled_data)
    split_idx = len(shuffled_data) // 2
    val_concept, val_gt = zip(*shuffled_data[:split_idx])
    test_concept, test_gt = zip(*shuffled_data[split_idx:])

    path2root = collections.defaultdict(list)
    for node in train_concept_set:
        current = node
        while current != supernode:
            path2root[node].append(current)
            current = list(chd2par_dict[current])[0]
        path2root[node].append(supernode)

    return (
        concept_set, concept_id, id_concept, id_context, train_concept_set, taxo_dict,
        negative_parent_dict, child_parent_negative_parent_triple, parent_list, child_list,
        negative_parent_list, sibling_dict, cousin_dict, relative_triple, test_concepts_id,
        test_gt_id, all_taxo_dict, path2root, sib_pair, child_parent_pair, child_neg_parent_pair,
        child_sibling_pair, val_concept, val_gt, test_concept, test_gt, normalized_radii
    )


def create_image_data(args):
    print("Waiting for preprocess image data consisting of paths and labels...")
    concept_set, concept_id, id_concept, id_context, train_concept_set, taxo_dict, negative_parent_dict, child_parent_negative_parent_triple, parent_list, child_list, negative_parent_list, all_taxo_dict, path2root, child_parent_pair, child_neg_parent_pair, val_concept, val_gt, test_concepts_id, test_gt, node_levels, all_labels_set = pre_process_images(
        args)
    save_data = {
        "concept_set": concept_set,
        "concept2id": concept_id,
        "id2concept": id_concept,
        "id2context": id_context,
        "train_concept_set": train_concept_set,
        "train_taxo_dict": taxo_dict,
        "all_taxo_dict": all_taxo_dict,
        "train_negative_parent_dict": negative_parent_dict,
        "train_child_parent_negative_parent_triple": child_parent_negative_parent_triple,
        "train_parent_list": parent_list,
        "train_child_list": child_list,
        "train_negative_parent_list": negative_parent_list,
        "test_concepts_id": test_concepts_id,
        "test_gt_id": test_gt,
        "path2root": path2root,
        "child_parent_pair": child_parent_pair,
        "child_neg_parent_pair": child_neg_parent_pair,
        "val_concept": val_concept,
        "val_gt": val_gt,
        "test_concept": test_concepts_id,
        "test_gt": test_gt,
        "node_levels": node_levels,
        "all_labels_set": all_labels_set
    }

    with open("../data/"+str(args.dataset)+"/processed/taxonomy_data_"+str(args.expID)+str(args.negsamples)+"_.pkl", "wb") as f:
        pkl.dump(save_data, f)

    print("Waiting for saving processed data....")
    print("Done!")
    print(
        f"From processed data, there are :{len(child_parent_negative_parent_triple)} training instances")
    print(f"From processed data, there are :{len(test_gt)} test instances")


def create_mag_data(args):

    print("Waiting for preprocess data....")

    concept_set, concept_id, id_concept, id_context, train_concept_set, taxo_dict, negative_parent_dict, child_parent_negative_parent_triple, parent_list, child_list, negative_parent_list, all_taxo_dict, path2root, child_parent_pair, child_neg_parent_pair, val_concept, val_gt, test_concepts_id, test_gt, node_levels = pre_process_mag(
        args)
    save_data = {
        "concept_set": concept_set,
        "concept2id": concept_id,
        "id2concept": id_concept,
        "id2context": id_context,
        "train_concept_set": train_concept_set,
        "train_taxo_dict": taxo_dict,
        "all_taxo_dict": all_taxo_dict,
        "train_negative_parent_dict": negative_parent_dict,
        "train_child_parent_negative_parent_triple": child_parent_negative_parent_triple,
        "train_parent_list": parent_list,
        "train_child_list": child_list,
        "train_negative_parent_list": negative_parent_list,
        "test_concepts_id": test_concepts_id,
        "test_gt_id": test_gt,
        "path2root": path2root,
        "child_parent_pair": child_parent_pair,
        "child_neg_parent_pair": child_neg_parent_pair,
        "val_concept": val_concept,
        "val_gt": val_gt,
        "test_concept": test_concepts_id,
        "test_gt": test_gt,
        "node_levels": node_levels
    }

    with open("../data/"+str(args.dataset)+"/processed/taxonomy_data_"+str(args.expID)+str(args.negsamples)+str(args.seed)+"_.pkl", "wb") as f:
        pkl.dump(save_data, f)

    print("Waiting for saving processed data....")
    print("Done!")
    print(
        f"From processed data, there are :{len(child_parent_negative_parent_triple)} training instances")
    print(f"From processed data, there are :{len(test_gt)} test instances")


def create_data(args, maxlimit=None):

    concept_set, concept_id, id_concept, id_context, train_concept_set, train_taxo_dict, negative_parent_dict, train_child_parent_negative_parent_triple, train_parent_list, \
        train_child_list, train_negative_parent_list, train_sibling_dict, train_cousin_dict, train_relative_triple, test_concepts_id, test_gt_id, \
        all_taxo_dict, path2root, sib_pair, child_parent_pair, child_neg_parent_pair, child_sibling_pair, val_concept, val_gt, test_concept, test_gt, levels = preprocess(
            args)

    print("Waiting for preprocess data....")
    time.sleep(3)
    print("Done!")
    save_data = {
        "concept_set": concept_set,
        "concept2id": concept_id,
        "id2concept": id_concept,
        "id2context": id_context,
        "all_taxo_dict": all_taxo_dict,
        "train_concept_set": train_concept_set,
        "train_taxo_dict": train_taxo_dict,
        "train_negative_parent_dict": negative_parent_dict,
        "train_child_parent_negative_parent_triple": train_child_parent_negative_parent_triple,
        "train_parent_list": train_parent_list,
        "train_child_list": train_child_list,
        "train_negative_parent_list": train_negative_parent_list,
        "train_sibling_dict": train_sibling_dict,
        "train_cousin_dict": train_cousin_dict,
        "train_relative_triple": train_relative_triple,
        "test_concepts_id": test_concepts_id,
        "test_gt_id": test_gt_id,
        "path2root": path2root,
        "sib_pair": sib_pair,
        "child_parent_pair": child_parent_pair,
        "child_neg_parent_pair": child_neg_parent_pair,
        "child_sibling_pair": child_sibling_pair,
        "val_concept": val_concept,
        "val_gt": val_gt,
        "test_concept": test_concept,
        "test_gt": test_gt,
        "node_levels":levels,}

    with open("../data/"+str(args.dataset)+"/processed/taxonomy_data_"+str(args.expID)+str(args.negsamples)+str(args.seed)+"_.pkl", "wb") as f:
        pkl.dump(save_data, f)

    print("Waiting for saving processed data....")
    time.sleep(3)
    print("Done!")
    print(
        f"From processed data, there are :{len(train_child_parent_negative_parent_triple)} training instances")
    print(f"From processed data, there are :{len(test_gt_id)} test instances")


if __name__ == '__main__':
    create_child_to_parents_map(
        "../data/psychology/psychology.taxo", "psychology")
