from argparse import Namespace
import copy
from tqdm import tqdm

import torch
from torch import nn

from mas_sat.graph.base import BaseGraph
from mas_sat.learn.buffer import ReplayBuffer
from mas_sat.learn.base import BaseLearner
from mas_sat.utils.scatter import scatter_reduce

class DQNLearner(BaseLearner):
    def __init__(
        self,
        model: torch.nn.Module,
        device: torch.device,
        writer: torch.utils.tensorboard.SummaryWriter,
        args: Namespace
    ) -> None:
        super().__init__(model, device, writer, args)

        # additional components
        self.target_model = copy.deepcopy(model)
        self.target_model.eval()
        self.buffer = ReplayBuffer(args)
        self.loss_fn = nn.MSELoss()

        # additional hyper-parameters
        self.gamma = args.gamma
        self.learn_interval = args.learn_interval
        self.target_update_interval = args.target_update_interval
        self.exploration_step = args.exploration_step
        self.eps_decay_step = args.eps_decay_step
        self.eps_init = args.eps_init
        self.eps_final = args.eps_final
        self.slope = (self.eps_final - self.eps_init) / (self.eps_decay_step - self.exploration_step)

        # additional internal states
        self.transition_counter = 0
        self.learn_pending = 0

    def state_dict(self) -> dict:
        state_dict = super().state_dict()
        state_dict["transition_counter"] = self.transition_counter
        return state_dict

    def load_state_dict(self, state_dict) -> None:
        super().load_state_dict(state_dict)
        self.transition_counter = state_dict["transition_counter"]

    # internal methods
    def get_qs(self, model, s, a=None):
        ret_dict = model(s)
        qs_all = ret_dict["heuristic"]

        if a is None:
            # for target, will get the maximum of all candidates
            batch = s.get_candidate_batch()
            qs = scatter_reduce(qs_all, batch, reduce="amax", dim=0, dim_size=len(s))
        else:
            # for current action, will get the q corresponding to the actions
            qs = qs_all[s.get_candidate_ptr() + a]
        return qs
    
    def get_loss(self) -> torch.Tensor:
        s, a, r, s_next, non_term = self.buffer.sample()
        s, s_next = s.to(self.device), s_next.to(self.device)
        a, r = a.to(self.device), r.to(self.device)
        non_term = non_term.to(self.device)

        with torch.no_grad():
            target_qs = self.get_qs(self.target_model, s_next)
            target_qs = r + non_term * self.gamma * target_qs
        qs = self.get_qs(self.model, s, a)

        self.writer.add_scalar("average_q/target", target_qs.mean().item(), self.counter)
        self.writer.add_scalar("average_q/action", qs.mean().item(), self.counter)
        
        loss = self.loss_fn(qs, target_qs)
        self.writer.add_scalar("heuristic_loss", loss.item(), self.counter)
        return loss

    def clear(self):
        pass # do nothing

    # main interfaces
    def get_eps(self) -> float:
        if self.transition_counter < self.exploration_step:
            eps = self.eps_init
        elif self.transition_counter > self.eps_decay_step:
            eps = self.eps_final
        else:
            eps = self.slope * (self.transition_counter - self.exploration_step) + self.eps_init
        self.writer.add_scalar("eps", eps, self.transition_counter)
        return eps

    def add_transition(
        self,
        graph: BaseGraph,
        action_idx: int,
        reward: float,
        terminal: bool,
        original_graph: BaseGraph,
        ret_dict: dict
    ):
        if not graph.is_trivial():
            self.buffer.add_transition(graph, action_idx, reward, terminal)
            self.transition_counter += 1
            if self.transition_counter == 1:
                self.tbar = tqdm(total=self.exploration_step, initial=1, desc="Exploring")
            elif self.transition_counter < self.exploration_step:
                self.tbar.update()
            elif self.transition_counter == self.exploration_step:
                self.tbar.close()
            if self.transition_counter % self.learn_interval == 0:
                if self.buffer.is_full():
                    self.learn_pending += 1
                elif self.transition_counter > self.exploration_step and self.buffer.is_ready():
                    self.learn_pending += 1

    def learn(self) -> int:
        step = 0
        while self.learn_pending > 0:
            self.step()
            self.learn_pending -= 1
            step += 1
            if self.counter % self.target_update_interval == 0:
                self.target_model.load_state_dict(self.model.state_dict())
        return step