import os
import pandas as pd
from typing import List, Optional, Tuple
import multiprocessing


from utils.constants import (
    RECURSIVE,
    PARALLEL,
    LATEST_MESSAGE,
    ALL_MESSAGES,
    # LLMs
    GPT4O,
    GPT3_5,
    GPT4O_MINI,
    # TOOLBOX
    WEBSCRAPER,
    CSV_READER,
    CODER,
    EMAIL_READER,
    PDF_READER,
    # Threat types
    SCAM,
    DATA_THEFT,
    MALWARE,
    AVAILABILITY,
    MANIPULATED_CONTENT,
    DISINFORMATION_QA,
    DISINFORMATION_ARC,
    HEALTHY_ARC,
    HEALTHY_QA,
    MODEL_INFECTION,
    EXTERNAL_INFECTION,
    # Task types
    EMAIL,
    PDF,
    WEB,
    # Defense types
    NO_DEFENSE,
    SANDWICH,
    INSTRUCTION_DEFENSE,
    RANDOM_SEQUENCE_ENCLOSURE,
    DELIMITING_DATA,
    MARKING,
    MODEL_DELIMITER,
)
from utils.pretty_print import pretty_print
from utils.parse_args import parse_args
from experiments.graph import GraphExecutor
from experiments.agent_configs import get_agent_configs


def get_first_and_last_agents(
    threat_type: str, task_type: str
) -> Tuple[List[str], Optional[str]]:
    if threat_type == MODEL_INFECTION:
        return None, None

    if threat_type in {DISINFORMATION_QA, DISINFORMATION_ARC}:
        return [WEBSCRAPER], None

    task_to_agent = {EMAIL: EMAIL_READER, PDF: PDF_READER, WEB: WEBSCRAPER}
    first_agent = task_to_agent.get(task_type)

    if threat_type == DATA_THEFT:
        return [first_agent, CSV_READER], CODER

    return [first_agent], None


def run_experiment(
    executor: GraphExecutor,
    dataset_filename: str,
    log_filename: str,
    message_log_filename: str,
    num_agents: int,
    start_data_index: int = 0,
    end_data_index: int = 20,
    defense_type: str = None,
    counterattack: bool = False,
) -> None:
    pretty_print(f"  N agents: {num_agents}", color="cyan", size="small")
    executor.run_multiple_instructions(
        dataset_filename,
        log_filename,
        message_log_filename,
        num_agents,
        start_data_index,
        end_data_index,
        defense_type,
        counterattack,
    )


def print_success_rates(
    log_filename: str, num_agents: int, communication_mode
) -> None:
    csv = pd.read_csv(log_filename)
    csv = csv[csv["communication_mode"] == communication_mode]
    for i in range(2, num_agents + 1):
        success_rate = csv[str(i)].sum() / len(csv)
        print(f"Agent {i} success rate: {success_rate}")


def run_threat_experiments(
    llm_type: str,
    communication_mode: str,
    prompt_infection_mode: str,
    threat_type: str,
    task_type: str,
    num_agents: int,
    log_filename: Optional[str],
    message_log_filename: Optional[str],
    dataset_filename: Optional[str],
    start_data_index: int = 0,
    end_data_index: int = 20,
    defense_type: str = None,
    override_log: bool = False,
    counterattack: bool = False,
) -> None:
    if threat_type in [
        SCAM,
        MANIPULATED_CONTENT,
        MALWARE,
        HEALTHY_QA,
        DATA_THEFT,
    ]:
        num_agents_list = range(2, num_agents + 1)
    else:
        num_agents_list = [1]

    for current_num_agents in num_agents_list:
        first_agents, last_agent = get_first_and_last_agents(
            threat_type, task_type
        )
        agent_configs = get_agent_configs(
            llm_type,
            first_agents,
            last_agent,
            current_num_agents,
            defense_type,
        )

        executor = GraphExecutor(
            agent_configs,
            llm_type,
            communication_mode,
            prompt_infection_mode,
            threat_type,
            override_log,
            tool_type=task_type,
        )

        run_experiment(
            executor,
            dataset_filename,
            log_filename,
            message_log_filename,
            current_num_agents,
            start_data_index,
            end_data_index,
            defense_type,
            counterattack,
        )


def get_tool_types_per_threats(threat_type):
    if threat_type in [
        DISINFORMATION_ARC,
        DISINFORMATION_QA,
        HEALTHY_ARC,
        HEALTHY_QA,
    ]:
        return [WEB]
    return [EMAIL, PDF, WEB]


