import numpy as np


class VisitStrategy:
    def __init__(self, ontology, batch_size=None):
        self.batch_size = batch_size
        self.position = 0
        self.nodes = []
        self.visited = []
        self.ontology = ontology

    def get_next_batch(self):
        if self.batch_size is None:
            self.batch_size = len(self.nodes)

        batch = self.nodes[self.position:self.position+self.batch_size]
        self.position += self.batch_size
        self.visited += batch

        return batch

    def init_frontier(self):
        return self.get_next_batch()

    def update_frontier(self, intersection, union, act_sum, max_iou):
        return self.get_next_batch()


class Flat(VisitStrategy):
    def __init__(self, ontology, batch_size=None):
        super().__init__(ontology, batch_size)
        self.nodes = [c for c in self.ontology.to_list()
                      if not c.is_placeholder()]
        self.max_n = len(self.nodes)


class Leaves(VisitStrategy):
    def __init__(self, ontology, batch_size=None):
        super().__init__(ontology, batch_size)
        self.nodes = list(ontology.get_leaves())
        self.max_n = len(self.nodes)


class TopDown(VisitStrategy):
    def __init__(self, ontology, batch_size=None):
        super().__init__(ontology, batch_size)
        self.nodes = ontology.to_list(style='BFS', max_length=batch_size)
        self.intersection = None

    def update_frontier(self, intersection, union, act_sum, max_iou):
        # Concatenate results for the upper bound
        if self.intersection is None:
            self.intersection = intersection
        else:
            self.intersection = np.concatenate((self.intersection,
                                                intersection), axis=1)

        # Check new batch
        next_batch = self.get_next_batch()
        if next_batch:
            return next_batch

        # Compute upper bound for each unit
        # in respect of each concept
        upper_bound = self.intersection / act_sum[:, None]

        new_nodes = []
        for i in range(len(self.nodes)):
            # A concept is added to the next frontier
            # if its hypernym has an upper bound over
            # the maximum IoU for at least one unit
            # NOTE: this could be done by filtering
            #       out units which do not surpass it
            if np.any(upper_bound[:, i] > max_iou):
                for hyponym in self.nodes[i].hyponyms:
                    if hyponym not in self.visited:
                        new_nodes.append(hyponym)

        self.nodes = new_nodes
        self.position = 0

        return self.get_next_batch()


class BottomUp(VisitStrategy):
    def __init__(self, ontology, batch_size=None):
        super().__init__(ontology, batch_size)
        self.nodes = list(ontology.get_leaves())
        self.max_n = ontology.n
        self.union = None

    def update_frontier(self, intersection, union, act_sum, max_iou):
        # Concatenate results for the upper bound
        if self.union is None:
            self.union = union
        else:
            self.union = np.concatenate((self.union, union), axis=1)

        # Check new batch
        next_batch = self.get_next_batch()
        if next_batch:
            return next_batch

        # Compute upper bound for each unit
        # in respect of each concept
        upper_bound = act_sum[:, None] / self.union

        new_nodes = []
        for i in range(len(self.nodes)):
            # A concept is added to the next frontier
            # if one of its hyponym has an upper bound
            # over the maximum IoU for at least one unit
            # NOTE: the minimum upper bound for all the
            #       hyponyms could be checked
            # NOTE: this could be done by filtering
            #       out units which do not surpass it
            if np.any(upper_bound[:, i] > max_iou):
                for hypernym in self.nodes[i].hypernyms:
                    if hypernym not in self.visited \
                            and hypernym not in new_nodes:
                        new_nodes.append(hypernym)

        self.nodes = new_nodes
        self.position = 0
        self.union = None

        return self.get_next_batch()
