print(__doc__)
import matplotlib
# matplotlib.use('TkAgg')

import heapq
import numpy as np
import pandas as pd
import scipy as sp
import math
from scipy import spatial
import matplotlib.pyplot as plt


class WightedCoverage:

    def __init__(self, S, V, W, alpha=1.):
        self.S = S
        self.V = V
        self.W = W
        self.curVal = 0
        self.gains = []
        self.alpha = alpha
        self.f_norm = self.alpha / self.f_norm(V)
        self.norm = 1. / self.inc(V, [])

    def f_norm(self, sset):
        return np.matmul(self.W[sset], self.S[sset, :].sum(axis=1))

    def inc(self, sset, ndx):
        if not ndx:  # normalization
            return math.log(1 + self.alpha * 1)
        return self.norm * math.log(1 + self.f_norm *
                                    np.matmul(self.W[sset + [ndx]], self.S[sset + [ndx], :].sum(axis=1))
                                    ) - self.curVal

    def add(self, sset, ndx):
        cur_old = self.curVal
        self.curVal = self.norm * math.log(1 + self.f_norm *
                                           np.matmul(self.W[sset + [ndx]], self.S[sset + [ndx], :].sum(axis=1))
                                           )
        self.gains.extend([self.curVal - cur_old])
        return self.curVal


class FacilityLocation1:

    def __init__(self, D, V):
        self.D = D
        self.V = V
        self.curVal = 0
        self.gains = []

    def inc(self, sset, ndx):
        if len(sset + [ndx]) > 1:
            return self.D[:, sset + [ndx]].max(axis=1).sum() - self.curVal
        else:
            return self.D[:, sset + [ndx]].sum() - self.curVal

    def add(self, sset, ndx):
        cur_old = self.curVal
        if len(sset + [ndx]) > 1:
            subset_sim = self.D[:, ndx].sum()
            self.curVal = self.D[:, sset + [ndx]].max(axis=1).sum() - self.gamma * subset_sim
        else:
            self.curVal = self.D[:, sset + [ndx]].sum()
        self.gains.extend([self.curVal - cur_old])
        return self.curVal


class FacilityLocation:

    def __init__(self, D, V, alpha=1., gamma=0.1):
        '''
        Args
        - D: np.array, shape [N, N], similarity matrix
        - V: list of int, indices of columns of D
        - alpha: float
        '''
        self.D = D
        self.curVal = 0
        self.curMax = np.zeros(len(D))
        self.gains = []
        self.alpha = alpha
        self.f_norm = self.alpha / self.f_norm(V)
        self.cache_sset = []
        self.cache_sset_sum = 0
        self.norm = 1. / self.inc(V, [])
        self.gamma = gamma / len(self.D) # encouraging diversity

    def f_norm(self, sset):
        return self.D[:, sset].max(axis=1).sum()

    def inc(self, sset, ndx):

        if len(sset + [ndx]) > 1:
            if not ndx:  # normalization
                return math.log(1 + self.alpha * 1)

            # Cache function
            sorted_sset = np.sort(sset)
            if not np.array_equal(sorted_sset, self.cache_sset):
                self.cache_sset_sum = self.D[sset][:, sset].sum()
                self.cache_sset = sorted_sset


            diversity_sum = (self.cache_sset_sum +
                             self.D[ndx, sset + [ndx]].sum() +
                             self.D[sset + [ndx], ndx].sum() -
                             self.D[ndx, ndx])

            return self.norm * math.log(1 + self.f_norm * (
                    np.maximum(self.curMax, self.D[:, ndx]).sum() -
                    # self.gamma * self.D[sset + [ndx]] [:, sset + [ndx]].sum()
                    self.gamma * diversity_sum
                    )) - self.curVal
        else:
            return self.norm * math.log(1 + self.f_norm * self.D[:, ndx].sum()) - self.curVal

    def add(self, sset, ndx):
        sorted_sset = np.sort(sset)
        if not np.array_equal(sorted_sset, self.cache_sset):
            # This shouldn't happen
            diversity_sum = self.D[sset + [ndx]] [:, sset + [ndx]].sum()
        else:
            diversity_sum = (self.cache_sset_sum +
                             self.D[ndx, sset + [ndx]].sum() +
                             self.D[sset + [ndx], ndx].sum() -
                             self.D[ndx, ndx])

        self.cache_sset_sum = diversity_sum
        self.cache_sset = np.sort(sset + [ndx])

        cur_old = self.curVal
        if len(sset + [ndx]) > 1:
            self.curMax = np.maximum(self.curMax, self.D[:, ndx])
        else:
            self.curMax = self.D[:, ndx]
        self.curVal = self.norm * math.log(1 + self.f_norm * (self.curMax.sum()
            # - self.gamma * self.D[sset + [ndx]] [:, sset + [ndx]].sum()
            - self.gamma * diversity_sum
            ))
        self.gains.extend([self.curVal - cur_old])
        return self.curVal


