from dataclasses import dataclass

import torch
from konductor.data import get_dataset_config
from konductor.init import ExperimentInitConfig
from konductor.models import get_model
from torch import Tensor, nn

from src.dataset.chasing_targets import GymPredictConfig
from src.model.motion_perceiver import MotionPerceiver


@dataclass(slots=True)
class EnvResult:
    """Trial outcome for environment"""

    wins: int = 0
    losses: int = 0
    draws: int = 0


def scenairo_id_tensor_2_str(batch_ids: Tensor) -> list[str]:
    """Convert scenario id tensor to list of strings"""
    return [
        "".join([chr(c) for c in id_chars]).rstrip("\x00") for id_chars in batch_ids
    ]


def load_model(exp_cfg: ExperimentInitConfig) -> nn.Module:
    """Create model from experiment and load latest weights"""
    model: nn.Module = get_model(exp_cfg).cuda()
    ckpt = torch.load(
        exp_cfg.exp_path / "latest.pt",
        map_location=f"cuda:{torch.cuda.current_device()}",
        weights_only=True,
    )["model"]
    model.load_state_dict(ckpt)
    return model


def initialize(
    exp_cfg: ExperimentInitConfig,
) -> tuple[MotionPerceiver, GymPredictConfig]:
    """Initialise model and dataset for prediction export"""

    data_cfg = get_dataset_config(exp_cfg)

    model = load_model(exp_cfg).eval()
    if isinstance(model, MotionPerceiver):
        model.encoder.random_input_indicies = 0

    return model, data_cfg
