from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from tqdm import tqdm

from episode.decoders import Decoder
from episode.encoders import Encoder
from episode.episode_types import EpisodeInput, ExperimentResults
from episode.overseers import Overseer
from episode.paraphrasers import Paraphraser
from episode.payloads import Payload, PayloadInstance
from episode.tasks import Task, TaskInstance
from logger import logger
from natural_episode.decoders.natural_decoder import NaturalDecoder
from natural_episode.encoders.natural_encoder import NaturalEncoder
from natural_episode.natural_episode_types import NaturalExperimentResults
from natural_episode.overseers.natural_overseer import NaturalOverseer
from natural_episode.tasks import NaturalTask, NaturalTaskInstance
from serde import deserialize_from_json, serde_to_dict, serialize
from utils import load_json_file


@dataclass
class RedTeamConfig:
    encoder: Encoder
    decoder: Decoder


@dataclass
class NaturalRedTeamConfig(RedTeamConfig):
    encoder: NaturalEncoder
    decoder: NaturalDecoder


@dataclass
class ExperimentConfig:
    task: Task
    payload_type: Payload
    paraphraser: Optional[Paraphraser] = None
    overseer: Optional[Overseer] = None
    seed: Optional[int] = None


@dataclass
class NaturalExperimentConfig:
    task: NaturalTask
    paraphraser: Optional[Paraphraser] = None
    overseer: Optional[NaturalOverseer] = None
    seed: Optional[int] = None


def save_results(
    results_dir: Path,
    existing_red_teams: dict[str, tuple[RedTeamConfig, ExperimentResults]],
    new_configs: dict[str, RedTeamConfig] | dict[str, NaturalRedTeamConfig],
    new_results: dict[str, ExperimentResults] | dict[str, NaturalExperimentResults],
):
    all_configs = {}
    all_results = {}
    for team_name, (config, results) in existing_red_teams.items():
        all_configs[team_name] = config
        all_results[team_name] = results
    all_configs |= new_configs
    all_results |= new_results
    results_file_path = results_dir / "results.json"
    logger.info(f"Saving results file to: {results_file_path}")
    results_file_path.write_text(serialize(all_results))
    redteam_config_file_path = results_dir / "redteam_configs.json"
    logger.info(f"Updating red team config file at: {redteam_config_file_path}")
    redteam_config_file_path.write_text(serialize(all_configs))


def initialise_experiment_dir(
    results_dir: Path,
    inputs: list[EpisodeInput] | list[NaturalTaskInstance],
    experiment: ExperimentConfig | NaturalExperimentConfig,
    red_team_configs: dict[str, RedTeamConfig] | dict[str, RedTeamConfig],
    expect_changing_payloads: bool = False,
) -> dict[
    str,
    tuple[RedTeamConfig, ExperimentResults]
    | tuple[NaturalRedTeamConfig, NaturalExperimentResults],
]:
    if results_dir.exists():
        experiment_config, existing_inputs, previous_red_teams = (
            load_experiment(results_dir)
            if isinstance(experiment, ExperimentConfig)
            else load_natural_experiment(results_dir)
        )
        # Check that the experiment is not conflicting with old results
        if expect_changing_payloads:
            # If we expect the payloads to change, we don't need to check the inputs
            assert {
                k: v
                for k, v in serde_to_dict(experiment_config).items()
                if not k == "payload_type"
            } == {
                k: v
                for k, v in serde_to_dict(experiment).items()
                if not k == "payload_type"
            }, "Attempt to reuse experiment dir with different experiment config"
        else:
            assert serialize(experiment_config) == serialize(
                experiment
            ), "Attempt to reuse experiment dir with different experiment config"
            assert serialize(inputs) == serialize(
                existing_inputs
            ), "Attempt to reuse experiment dir with different inputs"
        assert not any(
            [team_name in previous_red_teams for team_name in red_team_configs]
        ), "Found an existing red team with the same name."

        return previous_red_teams
    else:
        # First time this results dir is being used. Need to initialize it.
        previous_red_teams = {}
        logger.info("Creating directory to store results at: {results_dir}")
        results_dir.mkdir(parents=True)

        experiment_config_file_path = results_dir / "experiment_config.json"
        logger.info(f"Saving experiment config file to: {experiment_config_file_path}")
        experiment_config_file_path.write_text(serialize(experiment))

        redteam_config_file_path = results_dir / "redteam_configs.json"
        logger.info(f"Saving red team config file to: {redteam_config_file_path}")
        redteam_config_file_path.write_text(serialize(red_team_configs))

        input_file_path = results_dir / "inputs.json"
        logger.info(f"Saving inputs to: {input_file_path}")
        input_file_path.write_text(serialize(inputs))
        return {}