class NormalizedSoftFacilityLocation:

    def __init__(self, D, V, alpha=1.):
        self.D = D
        self.V = V
        self.curVal = 0
        self.gains = []
        self.alpha = alpha
        self.f_norm = self.alpha / self.f_norm(V)
        self.norm = 1. / self.inc(V, [])

    def f_norm(self, sset):
        sum_exp = np.exp(self.alpha * self.D[:, sset]).sum()
        return 1. / self.alpha * math.log(1 + sum_exp)

    def inc(self, sset, ndx):
        if len(sset + [ndx]) > 1:
            if not ndx:  # normalization
                return math.log(1 + self.alpha * 1)
            sum_exp = np.exp(self.alpha * self.D[:, sset]).sum()
            soft_fl = 1. / self.alpha * math.log(1 + sum_exp)
            soft_norm = self.norm * math.log(1 + self.f_norm * soft_fl)
        else:
            sum_exp = np.exp(self.alpha * self.D[:, sset + [ndx]]).sum()
            soft_fl = 1. / self.alpha * math.log(1 + sum_exp)
            soft_norm = self.norm * math.log(1 + self.f_norm * soft_fl)
        return soft_norm - self.curVal

    def add(self, sset, ndx):
        cur_old = self.curVal
        if not ndx:  # normalization
            sum_exp = np.exp(self.alpha * self.D[:, sset]).sum()
            soft_fl = 1. / self.alpha * math.log(1 + sum_exp)
            soft_norm = self.norm * math.log(1 + self.f_norm * soft_fl)
        else:
            sum_exp = np.exp(self.alpha * self.D[:, sset + [ndx]]).sum()
            soft_fl = 1. / self.alpha * math.log(1 + sum_exp)
            soft_norm = self.norm * math.log(1 + self.f_norm * soft_fl)

        self.curVal = soft_norm
        self.gains.extend([self.curVal - cur_old])
        return self.curVal


class SoftFacilityLocation:

    def __init__(self, D, V, alpha=1.):
        self.D = D
        self.V = V
        self.curVal = 0
        self.gains = []
        self.alpha = alpha
        self.f_norm = self.alpha / self.f_norm(V)
        self.norm = 1. / self.inc(V, [])

    def f_norm(self, sset):
        return self.D[:, sset].max(axis=1).sum()

    def inc(self, sset, ndx):
        if not ndx:
            sum_exp = np.exp(self.alpha * self.D[:, sset]).sum()
        else:
            sum_exp = np.exp(self.alpha * self.D[:, sset + [ndx]]).sum()
        return 1. / self.alpha * math.log(1 + sum_exp) - self.curVal

    def add(self, sset, ndx):
        cur_old = self.curVal
        if not ndx:
            sum_exp = np.exp(self.alpha * self.D[:, sset]).sum()
        else:
            sum_exp = np.exp(self.alpha * self.D[:, sset + [ndx]]).sum()
        self.curVal = 1. / self.alpha * math.log(1 + sum_exp)
        self.gains.extend([self.curVal - cur_old])
        return self.curVal


class FacilityLocation_L:

    def __init__(self, D, V, L, alpha=1.):
        self.D = D
        self.V = V
        self.curVal = 0
        self.gains = []
        self.alpha = alpha
        self.L = L
        self.f_norm = self.alpha / self.f_norm(V)
        self.norm = 1. / self.inc(V, [])

    def f_norm(self, sset):
        return self.D[:, sset].max(axis=1).sum()

    def inc(self, sset, ndx):
        if len(sset + [ndx]) > 1:
            if not ndx:  # normalization
                return math.log(1 + self.alpha * 1)
            return self.norm * math.log(1 + self.f_norm * self.D[:, sset + [ndx]].max(axis=1).sum()) - self.curVal + \
                   self.L[ndx]
        else:
            return self.norm * math.log(1 + self.f_norm * self.D[:, sset + [ndx]].sum()) - self.curVal + self.L[ndx]

    def add(self, sset, ndx):
        cur_old = self.curVal
        if len(sset + [ndx]) > 1:
            self.curVal = self.norm * math.log(1 + self.f_norm * self.D[:, sset + [ndx]].max(axis=1).sum()) + self.L[
                ndx]
        else:
            self.curVal = self.norm * math.log(1 + self.f_norm * self.D[:, sset + [ndx]].sum()) + self.L[ndx]
        self.gains.extend([self.curVal - cur_old])
        return self.curVal


