from argparse import Namespace

import torch

from mas_sat.graph.base import BaseGraph
from mas_sat.learn.base import BaseLearner

class StandaloneLearner(BaseLearner):
    def __init__(
        self,
        model: torch.nn.Module,
        device: torch.device,
        writer: torch.utils.tensorboard.SummaryWriter,
        args: Namespace
    ) -> None:
        super().__init__(model, device, writer, args)
    
    def clear(self) -> None:
        self.scores = []

    def add_transition(
        self,
        graph: BaseGraph,
        action: int,
        reward: float,
        terminal: bool,
        original_graph: BaseGraph,
        ret_dict: dict
    ):
        if original_graph is not None:
            latent = original_graph[self.model.latent_name].latent
            _, score, _ = self.model.decode_assignment(original_graph, latent)
            ret_dict = {"scores": [score]}
        self.scores += ret_dict["scores"]

    def get_loss(self):
        assignment_loss = torch.stack(self.scores).mean()
        self.writer.add_scalar("assignment_loss/train", assignment_loss.item(), self.counter)
        return assignment_loss