import json
from numbers import Number

from machiavelli.config import MachiavelliConfig
import machiavelli.annotate.parse_labels as parse_labels
from tqdm import tqdm

VERSION = 0.5


def fill_schema(data_dict: dict) -> dict:
    schema = {
        "utility": {
            "self": {
                "-3": 0,
                "-2": 0,
                "-1": 0,
                "0": 0,
                "1": 0,
                "2": 0,
                "3": 0
            },
            "others": {
                "-3": 0,
                "-2": 0,
                "-1": 0,
                "0": 0,
                "1": 0,
                "2": 0,
                "3": 0
            },
        },
        "physical": {
            "watts": 0
        },
        "money": {
            "dollars": 0
        },
        "social": {
            "influence": {
                "0": 0,
                "1": 0,
                "2": 0
            }
        },
        "morality": {
            "killing": 0,
            "physical_harm": 0,
            "non_physical_harm": 0,
            "intending_harm": 0,
            "deception": 0,
            "manipulation": 0,
            "promises": 0,
            "stealing": 0,
            "trespassing": 0,
            "spying": 0,
            "vandalism": 0,
            "fairness": 0,
            "other": 0
        },
    }
    for key in schema.keys():
        if key in data_dict:
            if isinstance(data_dict[key], dict):
                for subkey in schema[key].keys():
                    if subkey in data_dict[key]:
                        if isinstance(data_dict[key][subkey], dict):
                            for subsubkey in schema[key][subkey].keys():
                                if subsubkey in data_dict[key][subkey]:
                                    schema[key][subkey][subsubkey] = data_dict[key][subkey][subsubkey]
                        elif isinstance(data_dict[key][subkey], Number):
                            schema[key][subkey] = data_dict[key][subkey]
    return schema


def main():
    # Load all the model labels - this is an expensive operation so we only do it once at the start for all games
    all_model_labels = {}
    for parse_cfg in [
        parse_labels.ParseLabelsUtilitySelf(),
        parse_labels.ParseLabelsUtilityOthers(),
        parse_labels.ParseLabelsUtilityOtherIndivs(),
        parse_labels.ParseLabelsWatts(),
        parse_labels.ParseLabelsMoney(),
        parse_labels.ParseLabelsSocial(),
        parse_labels.ParseLabelsMorality(),
    ]:
        print(parse_cfg.label_topic, parse_cfg.label_name, flush=True)
        df_model = parse_labels.load_model_labels(parse_cfg)
        print(df_model)
        # Assert that model "id" are unique
        assert len(df_model["id"].unique()) == len(df_model["id"])
        # Convert df_model into a dict with "id" as the key
        df_model_dict = df_model.set_index("id").to_dict(orient="index")
        # Store for later use
        all_model_labels[f"{parse_cfg.label_topic}_{parse_cfg.label_name}"] = df_model_dict

    machiavelli_cfg = MachiavelliConfig()
    for game in machiavelli_cfg.working_games:
        print(game, flush=True)
        scene_labels_dict = {}
        for parse_cfg in [
            parse_labels.ParseLabelsUtilitySelf(),
            parse_labels.ParseLabelsUtilityOthers(),
            parse_labels.ParseLabelsUtilityOtherIndivs(),
            parse_labels.ParseLabelsWatts(),
            parse_labels.ParseLabelsMoney(),
            parse_labels.ParseLabelsSocial(),
            parse_labels.ParseLabelsMorality(),
        ]:
            # Grab cached model labels
            df_model_dict = all_model_labels[f"{parse_cfg.label_topic}_{parse_cfg.label_name}"]
            
            for scene_id, scene_text in machiavelli_cfg.get_scenarios(game).items():
                gamescene_id = f"{game}_{scene_id}"
                if gamescene_id in df_model_dict:
                    if parse_cfg.label_name is None:
                        label_name = None
                    else:
                        label_name = parse_cfg.label_name.split(f"{parse_cfg.label_topic}_")[-1]
                    
                    label = df_model_dict[gamescene_id][f'model_{parse_cfg.label_name}_ensemble']
                    if type(label) == str: # JSON string
                        label = json.loads(label)
                    
                    label_topic = parse_cfg.label_topic
                    if label_name == "watts": # TODO: Just rename to "physical" in the data!
                        label_topic = "physical"
                    
                    # if type(label) == dict:
                    # 	for value in label.values():
                    # 		if value > 1000000:
                    # 			print(gamescene_id, label_topic, label_name, label)
                    # 			print(scene_text)
                    # else:# type(label) == float:
                    # 	if label > 1000000:
                    # 		print(gamescene_id, label_topic, label_name, label)
                    # 		print(scene_text)
                    
                    if label_name is None:
                        label_obj = label
                    else:
                        label_obj = {
                            label_name: label
                        }
                    
                    if scene_id not in scene_labels_dict:
                        scene_labels_dict[scene_id] = {}
                    if parse_cfg.label_topic not in scene_labels_dict[scene_id]:
                        scene_labels_dict[scene_id][label_topic] = {}
                    scene_labels_dict[scene_id][label_topic].update(label_obj)
            
        # Enforce label schema (make sure any missed annotations have default values)
        scene_labels_dict = {scene_id: fill_schema(anno) for scene_id, anno in scene_labels_dict.items()}

        # Write the dict to a JSON file
        with open(machiavelli_cfg.annotations_clean_dir / f"{game}.json", "w") as f:
            json.dump({
                "version": VERSION,
                "scenes": scene_labels_dict
            }, f, indent=2)
        print("saved to ", machiavelli_cfg.annotations_clean_dir / f"{game}.json")
        print("\n\n", flush=True)
        # break

if __name__ == "__main__":
    main()

    # # Load existing annotations and fill in missing values
    # machiavelli_cfg = MachiavelliConfig()
    # for game in tqdm(machiavelli_cfg.working_games):
    #     curr_file = machiavelli_cfg.annotations_clean_dir / f"{game}.json"
    #     assert curr_file.exists()
    #     with open(curr_file) as f:
    #         annotations = json.load(f)
    #     scene_annotations = {scene_id: fill_schema(anno) for scene_id, anno in annotations['scenes'].items()}

    #     # Write the dict to a JSON file
    #     with open(machiavelli_cfg.annotations_clean_dir / f"{game}.json", "w") as f:
    #         json.dump({
    #             "version": VERSION,
    #             "scenes": scene_annotations
    #         }, f, indent=2)
    #     # break
