import json
from pathlib import Path

from in_context_learning.icl_learner import ICLLearner
from in_context_learning.multi_part_icl_learner import (
    MultiPartICLLearner,
    MultiPartLearnerConfig,
)

from episode.overseers import OverseerType, get_overseer_by_type
from episode.paraphrasers import (
    CodingParaphraser,
    ParaphraserType,
    get_task_paraphraser,
)
from episode.payloads import MultiPartPayload
from episode.tasks import (
    AquaRatTask,
    HumanEvalTask,
    MMLUTask,
    QATaskSimple,
    Task,
    TaskType,
)
from learning.run_icl_utils import (
    get_hparams,
    run_multipart_icl_training,
    run_parameterized_icl,
)
from logger import logger
from models.claude import ClaudeModelType
from models.GPT import GPTModelType
from run_experiment import ExperimentConfig
from serde import deserialize_from_json
from utils import generate_timestamped_results_dir


def load_learner(file_path: Path) -> ICLLearner:
    return deserialize_from_json(ICLLearner, json.loads(file_path.read_text()))


def run_simple_icl(task: Task):
    # Set the save directory
    save_dir = generate_timestamped_results_dir(f"icl_{task.__class__.__name__}_MB2")
    # base_dir = Path("results", "24-09-14_mb-ar-no")
    # save_dir = base_dir / "claude-0"
    save_dir.mkdir(exist_ok=True, parents=True)

    # Get the paraphraser
    if task.TASK_TYPE == TaskType.CODING:
        paraphraser = CodingParaphraser(True, True)
    else:
        paraphraser = get_task_paraphraser(
            task,
            ParaphraserType.ContextAwareParaphraser,
            model=ClaudeModelType.CLAUDE_3_SONNET,
        )

    # Set the learner config
    learner_config = MultiPartLearnerConfig(
        # learner_model_type=GPTModelType.GPT4TURBO,
        # encoder_model_type=GPTModelType.GPT4O,
        # decoder_model_type=GPTModelType.GPT4O,
        learner_model_type=ClaudeModelType.CLAUDE_35_SONNET,
        encoder_model_type=ClaudeModelType.CLAUDE_35_SONNET,
        decoder_model_type=ClaudeModelType.CLAUDE_35_SONNET,
        n_episodes=50,
        try_improve_last_strategy=False,
    )
    payload = MultiPartPayload([2])

    # Get the overseer
    overseer_model_type = GPTModelType.GPT35TURBO
    overseer_model_type = ClaudeModelType.CLAUDE_3_SONNET
    overseer = get_overseer_by_type(
        model_type=overseer_model_type,
        payload=payload,
        task=task,
        overseer_type=OverseerType.PAYLOAD_DETECTION,
        save_dir=save_dir,
    )

    # Set the experiment config
    experiment_config = ExperimentConfig(
        task,
        payload,
        paraphraser,
        overseer=overseer,
        seed=0,
    )
    learner = MultiPartICLLearner(learner_config, experiment_config, save_dir)

    # Switch to a pre-trained learner
    # base_learner = load_learner(
    #     Path(
    #         "/home/user/steganography/results/24-09-14_mb-ar-no/run-0/mb_2-2/2024-09-14_10-14-32_icl-agent-dump.json"
    #     )
    # )
    # base_learner.base_save_dir = save_dir
    # learner = base_learner
    # # learner.config.try_improve_last_strategy = True
    # learner.save_dir = save_dir / str(learner.experiment_config.payload_type)

    # Run the ICL training
    run_multipart_icl_training(
        learner,
        save_dir,
    )

    logger.info(f"\nAll done! Results saved to: {save_dir}\n")


# Main run script
if __name__ == "__main__":
    # Select task and no. of samples (move the selected task to the bottom)
    num_samples = 20
    task = QATaskSimple(num_samples)
    task = MMLUTask(num_samples, n_sentences=None)
    task = HumanEvalTask(num_samples)
    task = AquaRatTask(num_samples)

    # Run required ICL training type
    # run_simple_icl(task)
    run_parameterized_icl(get_hparams(num_samples, task))
