import json
import numpy as np
import pandas as pd
from machiavelli.openai_helpers import num_tokens

from machiavelli.config import MachiavelliConfig
from tqdm import tqdm

def get_token_counts(samples_per_game = 100):
    machiavelli_cfg = MachiavelliConfig()
    scene_lengths = []
    scene_words = []
    scene_toks = []
    all_topics = ["money", "morality", "social", "utility", "watts"]
    completion_lengths = {topic: [] for topic in all_topics}
    completion_toks = {topic: [] for topic in all_topics}
    prompt_lengths = {topic: [] for topic in all_topics}
    prompt_toks = {topic: [] for topic in all_topics}

    for game in tqdm(machiavelli_cfg.working_games):
        np.random.seed(0)
        id_2_scene_text = machiavelli_cfg.get_scenarios(game)
        for scene_id in np.random.choice(list(id_2_scene_text.keys()), samples_per_game):
            scene_text = id_2_scene_text[scene_id]
            scene_lengths.append(len(scene_text))
            scene_words.append(len(scene_text.split()))
            scene_toks.append(num_tokens(scene_text))
        
        for topic in ["money", "morality", "social", "utility", "watts"]:
            p = machiavelli_cfg.annotations_model_dir / topic / "0" / f"{game}.csv"
            if not p.exists():
                continue
            df = pd.read_csv(p)
            for i in range(samples_per_game):
                completion_text = df.iloc[i]["explanation"]
                completion_lengths[topic].append(len(completion_text))
                completion_toks[topic].append(num_tokens(completion_text))

            # Get prompt
            game_meta = machiavelli_cfg.get_game_metadata(game)
            get_batched_predictions = getattr(
                __import__(f"machiavelli.annotate.{topic}.prompt", fromlist=["get_batched_predictions"]),
                "get_batched_predictions")
            _, prompt, _ = get_batched_predictions(
                model="dummy",
                batch_of_scenes=[],
                game_title=game_meta['name'],
                game_summary=game_meta['short_summary'],
                player_character_name=game_meta['player_character'],
                player_character_aliases=game_meta['player_character_aliases'],
            )
            prompt_lengths[topic].append(len(prompt))
            prompt_toks[topic].append(num_tokens(prompt))

    print("avg_scene_length", np.mean(scene_lengths))
    print("avg_scene_words", np.mean(scene_words))
    print("avg_scene_tokens", np.mean(scene_toks))
    for topic in ["money", "morality", "social", "utility", "watts"]:
        print(f"{topic}_avg_completion_length:", np.mean(completion_lengths[topic]))
        print(f"{topic}_avg_completion_tokens:", np.mean(completion_toks[topic]))
        print(f"{topic}_avg_prompt_length:", np.mean(prompt_lengths[topic]))
        print(f"{topic}_avg_prompt_tokens:", np.mean(prompt_toks[topic]))


def get_label_counts():
    """
    # Each json file will look like this
    "<SCENE_ID>": {
        "utility": {
            "self": {
                "-3": <float>,
                "-2": <float>,
                "-1": <float>,
                "0": <float>,
                "1": <float>,
                "2": <float>,
                "3": <float>
            }
            "others": {
                "-3": <float>,
                "-2": <float>,
                "-1": <float>,
                "0": <float>,
                "1": <float>,
                "2": <float>,
                "3": <float>
            }
        },
        "physical": {
            "watts": <float>
        },
        "money": {
            "dollars": <float>
        },
        "social": {
            "influence": {
                "0": <float>,
                "1": <float>,
                "2": <float>
            }
        },
        "morality": {
            "killing": <float>,
            "physical_harm": <float>,
            "non_physical_harm": <float>,
            "intending_harm": <float>,
            "deception": <float>,
            "manipulation": <float>,
            "promises": <float>,
            "stealing": <float>,
            "trespassing": <float>,
            "spying": <float>,
            "vandalism": <float>,
            "fairness": <float>,
            "other": <float>
        }
    }
    """

    # Count the number of salient (highest-level) annotations for each scene
    utility = 0
    physical = 0
    money = 0
    social = 0
    morality = {
        "killing": 0,
        "physical_harm": 0,
        "non_physical_harm": 0,
        "intending_harm": 0,
        "deception": 0,
        "manipulation": 0,
        "promises": 0,
        "stealing": 0,
        "trespassing": 0,
        "spying": 0,
        "vandalism": 0,
        "fairness": 0,
        "other": 0
    }
    total = 0
    for game in tqdm(cfg.working_games):
        obj = cfg.get_annotations(game)
        total += len(obj)
        for scene, obj in obj.items():
            try:
                if "utility" in obj:
                    if any([obj["utility"]["self"][k] > cfg.confidence_threshold for k in cfg.utility_salient_labels]) \
                        or any([obj["utility"]["others"][k] > cfg.confidence_threshold for k in cfg.utility_salient_labels]):
                        utility += 1
                if "physical" in obj:
                    if obj["physical"]["watts"] > cfg.watts_threshold:
                        physical += 1
                if "money" in obj:
                    if obj["money"]["dollars"] > cfg.dollars_threshold:
                        money += 1
                if "social" in obj:
                    if any([obj["social"]["influence"][k] > cfg.confidence_threshold for k in cfg.social_salient_labels]):
                        social += 1
                if "morality" in obj:
                    for k in obj["morality"]:
                        if obj["morality"][k] > cfg.confidence_threshold:
                            morality[k] += 1
            except KeyError:
                print("KeyError", game, scene)
    print("utility", utility)
    print("physical", physical)
    print("money", money)
    print("social", social)
    for k in morality:
        print(k, morality[k])
    print("total", total)

if __name__ == "__main__":

    cfg = MachiavelliConfig()
    print("num_games", len(cfg.working_games))
    print("total_num_scenes", sum([len(cfg.get_scenarios(game)) for game in cfg.working_games]))

    # # Count how many tokens were used to label the dataset, for cost estimates
    # get_token_counts()

    # Count how many of various labels for each topic are present in the dataset
    get_label_counts()