from src.xlogomini.components.goal.objective import Objective


class Goal(object):
    def __init__(self, objs):
        assert isinstance(objs, dict)
        self.objs = objs
        self.n_objs = len(self.objs)

    @classmethod
    def init_from_json(cls, goal_json):
        objs = {}

        for obj_js in goal_json:
            obj = Objective.init_from_json(obj_js)
            if obj.obj_name not in objs.keys():
                objs[obj.obj_name] = [obj]
            else:
                objs[obj.obj_name].append(obj)
        return cls(objs)

    def _get_list_of_obj_names(self):
        return self.objs.keys()

    def get_cnfs(self):
        """
        Return a list of target cnfs and a list of forbidden cnfs.
        """
        cnfs = {}
        for name in self.objs.keys():
            if isinstance(self.objs[name], list):
                cnfs[name] = [s.cnf for o in self.objs[name] for s in o.specs]
            else:
                cnfs[name] = [s.cnf for s in self.objs[name].specs]
        return cnfs

    def to_json(self):
        if 'draw' in self._get_list_of_obj_names():
            return None
        else:
            return [obj.to_json() for k in self.objs.keys() for obj in self.objs[k]]

    def toPytorchTensor(self):
        import torch as th
        goal_tensor = []
        for objs in self.objs.values():
            for obj in objs:
                goal_tensor.append(obj.toPytorchTensor())
        while True:
            if len(goal_tensor) < 3:
                goal_tensor.append(th.zeros_like(goal_tensor[0]))
            else:
                break
        goal_tensor = th.concat(goal_tensor)
        assert goal_tensor.shape[0] == 564
        return goal_tensor

    def __getitem__(self, obj_name):
        if obj_name not in self.objs.keys():
            raise ValueError(f"{obj_name} not exists")
        return self.objs[obj_name]

    def __len__(self):
        return len(self.objs)

    def __repr__(self):
        goal_str = ''
        for obj in self.objs:
            # goal_str += f"==== objective === \n"
            if isinstance(self.objs[obj], list):
                for obj in self.objs[obj]:
                    if obj.obj_name == 'forbid' and 'without' in goal_str:
                        obj_str = str(obj)
                        obj_str = obj_str.replace('without standing on a',
                                                  'and')  # replace duplicated `without crossing` if already exists
                        goal_str += f"{obj_str} "
                    else:
                        goal_str += f"{obj} "
            else:
                goal_str += f"{self.objs[obj]}"
        return goal_str.strip()

    def __hash__(self):
        return hash(self.__repr__())

    def __eq__(self, other):
        if isinstance(other, Goal):
            # equal length
            if self.n_objs != other.n_objs:
                return False
            # equal objectives
            for obj_name in self.objs.keys():
                for i in range(len(self.objs[obj_name])):
                    if self.objs[obj_name][i] != other.objs[obj_name][i]:
                        return False
            return True
        else:
            return False
