import logging
import random
import time
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import List, Optional

import pandas as pd
import pyrallis
import wandb

from machiavelli.config import MachiavelliConfig
from machiavelli.annotate.check_status import get_salient_scene_ids
from machiavelli.utils import get_slurm_id, random_seed


@dataclass
class LabelConfig:
    topic: str # "money" / "morality" / "social" / "utility" / "watts"
    game: str = "all"
    random_seed: int = 0
    batch_size: int = 10
    max_retries: int = 5
    eval: bool = False
    max_scenes_per_game: int = 1000000
    
    # model: str = "dummy" # dummy model for testing
    # model: str = "text-davinci-002" # GPT3
    model: str = "text-alpha-002-latest" # XYZ

    common_logfile: Optional[Path] = None # Log file for all games, default to `topic/output/common.log`
    error_label: int = -100 # Label to indicate label needs to be re-done
    retry_interval_sec: int = 150 # Time to wait before retrying a failed request

    gold_only: bool = False # Only label the subset of scenes that we have gold labels for
    salient_only: bool = False # Only label the subset of scenes that have been marked as positive so far (for model+ labelling)

    # Run metadata
    group: Optional[str] = None # Name for this experiment group, default to `topic`
    tags: List[str] = field(default_factory=list) # Tags to apply to this run
    job_type: str = "label" # Type of job
    
    def __post_init__(self):
        if self.group is None:
            self.group = self.topic
        if self.common_logfile is None:
            self.common_logfile = Path(__file__).parent / f"{self.topic}/output/common.log"


@pyrallis.wrap()
def main(cfg: LabelConfig):
    # from machiavelli.annotate.{cfg.topic}.labels import get_batched_predictions
    get_batched_predictions = getattr(
        __import__(f"machiavelli.annotate.{cfg.topic}.prompt", fromlist=["get_batched_predictions"]),
        "get_batched_predictions")

    cfg.common_logfile.parent.mkdir(parents=True, exist_ok=True)
    logging.basicConfig(
        filename=cfg.common_logfile,
        level=logging.ERROR,
        format='%(asctime)s %(message)s', datefmt='%m/%d/%Y %I:%M:%S %p'
    )

    wandb.init(
        project="machiavelli",
        job_type=cfg.job_type,
        group=cfg.group,
        tags=cfg.tags,
        config=asdict(cfg),
        # mode='disabled',
    )
    # If using SLURM, take note of job id
    slurm_job_id = get_slurm_id()
    if slurm_job_id:
        wandb.run.name = f"{slurm_job_id}"
        wandb.config.slurm_job_id = slurm_job_id

    machiavelli_cfg = MachiavelliConfig()
    
    games = sorted(machiavelli_cfg.working_games) if cfg.game == "all" else [cfg.game]
    # games = games[:30] # TODO: Remove this
    
    for game in games:
        print("Running annotation for game:", game)

        game_results = []

        # Get set of scenes to annotate
        game_meta = machiavelli_cfg.get_game_metadata(game)
        scenes_obj = machiavelli_cfg.get_scenarios(game)
        print("len(scenes_obj)", len(scenes_obj))
        scenes_to_annotate = set(scenes_obj.keys())
        if cfg.gold_only:
            # Only run on the subset of scenes that we have gold labels for
            # labelled_scenes_df = pd.read_csv(machiavelli_cfg.base_data_dir / "unlabelled/gold_scenarios.csv")
            labelled_scenes_df = pd.read_csv(machiavelli_cfg.annotations_expert_dir / f"{cfg.topic}.csv")
            scenes_to_annotate = [row['scene_id'] for idx, row in labelled_scenes_df.iterrows() if row['game'] == game]
        if cfg.salient_only:
            print("Labelling only salient scenes")
            # Only run on scenes that have been labelled as salient (positive labels)
            scenes_to_annotate = get_salient_scene_ids(game, cfg.topic)
        print("initial len(scenes_to_annotate)", len(scenes_to_annotate))

        # Check if we are resuming a run
        game_results_path = Path(__file__).parent / f"{cfg.topic}/output/{game}.csv"
        game_results_path.parent.mkdir(exist_ok=True, parents=True)
        scenes_already_annotated = set()
        if game_results_path.exists():
            print("Resuming existing run")
            game_results_df = pd.read_csv(game_results_path)
            # Populate with existing results (except the labels that need to be re-done)
            game_results = [row.to_dict() for _, row in game_results_df.iterrows() if row['pred'] != cfg.error_label]
            scenes_already_annotated = set([row['scene_id'] for row in game_results])
        print("len(scenes_already_annotated)", len(scenes_already_annotated))

        random_seed(cfg.random_seed)
        scenes_to_annotate = sorted(list(set(scenes_to_annotate) - scenes_already_annotated))
        random.shuffle(scenes_to_annotate)
        scenes_to_annotate = scenes_to_annotate[:cfg.max_scenes_per_game]

        print("final len(scenes_to_annotate)", len(scenes_to_annotate))
        print("batch_size", cfg.batch_size)

        tstart = time.time()
        idx = 0
        while idx < len(scenes_to_annotate):
            batch_scene_ids = list(scenes_to_annotate)[idx: idx + cfg.batch_size]
            print("Running batch:", batch_scene_ids)

            scene_texts = [(scenes_obj[scene_id] if scene_id in scenes_obj else "") for scene_id in batch_scene_ids]

            # Batch prompting and prediction
            answers, prompt, response = [], None, None
            for n_attempts_remaining in range(cfg.max_retries, 0, -1):
                answers, prompt, response = get_batched_predictions(
                    model=cfg.model,
                    batch_of_scenes=scene_texts,
                    game_title=game_meta['name'],
                    game_summary=game_meta['short_summary'],
                    player_character_name=game_meta['player_character'],
                    player_character_aliases=game_meta['player_character_aliases'],
                    )
                if len(answers) > 0:
                    print(f"Obtained {len(answers)}/{len(scene_texts)} predictions")
                    break
                else:
                    print("No answers found. Retrying...")

                print(f"{n_attempts_remaining - 1} attempts remaining)")
            if len(answers) == 0:
                err_msg = f"Failed to get answers after {cfg.max_retries} retries."
                print(err_msg)
                logging.error(f"{slurm_job_id}: {game} {err_msg}")
                answers = [(cfg.error_label, err_msg)]

            for scene_id, answer in zip(batch_scene_ids, answers):
                pred, explanation = answer

                game_results.append({
                    "game": game,
                    "scene_id": scene_id,
                    # "scene_text": scenes_obj[scene_id],
                    f"pred": pred,
                    f"explanation": explanation,
                    # f"correct": correct,
                    # f"prompt": prompt,
                    # f"res": response,
                })
            game_results_df = pd.DataFrame(game_results)
            # Current file
            print("Saving results to", game_results_path)
            game_results_df.to_csv(game_results_path, index=False)

            idx += len(answers)
            print(f"Progress: {idx}/{len(scenes_to_annotate)} | Elapsed time: {(time.time() - tstart)/60.:.2f} minutes", flush=True)
            print("\n\n\n\n\n")
        print(f"DONE {game}")
        print(f"Number of annotations: {len(game_results)} / {len(scenes_obj)}")

if __name__ == "__main__":
    main()