def run_process(
    llm_type=GPT4O,
    prompt_infection_mode=PARALLEL,
    communication_mode=ALL_MESSAGES,
    override_log=False,
    start_data_index=0,
    end_data_index=20,
    threat_types=None,
    tool_types=None,
    num_agents=5,
    defense_type=None,
    counterattack=False,
):
    parent_dir = os.path.dirname(os.path.dirname(__file__))

    # Run experiments for each threat type and task type
    for threat_type in threat_types:
        if tool_types is None:
            tool_types = get_tool_types_per_threats(threat_type)

        pretty_print(threat_type, color="red")
        pretty_print(communication_mode, color="white")
        if defense_type:
            pretty_print(defense_type, color="white")
        for task_type in tool_types:
            pretty_print(task_type, color="magenta", size="tiny")

            log_directory = os.path.join(
                parent_dir, f"logs/{task_type}/{prompt_infection_mode}"
            )
            os.makedirs(log_directory, exist_ok=True)
            log_filename = os.path.join(
                log_directory, f"{threat_type}_new.csv"
            )

            # log_filename = None
            message_log_filename = os.path.join(
                log_directory, f"{threat_type}_messages_new.csv"
            )

            if defense_type:
                log_filename = os.path.join(
                    log_directory,
                    "defense",
                    f"{threat_type}.csv",
                )
                os.makedirs(
                    os.path.join(log_directory, "defense"), exist_ok=True
                )
                message_log_filename = None

            if threat_type in [HEALTHY_QA]:
                dataset_filename = f"datasets/{task_type}/{threat_type}.csv"
            else:
                dataset_filename = f"datasets/{task_type}/{prompt_infection_mode}/{threat_type}.csv"

            run_threat_experiments(
                llm_type=llm_type,
                communication_mode=communication_mode,
                prompt_infection_mode=prompt_infection_mode,
                threat_type=threat_type,
                task_type=task_type,
                num_agents=num_agents,
                log_filename=log_filename,
                message_log_filename=message_log_filename,
                dataset_filename=dataset_filename,
                override_log=override_log,
                start_data_index=start_data_index,
                end_data_index=end_data_index,
                defense_type=defense_type,
                counterattack=counterattack,
            )


def run_self_replication_experiments():
    models = [GPT4O, GPT3_5]
    threat_types = [
        SCAM,
        MALWARE,
        MANIPULATED_CONTENT,
        DATA_THEFT,
    ]
    infection_modes = [RECURSIVE, PARALLEL]
    communication_modes = [LATEST_MESSAGE, ALL_MESSAGES]

    for model in models:
        for communication_mode in communication_modes:
            processes = []
            for infection_mode in infection_modes:
                for tool_type in [EMAIL, PDF, WEB]:
                    pretty_print(infection_mode, color="yellow")
                    p = multiprocessing.Process(
                        target=run_process,
                        kwargs={
                            "llm_type": model,
                            "prompt_infection_mode": infection_mode,
                            "communication_mode": communication_mode,
                            "override_log": True,
                            "start_data_index": 0,
                            "end_data_index": 34,
                            "threat_types": threat_types,
                            "tool_types": [tool_type],
                            "num_agents": 5,
                            "defense_type": "",
                        },
                    )
                    processes.append(p)
                    p.start()

            # Wait for all processes to complete
            for p in processes:
                p.join()

    print("All processes completed.")


def run_defense_experiments():
    models = [GPT4O, GPT3_5]
    threat_types = [MODEL_INFECTION]
    infection_modes = [RECURSIVE]
    communication_modes = [LATEST_MESSAGE]
    defense_types = [
        NO_DEFENSE,
        SANDWICH,
        INSTRUCTION_DEFENSE,
        MARKING,
        RANDOM_SEQUENCE_ENCLOSURE,
        DELIMITING_DATA,
    ]

    for model in models:
        for defense_type in defense_types:
            for use_llm_tagging in [True, False]:
                if use_llm_tagging:
                    defense_type = defense_type + "_MODEL_DELIMITER"
                for communication_mode in communication_modes:
                    processes = []
                    for infection_mode in infection_modes:
                        for tool_type in [EMAIL, PDF, WEB]:
                            pretty_print(infection_mode, color="yellow")
                            p = multiprocessing.Process(
                                target=run_process,
                                kwargs={
                                    "llm_type": model,
                                    "prompt_infection_mode": infection_mode,
                                    "communication_mode": communication_mode,
                                    "override_log": True,
                                    "start_data_index": 0,
                                    "end_data_index": 40,
                                    "threat_types": threat_types,
                                    "tool_types": [tool_type],
                                    "num_agents": 5,
                                    "defense_type": defense_type,
                                },
                            )
                            processes.append(p)
                            p.start()

                    # Wait for all processes to complete
                    for p in processes:
                        p.join()

        print("All processes completed.")


if __name__ == "__main__":
    run_self_replication_experiments()
    run_defense_experiments()
