from argparse import Namespace

import torch

from mas_sat.graph.base import BaseGraph
from mas_sat.learn.base import BaseLearner

class ReinforceLearner(BaseLearner):
    def __init__(
        self,
        model: torch.nn.Module,
        device: torch.device,
        writer: torch.utils.tensorboard.SummaryWriter,
        args: Namespace
    ) -> None:
        super().__init__(model, device, writer, args)

        # additional hyper-parameters
        self.gamma = args.gamma

    def clear(self):
        self.log_probs = []
        self.rewards = []
        self.terminals = []

    def add_transition(
        self,
        graph: BaseGraph,
        action_idx: int,
        reward: float,
        terminal: bool,
        original_graph: BaseGraph,
        ret_dict: dict
    ):
        if "log_prob" in ret_dict:
            self.log_probs.append(ret_dict["log_prob"])
            self.rewards.append(reward)
            self.terminals.append(terminal)

    # internal methods for REINFORCE algorithm
    def get_discounted_rewards(self) -> torch.Tensor:
        discounted_rewards = []
        R = 0
        for reward, terminal in zip(reversed(self.rewards), reversed(self.terminals)):
            if terminal:
                R = 0
            R = reward + self.gamma * R
            discounted_rewards.insert(0, R)
        return torch.tensor(discounted_rewards, device=self.device)

    def get_loss(self):
        discounted_rewards = self.get_discounted_rewards()
        log_probs = torch.stack(self.log_probs)
        advantages = discounted_rewards - discounted_rewards.mean()
        heuristic_loss = -(log_probs * advantages).mean()
        self.writer.add_scalar("heuristic_loss", heuristic_loss.item(), self.counter)
        return heuristic_loss

    def learn(self) -> int:
        if len(self.log_probs) == 0:
            return 0
        else:
            return super().learn()
