import json
import os
import pandas as pd

# hierarchy is a dictionary with keys 'id', 'name', 'children', where 'children' is a list of dictionaries with the same keys
# flatten this hierarchy into a list of tuples (id, name, parent_id)

def flatten_hierarchy(hierarchy, parent_id=None):

    result = []

    for node in hierarchy['children']:

        if 'index' not in node:
            node['index'] = None

        result.append({'id': node['id'], 'name': node['name'], 'index': node['index'], 'parent_id': parent_id})
        if 'children' in node:
            children_result_list = flatten_hierarchy(node, parent_id = node['id'])
            result.extend(children_result_list)
    return result

class Node:

    def __init__(self, name, cid, son2parent):

        self.id = cid 
        self.name = name
        self.son2parent = son2parent
        self.children = self.construct_children()

        self.n_children = len(self.children) if self.children is not None else 0
        self.n_descendants = self.get_n_descendants()

        # if self.n_descendants > 0:
            # print(self.name, 'has', self.n_descendants, 'descendants')
            
    # Construct the children recursively
    def construct_children(self):

        if self.name not in self.son2parent.values():
            return None
        else:
            children = []
            for son, parent in self.son2parent.items():
                if parent == self.name:
                    children.append(Node(son, self.son2parent))
            return children
        
    def get_n_descendants(self):
        if self.n_children == 0:
            return 0
        else:
            return self.n_children + sum([child.get_n_descendants() for child in self.children])
        
    def __str__(self) -> str:
        return self.name + ', n_children: ' + str(self.n_children) + ', n_descendants: ' + str(self.n_descendants)

root_id = 'Root'

# Step 1: load the hierarchy from json file and set the root node
with open('imagenet_hierarchy.json', 'r') as f:
    hierarchy = json.load(f)
    hierarchy['name'] = 'Root'
    hierarchy['id'] = root_id
    hierarchy['index'] = None
    hierarchy['parent_id'] = None

    sub_root ={'name': 'SubRoot', 'id': 'SubRoot', 'index': None, 'parent_id': root_id}
    hierarchy['children'].append(sub_root)

# Step 2: flatten the hierarchy and fix the parent_id of the root node
flattened_hierarchy = flatten_hierarchy(hierarchy)
root_node = {'name': 'Root', 'id': root_id, 'index': None, 'parent_id': None}
flattened_hierarchy.insert(0, root_node)

# count how many nodes has a parent_id of None
no_parent_list =[node for node in flattened_hierarchy if node['parent_id'] is None]

print('There are', len(no_parent_list), 'nodes with parent_id of None')

# for all nodes with parent_id of None, set their parent_id to be the root_id, except for the root node
for node in flattened_hierarchy:
    if node['parent_id'] is None and node['id'] != root_id:
        node['parent_id'] = root_id

no_parent_list =[node for node in flattened_hierarchy if node['parent_id'] is None]
print('There are', len(no_parent_list), 'nodes with parent_id of None')

# Step 3: create a dictionary mapping from id to name, and a dictionary mapping from id to parent_id
# create a dictionary mapping from id to name
id2name = {node['id']: node['name'] for node in flattened_hierarchy}

# create a dictionary mapping from id to index
id2index = {node['id']: node['index'] for node in flattened_hierarchy}

# create a dictionary mapping from id to parent_id
son2parent = {node['id']: node['parent_id'] for node in flattened_hierarchy}

# Step 4: create a hierarchy tree
root = Node('root', None, son2parent)

# leaf ids are the ids of the leaf nodes
leaf_ids = os.listdir('/local/tlong/data/ImageNet_train')

# create a dictionary mapping from leaf id to leaf name
leaf_id2name = {leaf_id: id2name[leaf_id] for leaf_id in leaf_ids}

# create a dictionary mapping from leaf id to parent id
leaf2parent = {leaf_id: son2parent[leaf_id] for leaf_id in leaf_ids if son2parent[leaf_id] != root_id}
print('There are', len(leaf2parent), 'leaf nodes')
dangleleaf2parent = {leaf_id: 'SubRoot' for leaf_id in leaf_ids if son2parent[leaf_id] == root_id}
leaf2parent.update(dangleleaf2parent)
print('There are', len(leaf2parent), 'leaf nodes after updating the update dangling leaf nodes')


# Create a list of dictionaries for each row of the final DataFrame
rows = []
for leaf, parent in leaf2parent.items():
    rows.append({
        'class': leaf,
        'parent class': parent,
        # 'grandparent class': grandparent,
        'root class': 'Root',
        'seen flag': 1
    })

df = pd.DataFrame(rows)
df.to_csv('imagenet_hierarchy.csv', index=False, header=False)

# creata a dictionary mapping from parent id to grantparent id
# parent2grandparent = {parent_id: son2parent[parent_id] for parent_id in leaf2parent.values()}

# create a dictionary mapping from grandparent id to greatgrandparent id
# # find the ids of the level 2 nodes
# level2_ids = set()
# for leaf_id in leaf_ids:
#     parent_id = leaf2parent[leaf_id]
#     level2_ids.add(parent_id)

