from abc import abstractmethod
from argparse import Namespace

import torch
from torch.utils.tensorboard import SummaryWriter

from mas_sat.graph.base import BaseGraph

class BaseLearner(object):
    def __init__(
        self,
        model: torch.nn.Module,
        device: torch.device,
        writer: SummaryWriter,
        args: Namespace,
    ) -> None:
        # components
        self.model = model
        self.device = device
        self.writer = writer
        self.optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, args.learn_step)

        # hyper-parameters
        self.grad_clip = args.grad_clip
        self.grad_clip_norm_type = args.grad_clip_norm_type

        # internal states
        self.counter = 0
        self.clear()

    # basic get/set functions
    def get_counter(self) -> int:
        return self.counter

    # internal methods
    @abstractmethod
    def get_loss(self) -> torch.Tensor:
        """
        Action: compute loss, record to writer
        Return: scalar loss
        """
        raise NotImplementedError

    def optim(self, loss: torch.Tensor) -> tuple[float, float]:
        """
        Input: scalar loss
        Action: back-propagate, clip gradient, step the optimizer and scheduler
        Return: learning rate and gradient norm
        """
        self.optimizer.zero_grad()
        loss.backward()
        lr = self.scheduler.get_last_lr()[0]
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            self.grad_clip,
            norm_type=self.grad_clip_norm_type
        )
        self.optimizer.step()
        self.scheduler.step()
        return lr, grad_norm.item()
    
    @abstractmethod
    def clear(self):
        """
        Action: clean up after training
        """
        raise NotImplementedError

    def step(self):
        """
        Action: perform one step of training
        """
        self.counter += 1
        loss = self.get_loss()
        lr, grad_norm = self.optim(loss)
        self.writer.add_scalar("lr", lr, self.counter)
        self.writer.add_scalar("grad_norm", grad_norm, self.counter)

    # main interfaces
    def get_eps(self) -> float:
        """
        by default eps = 0.0 (non-random)
        for DQN, will use epsilon-greedy to tradeoff exploration and exploitation
        """
        return 0.0
    
    @abstractmethod
    def add_transition(
        self,
        graph: BaseGraph,
        action_idx: int,
        reward: float,
        terminal: bool,
        original_graph: BaseGraph,
        ret_dict: dict
    ):
        """
        This function will be called after each transistion
        """
        raise NotImplementedError

    def learn(self) -> int:
        """
        by default this will only run one step
        for DQN, this may run multiple steps
        """
        self.step()
        self.clear()
        return 1
