from abc import abstractmethod

import torch
from torch import nn

from mas_sat.model.mlp import get_mlp

class BaseModel(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        if args.graph in ["vcg", "vcgl"]:
            self.latent_name = "variable"
            self.out_dim = 2
        else:
            raise ValueError(f"Unrecognized graph: {args.graph}")

        # hyperparameters
        self.n_step = args.num_step
        self.recurrent = args.recurrent
        self.grad_alpha = args.grad_alpha
        self.heuristic = args.head in ["heuristic", "multi"]
        self.assignment = args.head in ["assignment", "multi"]

        # decoders
        self.decoder_heuristic = nn.Sequential(
            get_mlp(args.dim, args.dim//2, 0),
            nn.Linear(args.dim//2, self.out_dim)
        ) if self.heuristic else None
        self.decoder_assignment = nn.Sequential(
            get_mlp(args.dim, args.dim//2, 0),
            nn.Linear(args.dim//2, self.out_dim)
        ) if self.assignment else None

    def set_step(self, step: int):
        self.n_step = step

    # basic operations
    @abstractmethod
    def get_latent_dict(self, graph):
        pass

    @abstractmethod
    def get_edge_index_dict(self, graph):
        pass

    @abstractmethod
    def get_kwargs(self, graph):
        """
        get additional kwargs other than latent_dict
        """
        pass

    @abstractmethod
    def step(self, latent_dict, edge_index_dict, **kwargs):
        """
        return updated latent_dict
        """
        pass

    def reduce_gradient(self, latent_dict, alpha):
        return {k: alpha * v + (1-alpha) * v.detach() for k, v in latent_dict.items()}

    def update_graph(self, graph, latent_dict):
        for k, v in latent_dict.items():
            graph[k].latent = v

    def decode_heuristic(self, latent):
        return self.decoder_heuristic(latent).flatten()

    def decode_assignment(self, graph, latent):
        assignment = self.decoder_assignment(latent)
        score = graph.get_score(assignment)
        hard_scores = graph.get_score(assignment, hard=True)
        solved = hard_scores > 0.5
        return assignment.flatten(), score, solved

    # forward function
    def forward(self, graph):
        latent_dict = self.get_latent_dict(graph)
        edge_index_dict = self.get_edge_index_dict(graph)
        kwargs = self.get_kwargs(graph)

        n_step = 0
        scores = []
        heuristic = None
        solution = None
        for i_step in range(self.n_step):
            latent_dict = self.step(latent_dict, edge_index_dict, **kwargs)
            n_step += 1

            # for assignment, decode at every step
            if self.assignment:
                assignment, score, solved = self.decode_assignment(graph, latent_dict[self.latent_name])
                scores.append(score)
                if torch.all(solved):
                    if len(solved) == 1:
                        in_solution = (assignment.reshape(-1, 2).softmax(1) > 0.5).flatten()
                        solution = graph.get_candidate_indices()[in_solution]
                    break

            # for heuristic, decode at the last step
            if self.heuristic and i_step + 1 == self.n_step:
                heuristic = self.decode_heuristic(latent_dict[self.latent_name])
        
            # reduce the gradient norm
            if self.recurrent:
                latent_dict = self.reduce_gradient(latent_dict, self.grad_alpha)
 
        # prepare ret_dict
        ret_dict = {
            "model_step": n_step,
            "solution": solution
        }
        if self.recurrent:
            self.update_graph(graph, latent_dict)
            ret_dict["updated_graph"] = graph
        if self.heuristic:
            ret_dict["heuristic"] = heuristic
        if self.assignment:
            ret_dict["assignment"] = assignment
            if not self.heuristic:
                ret_dict["heuristic"] = assignment
            ret_dict["solved"] = solved.sum().item()
            ret_dict["scores"] = torch.stack(scores)
        return ret_dict
