# print(__doc__)

# matplotlib.use('TkAgg')

import heapq
import numpy as np
import math


class FacilityLocation:
    def __init__(self, D, V, alpha=1.0, gamma=0.0):
        """
        Args
        - D: np.array, shape [N, N], similarity matrix
        - V: list of int, indices of columns of D
        - alpha: float
        """
        self.D = D
        self.curVal = 0
        self.curMax = np.zeros(len(D))
        self.gains = []
        self.alpha = alpha
        self.f_norm = self.alpha / self.f_norm(V)
        self.norm = 1.0 / self.inc(V, [])
        self.gamma = gamma / len(self.D)  # encouraging diversity

    def f_norm(self, sset):
        return self.D[:, sset].max(axis=1).sum()

    def inc(self, sset, ndx):
        if len(sset + [ndx]) > 1:
            if not ndx:  # normalization
                return math.log(1 + self.alpha * 1)
            return (
                self.norm
                * math.log(
                    1
                    + self.f_norm
                    * (
                        np.maximum(self.curMax, self.D[:, ndx]).sum()
                        - self.gamma * self.D[sset + [ndx]][:, sset + [ndx]].sum()
                    )
                )
                - self.curVal
            )
        else:
            return (
                self.norm * math.log(1 + self.f_norm * self.D[:, ndx].sum())
                - self.curVal
            )

    def add(self, sset, ndx):
        cur_old = self.curVal
        if len(sset + [ndx]) > 1:
            self.curMax = np.maximum(self.curMax, self.D[:, ndx])
        else:
            self.curMax = self.D[:, ndx]
        self.curVal = self.norm * math.log(
            1
            + self.f_norm
            * (
                self.curMax.sum()
                - self.gamma * self.D[sset + [ndx]][:, sset + [ndx]].sum()
            )
        )
        self.gains.extend([self.curVal - cur_old])
        return self.curVal


def lazy_greedy(F, ndx, B):
    """
    Args
    - F: FacilityLocation
    - ndx: indices of all points
    - B: int, number of points to select
    """
    TOL = 1e-6
    eps = 1e-15
    curVal = 0
    sset = []
    order = []
    vals = []
    for v in ndx:
        marginal = F.inc(sset, v) + eps
        heapq.heappush(order, (1.0 / marginal, v, marginal))

    not_selected = []
    while order and len(sset) < B:
        el = heapq.heappop(order)
        if not sset:
            improv = el[2]
        else:
            improv = F.inc(sset, el[1]) + eps
            # print(improv)

        # check for uniques elements
        if improv > 0 + eps:
            if not order:
                curVal = F.add(sset, el[1])
                # print curVal
                # print str(len(sset)) + ', ' + str(el[1])
                sset.append(el[1])
                vals.append(curVal)
            else:
                top = heapq.heappop(order)
                if improv >= top[2]:
                    curVal = F.add(sset, el[1])
                    # print curVal
                    # print str(len(sset)) + ', ' + str(el[1])
                    sset.append(el[1])
                    vals.append(curVal)
                else:
                    heapq.heappush(order, (1.0 / improv, el[1], improv))
                heapq.heappush(order, top)
        else: # save the unselected items in order in a list
            not_selected.append(el[1])
    # if the number of item selected is less than desired, add items from the unselected item list
    if len(sset) < B:
        num_add = B - len(sset)
        sset.extend(not_selected[:num_add])
    return sset, vals