def lazy_greedy(F, ndx, B):
    '''
    Args
    - F: FacilityLocation
    - ndx: indices of all points
    - B: int, number of points to select
    '''
    TOL = 1e-6
    eps = 1e-15
    curVal = 0
    sset = []
    order = []
    vals = []
    for v in ndx:
        marginal = F.inc(sset, v) + eps
        heapq.heappush(order, (1.0 / marginal, v, marginal))

    while order and len(sset) < B:
        el = heapq.heappop(order)
        if not sset:
            improv = el[2]
        else:
            improv = F.inc(sset, el[1]) + eps
            # print(improv)

        # check for uniques elements
        if improv > 0 + eps:
            if not order:
                curVal = F.add(sset, el[1])
                # print curVal
                # print str(len(sset)) + ', ' + str(el[1])
                sset.append(el[1])
                vals.append(curVal)
            else:
                top = heapq.heappop(order)
                if improv >= top[2]:
                    curVal = F.add(sset, el[1])
                    # print curVal
                    # print str(len(sset)) + ', ' + str(el[1])
                    sset.append(el[1])
                    vals.append(curVal)
                else:
                    heapq.heappush(order, (1.0 / improv, el[1], improv))
                heapq.heappush(order, top)
        else:
            2

    # print(sset)
    return sset, vals


def _heappush_max(heap, item):
    heap.append(item)
    heapq._siftdown_max(heap, 0, len(heap) - 1)


def _heappop_max(heap):
    """Maxheap version of a heappop."""
    lastelt = heap.pop()  # raises appropriate IndexError if heap is empty
    if heap:
        returnitem = heap[0]
        heap[0] = lastelt
        heapq._siftup_max(heap, 0)
        return returnitem
    return lastelt


def lazy_greedy_heap(F, V, B):
    curVal = 0
    sset = []
    vals = []

    order = []
    heapq._heapify_max(order)
    [_heappush_max(order, (F.inc(sset, index), index)) for index in V]

    while order and len(sset) < B:
        el = _heappop_max(order)
        improv = F.inc(sset, el[1])

        # check for uniques elements
        if improv >= 0:  # TODO <====
            if not order:
                curVal = F.add(sset, el[1])
                # print curVal
                sset.append(el[1])
                vals.append(curVal)
            else:
                top = _heappop_max(order)
                if improv >= top[0]:
                    curVal = F.add(sset, el[1])
                    # print curVal
                    sset.append(el[1])
                    vals.append(curVal)
                else:
                    _heappush_max(order, (improv, el[1]))
                _heappush_max(order, top)

    # print(str(sset) + ', val: ' + str(curVal))

    return sset, vals


def unconstrained(F, V):
    curVal = 0
    sset = []
    Y = V
    vals = []

    for i in V:
        a = F.inc(sset, i)
        b = F.inc([], list(set(Y) - {i})) - F.inc([], Y)
        if a >= b:
            sset.append(i)
        else:
            Y.remove(i)
    return sset


def test():
    n = 100
    B = 80
    np.random.seed(10)
    X = np.random.rand(n, n)
    D = X * np.transpose(X)

    F = FacilityLocation(D, list(range(0, n)), alpha=1, gamma=0)
    sset, vals = lazy_greedy_heap(F, list(np.arange(0, n)), B)
    print(sset)
    print(len(sset))
    F.sset, F.curVal, F.gains = [], 0, []
    sset = unconstrained(F, sset)
    print(sset)
    print(len(sset))

    F = FacilityLocation(D, list(range(0, n)), alpha=1, gamma=10)
    sset, vals = lazy_greedy_heap(F, list(np.arange(0, n)), B)
    print(sset)
    print(len(sset))
    F.sset, F.curVal, F.gains = [], 0, []
    sset = unconstrained(F, sset)
    print(sset)
    print(len(sset))

    F = FacilityLocation(D, list(range(0, n)), alpha=1, gamma=100)
    sset, vals = lazy_greedy_heap(F, list(np.arange(0, n)), B)
    print(sset)
    print(len(sset))
    F.sset, F.curVal, F.gains = [], 0, []
    sset = unconstrained(F, sset)
    print(sset)
    print(len(sset))


if __name__ == '__main__':
    test()
