
import numpy as np 
import math 
import torch
import heapq

## FL MAX CODE
class FacilityLocation:

    def __init__(self, D, V, alpha=1.):
        '''
        Args
        - D: np.array, shape [N, N], similarity matrix
        - V: list of int, indices of columns of D
        - alpha: float
        '''
        self.D = D
        self.curVal = 0
        self.curMax = np.zeros(len(D))
        self.gains = []
        self.alpha = alpha
        self.f_norm = self.alpha / self.f_norm(V)
        self.norm = 1. / self.inc(V, [])

    def f_norm(self, sset):
        return self.D[:, sset].max(axis=1).sum()

    def inc(self, sset, ndx):
        if len(list(sset) + [ndx]) > 1:
            if ndx == []:  # normalization
                return math.log(1 + self.alpha * 1)
            return self.norm * math.log(1 + self.f_norm * np.maximum(self.curMax, self.D[:, ndx]).sum()) - self.curVal
        else:
            return self.norm * math.log(1 + self.f_norm * self.D[:, ndx].sum()) - self.curVal

    def add(self, sset, ndx):
        cur_old = self.curVal
        if len(sset + [ndx]) > 1:
            self.curMax = np.maximum(self.curMax, self.D[:, ndx])
        else:
            self.curMax = self.D[:, ndx]
        self.curVal = self.norm * math.log(1 + self.f_norm * self.curMax.sum())
        self.gains.extend([self.curVal - cur_old])
        return self.curVal


def _heappush_max(heap, item):
    heap.append(item)
    heapq._siftdown_max(heap, 0, len(heap)-1)


def _heappop_max(heap):
    """Maxheap version of a heappop."""
    lastelt = heap.pop()  # raises appropriate IndexError if heap is empty
    if heap:
        returnitem = heap[0]
        heap[0] = lastelt
        heapq._siftup_max(heap, 0)
        return returnitem
    return lastelt


def lazy_greedy_heap(F, V, B):
    curVal = 0
    sset = []
    vals = []

    order = []
    heapq._heapify_max(order)
    [_heappush_max(order, (F.inc(sset, index), index)) for index in V]

    while order and len(sset) < B:
#         print(len(sset))
        el = _heappop_max(order)
        improv = F.inc(sset, el[1])
        
        # check for uniques elements
        if improv >= 0:
            if not order:
                curVal = F.add(sset, el[1])
                sset.append(el[1])
                vals.append(curVal)
            else:
                top = _heappop_max(order)
                if improv >= top[0]:
                    curVal = F.add(sset, el[1])
                    sset.append(el[1])
                    vals.append(curVal)
                else:
                    _heappush_max(order, (improv, el[1]))
                _heappush_max(order, top)

    return sset, vals


def stochastic_greedy(F, V, B, epsilon=.01):
    V = set(V)
    n = len(V)
    r_size = int((n/B)*math.log(1/epsilon))
    
    curVal = 0
    sset = []
    vals = []
    
    while len(sset) < B:   
        R_list = np.array(list(V))[np.random.choice(len(V), r_size, replace=False)]
        
        best_idx, best_gain = -1, -1
        for idx in R_list:
            g = F.inc(sset, idx)
            
            if g > best_gain:
                best_idx = idx
                best_gain = g
        
        _ = F.add(sset, best_idx)
        V.remove(best_idx)
        sset.append(best_idx)

    return sset


def FL_evaluation(sim, sets, norm=1., device=None):
    return torch.sum(torch.amax(sim[:,sets], dim=1))/norm