from collections import defaultdict
import numpy as np

try:
    GUROBI = True
    import gurobipy as gp
    from gurobipy import GRB
except ImportError:
    GUROBI = False

import pandas as pd
import sys
import math
import random

def RandomSetSystem(nbElements, nbSets, threshold, random_seed=0): #returns a set system
    
    rng = np.random.default_rng(random_seed)
    elements = list(range(nbElements))
    
    ptsInSets = defaultdict(list) #indexed by sets
    Sets = defaultdict(list) #indexed by elements

    for s in range(nbSets):
        for e in rng.choice(elements, int(threshold*len(elements)), replace = False):
            Sets[e].append(s)
            ptsInSets[s].append(e)
    
    elem = [e for e in elements if Sets[e] != []]
    newSets = {e: Sets[e] for e in elem}
    
    return elem, newSets, ptsInSets, nbSets

def makeSetSystemFromHS():
    file = sys.argv[1]
    with open(file, 'r') as f:
        content = f.readlines()

    Sets = defaultdict(list) #indexed by points
    ptsInSets = defaultdict(list) #indexed by Sets
    
    [p_HS, S_HS] = [int(x) for x in content[0].split(" ")[2:]]
    for index, line in zip(list(range(0, S_HS)), content[1:]):
        element = index
        sets_containing_element = []
        for _ in line.split(" "):
            try:
                sets_containing_element.append(int(_))
            except (ValueError, TypeError):
                pass
        
        Sets[element] = sets_containing_element
        for s in sets_containing_element:
            ptsInSets[s].append(element)

    #clean_up
    Sets = {e: Sets[e] for e in Sets.keys() if Sets[e] != []}
    ptsInSets = {s: ptsInSets[s] for s in ptsInSets.keys() if ptsInSets[s] != []}

    if (len(Sets.keys()) != S_HS) or (len(ptsInSets.keys()) != p_HS):
        print('something does not add up')
    
    return list(range(0, S_HS)), Sets, ptsInSets, p_HS

def SetCoverGreedy(elements, Sets, ptsInSets):
    elem = set(elements)
    pts = {S: set([e for e in ptsInSets[S] if e in elem]) for S in ptsInSets} #indexed by sets, 
    collection = []
    efficiency = defaultdict(set)

    for S in ptsInSets.keys():
        efficiency[len([e for e in ptsInSets[S] if e in elem])].add(S)
    for most_eff in reversed(range(1, max(efficiency.keys())+1)):
        while len(efficiency[most_eff]) != 0:
            S = efficiency[most_eff].pop() #selects most efficient set
            collection.append(S) #...and adds it
            for e in pts[S]:
                for other_S in Sets[e]: #loops through all 
                    if other_S != S:
                        sz = len(pts[other_S])
                        pts[other_S].discard(e)
                        efficiency[sz].discard(other_S)
                        efficiency[sz-1].add(other_S)
            efficiency[most_eff].discard(S)

    #check if collection covers all elements
    for S in collection:
        for e in ptsInSets[S]:
            elem.discard(e)
    if len(elem) != 0:
        print('error!!')
        return 0
    
    return collection 


#computes greedy solution for the (subset of) elements
def decompositionFromGreedy(elements, Sets, ptsInSets):
    decomposition = SetCoverGreedy(elements, Sets, ptsInSets) #Sets stacked in decreasing order of efficiency.
    return [(len(decomposition), decomposition)] #return S1, S2, ... stacked together. 

def getSelected(x,y):
    selectedSets = []
    covered = []
    uncovered = []
    for s in x.keys():
        if x[s].X > 0.5:
            selectedSets.append(s)
    for e in y.keys():
        if y[e].X > 0.5:
            covered.append(e)
        else:
            uncovered.append(e)
    return selectedSets, covered, uncovered

