from dataclasses import dataclass, astuple
import yaml


@dataclass(frozen=True)
class Concept:
    name: str
    agent_id: str
    full_name: str
    start_idx: int
    end_idx: int
    length: int
    concept_type: str
    opponents: list


@dataclass()
class Concepts:
    agent_id: str
    configs: list
    tom_configs: dict
    total_length: int


def concept_config(args):
    with open(args.concept_yaml, "r") as stream:
        concept_dict = yaml.safe_load(stream)

    start_index = {i: 0 for i in concept_dict["all_agent_ids"]}
    return_dict = {}
    if args.scenario == "fort_attack":
        for i, concept_name in enumerate(args.concepts):
            current_config = concept_dict[concept_name]
            current_length = current_config["length"]
            affected_agents = current_config["affected_agents"]
            # config for current concept
            for j, agent_id in enumerate(affected_agents):
                if agent_id not in return_dict:
                    return_dict[agent_id] = []
                current_targeted_list = current_config["targeted_agents"][j]
                for targeted_agents in current_targeted_list:
                    current_concept = Concept(
                        name=concept_name,
                        agent_id=agent_id,
                        full_name=f"{concept_name}_agent_{agent_id}",
                        start_idx=start_index[agent_id],
                        end_idx=start_index[agent_id] + current_length,
                        concept_type=current_config["type"],
                        length=current_length,
                        opponents=targeted_agents,
                    )
                    return_dict[agent_id].append(current_concept)
                    start_index[agent_id] += current_length

        return_dict_2 = {}
        for agent_id in concept_dict["all_agent_ids"]:
            if agent_id in return_dict:
                concepts_id = Concepts(
                    agent_id=agent_id,
                    configs=return_dict[agent_id],
                    total_length=start_index[agent_id],
                    tom_configs={},
                )
            else:
                concepts_id = Concepts(
                    agent_id=agent_id, configs=[], total_length=0, tom_configs={}
                )

            return_dict_2[agent_id] = concepts_id

        # Add TOM configs
        tom_config = concept_dict["ToM"]
        for agent_id in return_dict_2:
            if agent_id in tom_config:
                tom_config_i = tom_config[agent_id]
                for modeled_id in tom_config_i:
                    return_dict_2[agent_id].tom_configs[modeled_id] = {
                        "configs": return_dict_2[modeled_id].configs,
                        "total_length": return_dict_2[modeled_id].total_length,
                    }

        return_dict_2["guard"] = return_dict_2[concept_dict["representative guard"]]
        return_dict_2["attacker"] = return_dict_2[
            concept_dict["representative attacker"]
        ]

    return return_dict_2
