import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig


from diffusion_co_design.common.rl.mappo import MAPPOCoDesign
from diffusion_co_design.common.design import DesignerParams
from diffusion_co_design.rware.static import GROUP_NAME
import diffusion_co_design.rware.schema as schema
import diffusion_co_design.rware.design as design
from diffusion_co_design.rware.env import create_env
from diffusion_co_design.rware.model.rl import rware_models


class Trainer(
    MAPPOCoDesign[
        schema.DesignerConfig,
        schema.ScenarioConfig,
        schema.ActorCriticConfig,
        schema.TrainingConfig,
    ]
):
    def __init__(self, cfg: schema.TrainingConfig):
        super().__init__(cfg, f"diffusion-co-design-rware-{cfg.scenario_name}")

    def create_designer(self, scenario, designer, ppo, artifact_dir, device):
        return design.create_designer(
            scenario=scenario,
            designer=designer,
            ppo_cfg=ppo,
            artifact_dir=artifact_dir,
            device=device,
        )

    def create_env(self, mode, scenario, designer, device, render=False):
        return create_env(
            mode=mode,
            scenario=scenario,
            designer=designer,
            representation=self.cfg.designer.representation,
            device=device,
            render=render,
        )

    def create_actor_critic_models(self, reference_env, actor_critic_config, device):
        return rware_models(
            env=reference_env,
            cfg=actor_critic_config,
            device=device,
        )

    def create_placeholder_designer(self, scenario):
        return design.RandomDesigner(
            designer_setting=DesignerParams.placeholder(scenario=scenario),
            representation=self.cfg.designer.representation,
        )

    @property
    def group_name(self):
        return GROUP_NAME


@hydra.main(version_base=None, config_path="conf", config_name="random")
def run(config: DictConfig):
    print(f"Running job {HydraConfig.get().job.name}")
    trainer = Trainer(schema.TrainingConfig.from_raw(config))
    trainer.run()


if __name__ == "__main__":
    run()
