import torch
from dataclasses import dataclass

class Problem(object):
    NAME = None
    single_solution_set = None

    @staticmethod
    def get_costs(input, pi):
        pass

    @staticmethod
    def make_state(*args, **kwargs):
        raise NotImplementedError("define make_state!")


@dataclass
class State:
    adj: list
    embeddings: torch.Tensor

    visited: torch.ByteTensor
    sol_ind: torch.IntTensor
    sol_rep: object

    prev_a: torch.IntTensor

    step: torch.Tensor

    @staticmethod
    def concat_embeddings():
        """
        @return True or False
        """
        raise NotImplementedError("define concat_embeddings")

    def initialize(self, *args):
        raise NotImplementedError("define initialize!")

    # the transition function
    def update(self, selected):
        raise NotImplementedError("define update function!")

    def is_done(self):
        raise NotImplementedError("define termination criterion!")

    def get_cost(self):
        raise NotImplementedError("define a cost function!")


def load_problem(name, *args):
    from problems import GC, MVC, GP
    from problems.GC.problem_gc_defective import DefectiveGC
    problem = {
        'GC': GC,
        'DefectiveGC': DefectiveGC(*args),
        'MVC': MVC,
        'GP': GP(*args)
    }.get(name, None)
    assert problem is not None, "Currently unsupported problem: {}!".format(name)
    return problem