def load_experiment(
    experiment_dir: Path,
) -> tuple[
    ExperimentConfig,
    list[EpisodeInput],
    dict[str, tuple[RedTeamConfig, ExperimentResults]],
]:
    """Load an experiment from a directory.

    Args:
        experiment_dir: the directory containing the experiment

    Returns:
        (the experiment config, the inputs, a dict of red teams with (config, results) tuples for each team)
    """

    logger.info(f"Loading experiment from {experiment_dir}")

    # First, load all the files in as json
    experiment_config_json = load_json_file(experiment_dir / "experiment_config.json")
    try:
        redteam_configs_json = load_json_file(experiment_dir / "redteam_configs.json")
        results_json = load_json_file(experiment_dir / "results.json")
    except FileNotFoundError:
        # The run didn't finish. Just don't load it.
        redteam_configs_json = {}
        results_json = {}
    inputs_json = load_json_file(experiment_dir / "inputs.json")

    # Then, cast them to their appropriate types
    experiment_config = deserialize_from_json(ExperimentConfig, experiment_config_json)
    task_type = experiment_config.task
    payload_type = experiment_config.payload_type
    inputs = deserialize_from_json(
        list[EpisodeInput],
        inputs_json,
        deserializers={
            TaskInstance: task_type.instance_initializer,
            PayloadInstance: payload_type.create_payload_safe,
        },
    )
    red_teams = {}
    for red_team_name, red_team_results in results_json.items():
        if "results" in red_team_results:
            results_obj = deserialize_from_json(
                ExperimentResults,
                red_team_results,
                deserializers={
                    TaskInstance: task_type.instance_initializer,
                    PayloadInstance: payload_type.create_payload_safe,
                },
            )
        else:
            # Backwards compatibility for loading old results
            results_obj = deserialize_from_json(
                ExperimentResults,
                {
                    "results": red_team_results,
                    "inputs": inputs_json,
                },
                deserializers={PayloadInstance: payload_type.create_payload_safe},
            )
        extra_args = {"task_type": task_type, "payload_type": payload_type}
        config_obj = deserialize_from_json(
            RedTeamConfig,
            redteam_configs_json[red_team_name],
            deserializers={
                Encoder: lambda data: deserialize_from_json(Encoder, data | extra_args),
                Decoder: lambda data: deserialize_from_json(Decoder, data | extra_args),
            },
        )
        red_teams[red_team_name] = (config_obj, results_obj)
    return (experiment_config, inputs, red_teams)


def load_natural_experiment(
    experiment_dir: Path,
) -> tuple[
    NaturalExperimentConfig,
    list[TaskInstance],
    dict[str, tuple[NaturalRedTeamConfig, NaturalExperimentResults]],
]:
    """Load an experiment from a directory.

    Args:
        experiment_dir: the directory containing the experiment

    Returns:
        (the experiment config, the inputs, a dict of red teams with (config, results) tuples for each team)
    """

    logger.info(f"Loading experiment from {experiment_dir}")

    # First, load all the files in as json
    experiment_config_json = load_json_file(experiment_dir / "experiment_config.json")
    try:
        redteam_configs_json = load_json_file(experiment_dir / "redteam_configs.json")
        results_json = load_json_file(experiment_dir / "results.json")
    except FileNotFoundError:
        # The run didn't finish. Just don't load it.
        redteam_configs_json = {}
        results_json = {}
    inputs_json = load_json_file(experiment_dir / "inputs.json")

    # Then, cast them to their appropriate types
    experiment_config = deserialize_from_json(
        NaturalExperimentConfig, experiment_config_json
    )
    task_type = experiment_config.task
    inputs = deserialize_from_json(
        list[NaturalTaskInstance],
        inputs_json,
        deserializers={
            NaturalTaskInstance: task_type.instance_initializer,
        },
    )
    red_teams = {}
    for red_team_name, red_team_results in results_json.items():
        results_obj = deserialize_from_json(
            NaturalExperimentResults,
            red_team_results,
            deserializers={NaturalTaskInstance: task_type.instance_initializer},
        )

        extra_args = {"task_type": task_type}
        config_obj = deserialize_from_json(
            RedTeamConfig,
            redteam_configs_json[red_team_name],
            deserializers={
                Encoder: lambda data: deserialize_from_json(
                    NaturalEncoder, data | extra_args
                ),
                Decoder: lambda data: deserialize_from_json(
                    NaturalDecoder, data | extra_args
                ),
            },
        )
        red_teams[red_team_name] = (config_obj, results_obj)
    return (experiment_config, inputs, red_teams)


def run_episodes(thread: bool, episodes: dict, *args):
    if thread:
        with ThreadPoolExecutor() as executor:
            return {
                team_name: list(
                    tqdm(executor.map(episode.run, *args), total=len(episodes))
                )
                for team_name, episode in episodes.items()
            }
    else:
        return {
            team_name: list(tqdm(map(episode.run, *args), total=len(episodes)))
            for team_name, episode in episodes.items()
        }


if __name__ == "__main__":
    load_experiment(
        Path(
            r"/home/user/steganography/results/24-08-22_aqua-rat-ca-mb/2-bit/seed_0/icl_training_runs"
        )
    )
