import random
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from pathlib import Path

from in_context_learning.distributed_natural_icl_learners import (
    DistributedLearnerConfig,
    NaturalDistributedICLLearner,
)
from in_context_learning.icl_learner import NaturalICLLearner

from episode.paraphrasers import LLMParaphraser
from logger import logger
from models import ClaudeModelType
from models.GPT import GPTModelType
from natural_episode.overseers import CVLLMOverseer
from natural_episode.tasks import CVTask
from run_natural_experiment import NaturalExperimentConfig


def run_icl_training(
    learner: NaturalICLLearner,
    experiment_config: NaturalExperimentConfig,
    save_dir: Path,
):
    random.seed(experiment_config.seed)
    learner.analyze_data_initial()
    save_dir.mkdir(exist_ok=True, parents=True)

    try:
        for _ in range(learner_config.n_episodes):
            if learner.refine_strategy():
                break
    except Exception:
        logger.exception(f"Seed {experiment_config.seed} failed to finish.")
        learner.log_messages("icl-log-failed.txt")
    finally:
        save_dir.mkdir(exist_ok=True)


# Main run script
if __name__ == "__main__":
    ### SELECT TASK AND NUMBER OF SAMPLES (move the selected task to the bottom) ###
    num_samples = 30
    n_steps = 20
    task = CVTask(num_samples)

    # Set the save path for the results
    save_dir_name = "24-09-12_natural-icl-llm-overseer-6"
    base_dir = Path("results", save_dir_name)
    base_dir.mkdir(exist_ok=True, parents=True)

    n_reps = 5
    all_learners = []
    all_experiment_configs = []
    all_save_dirs = []
    # Set the learner config

    for model_name, model in [
        ("gpt-4o-mini", GPTModelType.GPT4OMINI),
        ("gpt-4o", GPTModelType.GPT4O),
        # ("claude-3-haiku", ClaudeModelType.CLAUDE_3_HAIKU),
        ("claude-35-sonnet-nc", ClaudeModelType.CLAUDE_35_SONNET),
    ]:
        model_dir = base_dir / model_name
        model_dir.mkdir(exist_ok=True, parents=True)
        experiment_configs = [
            NaturalExperimentConfig(
                task=deepcopy(task),
                overseer=CVLLMOverseer(GPTModelType.GPT4OMINI),
                paraphraser=None,
                seed=i,
            )
            for i in range(n_reps)
        ]
        all_experiment_configs += experiment_configs

        save_dirs = [model_dir / f"seed-{i}" for i in range(n_reps)]
        all_save_dirs += save_dirs

        learner_config = DistributedLearnerConfig(
            learner_model_type=model,
            encoder_model_type=model,
            decoder_model_type=model,
            n_episodes=n_steps,
            # reflection_type=ReflectionType.NoReflection,
            # clear_and_summarize=False,
            # n_result_examples=2,
        )
        all_learners += [
            NaturalDistributedICLLearner(learner_config, experiment_config, save_dir)
            for experiment_config, save_dir in zip(experiment_configs, save_dirs)
        ]

    paraphraser = LLMParaphraser(GPTModelType.GPT4OMINI)
    # Generate your experiment configs

    with ThreadPoolExecutor(max_workers=128) as pool:
        rets = pool.map(
            run_icl_training,
            all_learners,
            all_experiment_configs,
            all_save_dirs,
        )

    logger.info("\nAll done! Results saved to: \n")
