import json
import time
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional

import pandas as pd
import pyrallis

from machiavelli.config import MachiavelliConfig


def get_salient_scene_ids(game, topic):
    """
    Get scene ids with labels that are salient (non-negligible) for the given topic
    """
    existing_annotations = MachiavelliConfig().get_annotations(game)
    salient_ids = []
    for scene_id in existing_annotations:
        try:
            if topic == "utility":
                salient_labels = ['-3', '3']
                total = sum([existing_annotations[scene_id]['utility']['self'][label] for label in salient_labels]) + \
                    sum([existing_annotations[scene_id]['utility']['others'][label] for label in salient_labels])
                if total > 0:
                    salient_ids.append(scene_id)
            elif topic == "watts":
                if existing_annotations[scene_id]['physical']['watts'] > 10000:
                    salient_ids.append(scene_id)
            elif topic == "money":
                if existing_annotations[scene_id]['money']['dollars'] > 10000:
                    salient_ids.append(scene_id)
            elif topic == "social":
                salient_labels = ['2']
                total = sum([existing_annotations[scene_id]['social']['influence'][label] for label in salient_labels])
                if total > 0:
                    salient_ids.append(scene_id)
            elif topic == "morality":
                for key, value in existing_annotations[scene_id]['morality'].items():
                    if value > 0:
                        salient_ids.append(scene_id)
                        break
        except KeyError:
            print("KeyError", game, scene_id)
    return salient_ids

def get_game_completion_rate(topic, game, invalid_label_flag=-100, whitelist_scenes: Optional[List[str]] = None):
    machiavelli_cfg = MachiavelliConfig()

    # Load main character util buckets
    # datafile = machiavelli_cfg.annotations_model_dir / topic / "1" / f"{game}.csv"
    datafile = Path(__file__).parent / f"{topic}/output/{game}.csv"
    
    # Load full scenes
    scenes_to_label = set(machiavelli_cfg.get_scenarios(game).keys() if whitelist_scenes is None else whitelist_scenes)
    
    num_valid_labels = 0
    if datafile.exists():
        try:
            df = pd.read_csv(datafile)
            valid_labels = [row["scene_id"] for idx, row in df.iterrows() if row["pred"] != invalid_label_flag]
            valid_labels = set(valid_labels).intersection(scenes_to_label)
            return len(valid_labels) / len(scenes_to_label), len(valid_labels), len(scenes_to_label)
        except pd.errors.EmptyDataError:
            print(f"Empty data file: {datafile}")
            return 0, 0, len(scenes_to_label)
    return 0, 0, len(scenes_to_label)

@dataclass
class Config:
    topic: str = "money" # "money" / "morality" / "social" / "utility" / "watts"
    salient_only: bool = False # Only consider scenes with salient labels

@pyrallis.wrap()
def main(cfg: Config):
    topic = cfg.topic
    print(f"Checking progress for label type: {topic}\n")

    machiavelli_cfg = MachiavelliConfig()
    games = sorted(machiavelli_cfg.working_games)
    progress = 0
    total = 0
    for game in games:
        whitelist = None
        if cfg.salient_only:
            whitelist = get_salient_scene_ids(game, topic)
        _, num_valid_labels, num_game_scenes = get_game_completion_rate(topic, game, whitelist_scenes=whitelist)
        progress += num_valid_labels
        total += num_game_scenes
        print(f"{game:50}:\t{num_valid_labels} / {num_game_scenes}")

    print("\nOverall:")
    print(f"Progress {float(progress) / total} ({progress} / {total})")
    print("time:", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

if __name__ == "__main__":
    main()