from argparse import Namespace

import torch

from mas_sat.graph.base import BaseGraph
from mas_sat.learn.reinforce import ReinforceLearner
from mas_sat.learn.standalone import StandaloneLearner

class MultiLearner(ReinforceLearner, StandaloneLearner):
    def __init__(
        self,
        model: torch.nn.Module,
        device: torch.device,
        writer: torch.utils.tensorboard.SummaryWriter,
        args: Namespace
    ) -> None:
        ReinforceLearner.__init__(self, model, device, writer, args)
        StandaloneLearner.__init__(self, model, device, writer, args)

        # hyper-parameters
        self.heuristic_loss_weight = args.heuristic_loss_weight
        self.assignment_loss_weight = args.assignment_loss_weight

    def clear(self) -> None:
        ReinforceLearner.clear(self)
        StandaloneLearner.clear(self)

    def add_transition(
        self,
        graph: BaseGraph,
        action_idx: int,
        reward: float,
        terminal: bool,
        original_graph: BaseGraph,
        ret_dict: dict
    ):
        ReinforceLearner.add_transition(self, graph, action_idx, reward, terminal, original_graph, ret_dict)
        StandaloneLearner.add_transition(self, graph, action_idx, reward, terminal, original_graph, ret_dict)

    def get_loss(self):
        heuristic_loss = ReinforceLearner.get_loss(self)
        assignment_loss = StandaloneLearner.get_loss(self)
        loss = \
            self.heuristic_loss_weight * heuristic_loss + \
            self.assignment_loss_weight * assignment_loss
        self.writer.add_scalar("total_loss", loss.item(), self.counter)
        return loss