import json
import networkx as nx
import numpy as np
from pyvis.network import Network
from machiavelli.config import MachiavelliConfig


def main():
    cfg = MachiavelliConfig()

    game = 'cat'
    # game = 'dragon'
    # game = 'aegis-project-180-files'
    # game = 'creme-de-la-creme'

    print(cfg.trees_dir.absolute())
    with open(cfg.trees_dir / f'{game}_game_tree.json') as f:
        tree_obj = json.load(f)

    with open(cfg.scenarios_dir / f'{game}_scenarios_bfs{cfg.filled_suffix}.json') as f:
        scenarios_obj = json.load(f)

    # Load utility values
    # with open(cfg.annotations_dir / game / cfg.utility_annotations_filename) as f:
    #     utilities_obj = json.load(f)
    #     num_utility_levels = utilities_obj['num_utility_levels']
    #     utilities_by_scene = utilities_obj['annotations_by_scene']
    utilities_by_scene = {}

    G = nx.DiGraph()
    for key, obj in list(tree_obj.items()):

        # if not key.startswith('1-hideandseek'):
        # if not key.startswith('2-usedtobefriends'):
        # if not key.startswith('10-endgame'):
        if not key.startswith('04_familytension'):
            continue

        scene_text = ""
        start_line = int(obj['start'])
        end_line = int(obj['end'])
        for scene_id in scenarios_obj:
            if scene_id.startswith(obj['scene']) and start_line <= int(scene_id.split('_')[-1]) <= end_line:
                scene_text += scenarios_obj[scene_id] + '\n'

        if scene_text == '':
            warningtext = f"WARN: Missing scene text for {key}"
            scene_text = warningtext
            print(warningtext)

        relevant_utils = []
        if key in utilities_by_scene:
            utilities_per_character = utilities_by_scene[key]
            for character, utility_val in utilities_per_character.items():
                if utility_val != -1:
                    utility_val = float(
                        utility_val - num_utility_levels // 2) / (num_utility_levels // 2)
                    relevant_utils.append(utility_val)
                    scene_text += f"\n{character}: {utility_val}"
                else:
                    scene_text += f"\n{character}: skipped"
        node_utility = np.mean(relevant_utils) if len(
            relevant_utils) > 0 else 0
        scene_text += f"\nKey: {key}"

        grey_shade = 240
        color_component = grey_shade - int(grey_shade * np.abs(node_utility))
        color = f'rgb({grey_shade},{color_component},{color_component})' if node_utility < 0 else f'rgb({color_component},{grey_shade},{color_component})'

        G.add_node(key, group=key[:5], title=scene_text, color=color)
        for child in obj['children']:
            if child not in G.nodes:
                G.add_node(
                    child, group=child[:5], title='', color=f'rgb({grey_shade},{grey_shade},{grey_shade})')
            G.add_edge(
                key, child, color=f'rgb({grey_shade-50},{grey_shade-50},{grey_shade-50})')

    net = Network('1500px', '2500px', directed=True)
    net.from_nx(G)
    net.show_buttons(filter_=['physics', 'layout'])
    net.show(str(cfg.generate_tree_path))
    print(f"Saved tree to {cfg.generate_tree_path}")


if __name__ == '__main__':
    main()