def partialSetCover(elements, Sets, nb_to_cover):
    with gp.Env(empty=True) as env:
        env.setParam('OutputFlag', 0)
        env.start()
        with gp.Model(env=env) as m:
            m.Params.LogToConsole = 0
            m.setParam("Heuristics", 1)
            m.params.Timelimit = 200
            m.params.MIPGap = 0.05

            m.ModelSense = GRB.MINIMIZE
            Sets_names = list(set([s for e in Sets.keys() for s in Sets[e]]))
            x = m.addVars(Sets_names, vtype = GRB.BINARY, obj = 1)
            y = m.addVars(elements, vtype = GRB.BINARY, obj = -0.0001) #ensures the MIP covers an element if this happens "for free".
            for e in elements:
                m.addConstr(gp.quicksum(x[s] for s in Sets[e]) >= y[e]) 
            m.addConstr(gp.quicksum(y[e] for e in elements) >= nb_to_cover)
            m.optimize()
            selectedSets, covered, uncovered = getSelected(x,y)
            
            return len(selectedSets), selectedSets, covered, uncovered
        

def decompositionFromIP(elements, Sets, ptsInSets): #computes a simplified (1,1)-decomposition
    
    if len(elements) == 1:
        return [(1, Sets[elements[0]][0])]
        
    decomposition = []

    val, selectedSets, covered, uncovered = partialSetCover(elements, Sets, (len(elements)+1)//2)
    decomposition.append((val, selectedSets))
    curr_elements = uncovered
    last_cost = val

    while len(curr_elements) > 0:
        if len(curr_elements)==1:
            decomposition.append((1, [Sets[elements[0]][0]]))
            break
        lower = (len(curr_elements)+1)//2 #we need to cover at least half the remaining elements
        upper = len(curr_elements)
        while lower < upper: #binary search to find the smallest subset of remaining elements with cost at least twice preceding decomp.
            middle = (lower + upper)//2 
            val, selectedSets, covered, uncovered = partialSetCover(curr_elements, Sets, middle)
            if val < 2*last_cost:
                lower = middle+1
            else:
                upper = middle
            
        val, selectedSets, covered, uncovered = partialSetCover(curr_elements, Sets, lower)
        decomposition.append((val, selectedSets))
        last_cost = val
        curr_elements = uncovered
    return decomposition


def getSelectedSets(x):
    selectedSets = []
    for s in x.keys():
        if x[s].X > 0.2:
            selectedSets.append(s)
    return selectedSets

def SetCoverIP(elements, Sets, ptsInSets): #classical set cover algorithm with IP-solver
    with gp.Env(empty=True) as env:
        env.setParam('OutputFlag', 0)
        env.start()
        with gp.Model(env=env) as m:
            m.Params.LogToConsole = 0
            m.setParam("Heuristics", 1)
            m.params.Timelimit = 200
            m.params.MIPGap = 0.01
            m.ModelSense = GRB.MINIMIZE
            Sets_names = list(set([s for e in Sets.keys() for s in Sets[e]]))
            x = m.addVars(Sets_names, vtype = GRB.BINARY, obj = 1)
            for e in elements:
                m.addConstr(gp.quicksum(x[s] for s in Sets[e]) >= 1) 
        
            m.optimize()
            selectedSets = getSelectedSets(x)
    
            return selectedSets
        

def onlineSC(elements, Sets, ptsInSets, random_seed=0, prediction = []):
    rng = np.random.default_rng(random_seed)
    distSets = ptsInSets.keys()
    x_S = {s: 1/len(distSets) for s in distSets}
    thresh_S = {s: min(rng.uniform(0,1) for _ in range(int(max(1,math.log(len(elements)))))) for s in ptsInSets.keys()}
    covered = set()
    selected = set()
    predictedSetsInOrder = [s for (a,b) in prediction for s in b]

    #data structure to quickly look up where the set ranges in the prediction
    orderOfPredictedSets = dict()
    for (ind, S) in enumerate(predictedSetsInOrder):
        orderOfPredictedSets[S] = ind


    for e in elements:
        if e in covered:
            continue
        else:
            total = sum(x_S[s] for s in Sets[e]) #get total weight covering e
            to_multiply = math.ceil(math.log(1/total,2)) 
            for s in Sets[e]:
                x_S[s] *= 2**to_multiply #multiply all sets containing e by suitable power of 2
            
            set_covering_e = None
            
            if prediction != []:  #among all eligible sets in the prediction covering e, get the one that is one the lowest layer
                eligibleSets = []
                for S in Sets[e]:
                    if (x_S[S] >= thresh_S[S]) and (S in orderOfPredictedSets):
                        eligibleSets.append((orderOfPredictedSets[S], S))
                if eligibleSets != []:
                    set_covering_e = min(eligibleSets)[1]

            
            if set_covering_e == None: #gets any eligible set covering e
                eligibleSets = [] #all eligible sets, e.g. above threshold, that cover e
                for s in Sets[e]:
                    if x_S[s] >= thresh_S[s]:
                        eligibleSets.append(s) 
                if eligibleSets != []:
                    set_covering_e = rng.choice(eligibleSets, 1)[0] #picks a uniformly random set among eligible sets. 
                        
            if set_covering_e == None: #do greedy if no set was selected
                set_covering_e = Sets[e][0]
                

            if set_covering_e == None: 
                print(e, ' is not covered')

            selected.add(set_covering_e)
            covered.add(e)
            for e in ptsInSets[set_covering_e]:
                covered.add(e)
                
            if prediction: # In the setting with prediction, whenever we spend 1, we spend 1 on a predicted set. This is equivalent to Algorithm 1.
                if predictedSetsInOrder != []:
                    additional_S = predictedSetsInOrder.pop(0)
                    selected.add(additional_S)
                    for e in ptsInSets[additional_S]:
                        covered.add(e)
            
    return len(selected)

def downsample(arr, fraction, random_seed = 0):
    rng = np.random.default_rng(random_seed)
    return list(rng.choice(arr, int(fraction*len(arr)), replace = False))




FRAC = 0.05
SPLIT = 2
NB_RESAMPLING = 10
ITERATIONS = 10

FRACTIONS = [0.0, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.225, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]

def experiment():

    if sys.argv[1].isdigit():
        NB_ELEMENTS = 1000
        NB_SETS = 100
        random_seed = int(sys.argv[1])
        NAME = random_seed
        elements, Sets, ptsInSets, nbSets = RandomSetSystem(NB_ELEMENTS, NB_SETS, FRAC, random_seed)
        pred_elements, rem_elements = elements[:len(elements)//SPLIT], elements[len(elements)//SPLIT:]
    
    else:
        random_seed = 1
        elements, Sets, ptsInSets, nbSets = makeSetSystemFromHS()
        random.seed(random_seed)
        random.shuffle(elements)
        pred_elements, rem_elements = elements[:len(elements)//SPLIT], elements[len(elements)//SPLIT:]
        NAME = str(sys.argv[1]).split("/")[-1].split(".")[0]
        NB_ELEMENTS = len(elements)
        NB_SETS = len(ptsInSets)

    

    method = str(sys.argv[2])

    if method == 'IP':
        if GUROBI == False:
            print('Gurobi is not installed, but is required.')
            return
        make_decomp = decompositionFromIP
        make_opt = SetCoverIP
    elif method == 'Greedy':
        make_decomp = decompositionFromGreedy
        make_opt = SetCoverGreedy
    else:
        return 0

    decomp = make_decomp(pred_elements, Sets, ptsInSets)

    results = []

    columns = ['NAME', 'METHOD', 'NB_ELEMENTS', 'NB_SETS', 'SEED', 'SPLIT', 'subsample_frac', 'resample', 'opt', 'classical', 'LA']

    for frac in FRACTIONS:
        for resampl in range(NB_RESAMPLING):
            actual_elements = downsample(pred_elements, 1-frac, 1000*random_seed + resampl) + downsample(rem_elements, frac, 1000*random_seed + resampl)
            opt = len(make_opt(actual_elements, Sets, ptsInSets))
            classical = sum([onlineSC(actual_elements, Sets, ptsInSets, 1000*(1000*random_seed + resampl) + _) for _ in range(ITERATIONS)])/ITERATIONS
            LA = sum([onlineSC(actual_elements, Sets, ptsInSets, 1000*(1000*random_seed + resampl) + _, decomp) for _ in range(ITERATIONS)])/ITERATIONS
            
            results.append([NAME, method, NB_ELEMENTS, NB_SETS, random_seed, SPLIT, frac, resampl, opt, classical, LA])
    

    data = pd.DataFrame(results, columns = columns)
    print(data.to_csv(header=True))


if __name__=="__main__":
    experiment()