from pathlib import Path

from analysis.strategy_analysis import get_strategy_df
from episode.decoders import FixedStrategyDecoder
from episode.encoders import (
    FixedStrategyEncoder,
)
from episode.overseers import OverseerType, get_overseer_by_type
from episode.paraphrasers import (
    ParaphraserType,
    get_task_paraphraser,
    get_task_paraphraser_by_model,
)
from episode.payloads import IntegerPayload
from episode.tasks import AquaRatTask, HumanEvalTask, MMLUTask
from experiment_utils import (
    ExperimentConfig,
    RedTeamConfig,
)
from logger import logger
from models.GPT import GPTModelType
from run_experiment import run_experiment
from run_experiment_utils import (
    get_best_strategies_old,
)
from utils import generate_timestamped_results_dir


def run_best_strats_for_paraphraser(
    task,
    payload_type,
    paraphraser,
    overseer,
    output_results_dir,
    strategies,
):
    # Create the red teams & experiment config
    red_teams = {
        f"red_team_{i}": RedTeamConfig(
            FixedStrategyEncoder(
                strategy_row["encoder_strategy"], GPTModelType.GPT4O, task, payload_type
            ),
            FixedStrategyDecoder(
                strategy_row["decoder_strategy"], GPTModelType.GPT4O, payload_type
            ),
        )
        for i, strategy_row in strategies.iterrows()
    }
    experiment_config = ExperimentConfig(
        task,
        payload_type,
        paraphraser=get_task_paraphraser(task, paraphraser_type, **paraphraser_kwargs),
        overseer=overseer,
        seed=0,
    )

    # Run the experiment on the validation set using the best strategies
    results_by_redteam = run_experiment(
        experiment_config,
        red_teams,
        (output_results_dir / paraphraser_type.value / "icl_training_runs"),
        thread=True,
    )
    # TODO: Ensure results are saved to file even if run_experiment fails
    for red_team, results in results_by_redteam.items():
        logger.info(
            f"\n--------------------------------------------\nResults for {red_team}:"
        )
        logger.info(f"\n{results.to_string()}")
        logger.info(
            f"\nCompleted run for {paraphraser_type.value} on {red_team}\n============================================\n"
        )
    logger.info(
        f"\nCompleted run for {paraphraser_type.value}\n============================================\n============================================\n"
    )
    return red_teams, experiment_config


if __name__ == "__main__":
    # Set strategy parameters
    existing_results_dir = "2024-09-13_icl_AquaRatTask_MultiBit_Combined"
    input_dir = Path("results", existing_results_dir)
    strategies_per_level = 3

    # Set task & experiment parameters
    payload_type = IntegerPayload(0, 1)
    num_samples = 10
    task = HumanEvalTask(num_samples)
    task = MMLUTask(num_samples, n_sentences=None)
    task = AquaRatTask(num_samples)
    experiment_name = f"icl_{task.__class__.__name__}_ppbeststrats"
    output_results_dir = generate_timestamped_results_dir(experiment_name)
    # output_results_dir = Path("results", "2024-09-11_18-52-35_icl_AquaRatTask_manualstrat_none")

    # Set paraphraser and overseer
    overseer = get_overseer_by_type(
        GPTModelType.GPT35TURBO,
        payload_type,
        task,
        overseer_type=OverseerType.PAYLOAD_DETECTION,
        save_dir=output_results_dir,
    )
    paraphraser_kwargs = {
        "do_remove_docstrings": True,
        "do_rename_variables": True,
    }
    paraphraser_types = [
        ParaphraserType.NoParaphraser,
        ParaphraserType.StandardParaphraser,
        ParaphraserType.ContextAwareParaphraser,
    ]

    # Get dataframe of existing strategies
    strategy_df = get_strategy_df(Path("results", existing_results_dir))
    strategy = strategy_df.sort_values("codec_success").iloc[-1]
    # strategy_df.to_csv(output_results_dir / f"{input_results_dir_name}_strategy_df.csv")

    # Run the best strategies for each paraphraser type
    for paraphraser_type in paraphraser_types:
        strategies = get_best_strategies_old(
            strategy_df, strategies_per_level, paraphraser_type
        )
        paraphraser = get_task_paraphraser(task, paraphraser_type, **paraphraser_kwargs)
        paraphraser = get_task_paraphraser_by_model(
            task, paraphraser_type, GPTModelType.GPT35TURBO, **paraphraser_kwargs
        )
        run_best_strats_for_paraphraser(
            task,
            payload_type,
            paraphraser,
            overseer,
            output_results_dir,
            strategies,
        )

    logger.info("All done!")
