# actor_factory.py

import json5
from collections import defaultdict
from typing import Type, Any, Optional, Dict, List
from itertools import count
import simpy

from organisation.env.clinical_trial.core.actors import (
    Investigator,
    LegalTeam,
    Actor,
    Sponsor,
    Statistician,
)

from organisation.env.clinical_trial.core.external_factors import RegulatoryAgency

from organisation.env.config import TASKS_FILE, ACTORS_FILE
from .incentives import generate_random_incentive


# Mapping from org_role string to Actor class
ROLE_TO_CLASS: dict[str, Type[Actor]] = {
    "Regulatory Agency": RegulatoryAgency,
    "Investigator": Investigator,
    "Legal Team": LegalTeam,
    "Sponsor": Sponsor,
    "Statistician": Statistician,
}


class ActorFactory:
    """
    Factory that reads task tree and returns Actor instances,
    grouping tasks by assigned_org_role.
    """

    def __init__(
        self,
        env: simpy.Environment,
        simulation: object,
        llm_client: Any,
        llm_kwargs: Optional[Dict[str, Any]] = None,
    ):
        self.env = env
        self.simulation = simulation
        self.llm_client = llm_client
        self.llm_kwargs = llm_kwargs
        self._id_counter = count(start=0)

    def load_task_tree(self) -> list[dict]:
        """
        Load the task tree from the specified JSON file.
        """
        with open(TASKS_FILE, "r") as f:
            return json5.load(f)

    def load_actor_specs(self) -> List[dict]:
        """
        Load actor definitions from actors.json.
        Each entry must have at least:
          - "assigned_org_role": the role string
        Optional:
          - "actor_id": if you want to force a specific ID
        """
        with open(ACTORS_FILE, "r") as f:
            return json5.load(f)

    def create_actors(self) -> List[Actor]:
        # 1) group tasks by role
        tasks_by_role: Dict[str, List[dict]] = defaultdict(list)
        for task in self.load_task_tree():
            role = task["assigned_org_role"]
            tasks_by_role[role].append(task)

        # 2) load the list of actor specs (in the order you want them created)
        actor_specs = self.load_actor_specs()

        # 3) build a mapping: role → list of its specs
        specs_by_role: Dict[str, List[dict]] = defaultdict(list)
        for spec in actor_specs:
            role = spec["assigned_org_role"]
            specs_by_role[role].append(spec)

        actors: List[Actor] = []

        # 4) for each role, round-robin its tasks across its specs
        for role, specs in specs_by_role.items():
            cls = ROLE_TO_CLASS.get(role)
            if cls is None:
                raise ValueError(f"Unknown org role in actors.json: {role!r}")

            role_tasks = tasks_by_role.get(role, [])
            n = len(specs)
            if n == 0:
                continue  # no actors of this role

            # split tasks into n sublists by simple slicing:
            task_sublists = [role_tasks[i::n] for i in range(n)]

            # now instantiate each actor in the exact order from actors.json
            for spec, sub_tasks in zip(specs, task_sublists):
                # allow override of actor_id if provided, else use a counter
                actor_id = spec.get("actor_id", next(self._id_counter))
                # Generate a role-specific incentive via LLM
                incentive = generate_random_incentive(role)

                actor = cls(
                    env=self.env,
                    simulation=self.simulation,
                    actor_id=actor_id,  # actor_id from spec or auto
                    llm_client=self.llm_client,  # shared LLM client
                    llm_kwargs=self.llm_kwargs,  # shared LLM kwargs
                    incentive_text=incentive,  # LLM-generated incentive
                    task_dictionary=sub_tasks,  # only this slice of tasks
                )
                actor.external_factor = spec.get("external_factor", False)

                # record birth time
                self.simulation.actor_created[actor.actor_id] = self.env.now
                actors.append(actor)

        # Separate external factors from regular actors
        external_factors = [
            actor for actor in actors if getattr(actor, "external_factor", False)
        ]
        actors = [
            actor for actor in actors if not getattr(actor, "external_factor", False)
        ]

        return actors, external_factors
