from dataclasses import dataclass
from typing import Any

from konductor.data import get_dataset_properties
from konductor.models import MODEL_REGISTRY, ExperimentInitConfig
from konductor.models._pytorch import TorchModelConfig

from . import sc2_perceiver
from .goal_perceiver import GoalPerciever
from .motion_perceiver import MotionPerceiver


@dataclass
@MODEL_REGISTRY.register_module("motion-perceiver")
class MotionPerceiverConfig(TorchModelConfig):
    encoder: dict[str, Any]
    decoder: dict[str, Any]

    @classmethod
    def from_config(cls, config: ExperimentInitConfig, idx: int = 0) -> Any:
        props = get_dataset_properties(config)
        model_cfg = config.model[idx].args

        sz = props["occupancy"].shape
        dec_adapt = model_cfg["decoder"]["adapter"]
        dec_adapt["args"]["image_shape"] = [sz, sz]

        # If target classes aren't defined then add from dataset
        if dec_adapt["type"] == "class-occupancy" and "names" not in dec_adapt["args"]:
            dec_adapt["args"]["names"] = props["classes"]

        if model_cfg["encoder"].get("type") == "transformer":
            model_cfg["encoder"]["max_time"] = props["n_iter"]

        return super().from_config(config, idx)

    def get_instance(self, *args, **kwargs) -> Any:
        return self._apply_extra(MotionPerceiver(self.encoder, self.decoder))


@dataclass
@MODEL_REGISTRY.register_module("goal-perceiver")
class GoalPerceiverConfig(TorchModelConfig):
    encoder: dict[str, Any]
    decoder: dict[str, Any]
    input_adapter: dict[str, Any]
    output_adapter: dict[str, Any]

    @classmethod
    def from_config(cls, config: ExperimentInitConfig, idx: int = 0) -> Any:
        props = get_dataset_properties(config)
        model_cfg = config.model[idx].args
        encoder_cfg = model_cfg["encoder"]
        if encoder_cfg["type"] == "transformer":
            encoder_cfg["max_time"] = props["n_iter"]

        return super().from_config(config, idx)

    def get_instance(self, *args, **kwargs) -> Any:
        known_unused = set(TorchModelConfig.__dataclass_fields__.keys())
        inst = self.init_auto_filter(GoalPerciever, known_unused)
        return self._apply_extra(inst)
