import pandas as pd
import numpy as np
import time
import math
import pulp as p
import cplex
import networkx as nx
import matplotlib.pyplot as plt
import time

from scipy.spatial import distance_matrix

def read_reuters():
    df = pd.read_csv('data/c50.csv')
    C = np.array(df.iloc[:, 1:])
    colors = df['color']  # the size of colors equals the size of data set, i.e., colors[i]=1 denotes point i has color 1.
    groups = {}  # dict: key(group)->value(point)
    for k in range(colors.max()+1):
        groups[k] = df.loc[colors == k].index.tolist()

    return C, groups, colors, 'c50'

def greedy_k_center(C, F, k, groups, alpha, beta, delta, epsilon):
    if delta is None:
        if (alpha is None) or (beta is None):
            raise Exception("alpha, beta, and delta cannot be all None")
    else:
        alpha, beta = calculate_alpha_beta(C, F, k, groups, delta)

    t = time.time()

    n = len(C)
    S_index = greedy_helper(F, k)  # get the index of the set of centers S
    S = F[S_index]  # get the set of centers S
    d = distance_matrix(C, S)

    clusters = {i : [] for i in range(k)}
    cost = 0
    for i in range(n):
        closest_center = np.argmin(d[i])  # get the closest_center in S
        clusters[closest_center].append(i)
        cost = max(cost, d[i, closest_center])  # get cost of k-center

    time_taken = time.time() - t

    m = len(groups.keys())
    points = {i : [] for i in range(n)}  # get color of points
    for g_i in range(m):
        for point in groups[g_i]:
            points[point].append(g_i)

    violations = max_add_violation(C, S, groups, clusters, alpha, beta, points)

    return violations, time_taken, cost



def greedy_helper(X, k):

    """

    X: Numpy array which is Nxd. N points, each with d features.
    k: Number of centers to return
    Returns: *Index* of k centers using greedy algorithm

    """
    length_of_X = len(X)
    i_1 = np.random.randint(0, length_of_X)  # pick the first point form X randomly
    k_centers = [i_1]
    while len(k_centers) < k:
        max_dist = -1
        best_center = None  # choose the farthest points from X  as next center
        d = distance_matrix(X, X[k_centers])  # obtain the distance between points and centers chosen
        for i in range(length_of_X):
            di = d[i].min()  # the min distance of point i and  the center from centers chosen
            if di > max_dist:
                max_dist = di
                best_center = i
        k_centers.append(best_center)

    return k_centers

def calculate_alpha_beta(C, F, k, groups, delta):
    m = len(groups.keys())
    n = len(C)
    alpha, beta = np.zeros(m), np.zeros(m)
    for i in range(m):
        ratio_i = len(groups[i]) / n
        alpha[i], beta[i] = ratio_i / (1-delta), ratio_i * (1 - delta)
    return (alpha, beta)

def max_add_violation(C, S, groups, clusters, alpha, beta, points):
    "Calculate the maximum additive violation"
    m = len(groups)
    max_additive_violation = 0

    for j in range(len(clusters)):  # for each cluster j
        balls = np.array(clusters[j])
        cluster_size = len(balls)
        for g_i in range(m):  # for each group g_i
            number_of_g_i = 0
            for point in balls:  # for each point in the cluster
                if g_i in points[point]:  # note that points: key->value(list)
                    number_of_g_i += 1

            # should have beta[a]*clust_size <= number_of_g_i <= alpha[a]*clust_size
            if number_of_g_i > alpha[g_i] * cluster_size:
                max_additive_violation = max(max_additive_violation, math.ceil(number_of_g_i - alpha[g_i] * cluster_size))
            elif number_of_g_i < beta[g_i] * cluster_size:
                max_additive_violation = max(max_additive_violation, math.ceil(beta[g_i] * cluster_size - number_of_g_i))

    return max_additive_violation

def improved_fair_k_center(C, F, k, groups, alpha, beta, delta, epsilon):
    if delta is None:
        if alpha is None or beta is None:
            raise Exception("alpha, beta, and delta cannot be all None")
    else:
        alpha, beta = calculate_alpha_beta(C, F, k, groups, delta)


    t = time.time()

    n = len(C)
    S_index = greedy_helper(F, k)  # get the index of the set of centers S
    S = F[S_index]  # get the set of centers S
    d = distance_matrix(C, S)
    l, r = 0, 2 * np.max(d)
    feasible = False

    while (r - l > epsilon) or not feasible:
        print("the lower is %f"%l, ",the upper is %f"%r)
        lamb = (l + r) / 2
        skip = False
        for i in range(n):
            if d[i].min() > lamb:
                l = lamb
                feasible = False
                skip = True
                break
        if skip:
            continue

        LP, status, clusters, points = lp(C, S, k, groups, alpha, beta, lamb)
        if p.LpStatus[status] == 'Optimal':
            r, feasible = lamb, True
        else:
            l, feasible = lamb, False

    time_takes = time.time() - t
    violations = max_add_violation(C, S, groups, clusters, alpha, beta, points)

    cost = max([distance_matrix([S[j]], C[clusters[j]]).max() if len(clusters[j])>0 else 0 for j in range(len(S))])

    return violations, time_takes, cost



#  求解
def lp(C, S, k, groups, alpha, beta, lamb, reps=5):


    d = distance_matrix(C, S)

    m = len(groups)
    n = len(C)
    k = len(S)
    # points = [[i : [g_i]] for g_i in range(m) for i in range(n) if i in groups[g_i]]
    points = {i: [] for i in range(n)}
    for g_i in range(m):
        for point in groups[g_i]:
            points[point].append(g_i)


    """
    For each i in S,  and j in C, create a variable x_ij
    """
    count1 = 0
    count2 = 0
    variables = {}
    for i in range(k):
        for j in range(n):
            count1 += 1
            if(d[j,i] <= 2*lamb):
                count2 += 1
                variable_sig = tuple([i, j])
                variables[variable_sig] = p.LpVariable(str(variable_sig), lowBound=0, upBound=1)

    obj = 1
    Lp_prob = p.LpProblem('Problem', p.LpMaximize)
    Lp_prob += obj

    # time02 = time.time()
    """
    First constraints
    """
    for j in range(n):  # for each
        all_var = []  # get all variables that point to point j
        for i in range(k):
            if (d[j, i] <= 2 * lamb):

                var_sig = tuple([i, j])
                all_var.append(variables[var_sig])
        all_var_sum = p.lpSum(all_var)

        Lp_prob += all_var_sum == 1

    # time03  = time.time()
    """
    Second constraints
    """
    for i in range(k):
        all_var = []  # get all variable that point to cluster i
        for j in range(n):
            if (d[j, i] <= 2 * lamb):
                var_sig = tuple([i, j])
                all_var.append(variables[var_sig])
        all_var_sum = p.lpSum(all_var)  # sum_{j in C} x_ij

        for g_i in range(m):  # for each group
            var_in_group = []
            for j in range(n):
                if g_i in points[j]:  # if point j has color g_i
                    if (d[j, i] <= 2 * lamb):
                        var_sig = tuple([i, j])
                        var_in_group.append(variables[var_sig])

            var_in_group_sum = p.lpSum(var_in_group)

            Lp_prob += var_in_group_sum <= alpha[g_i] * all_var_sum
            Lp_prob += var_in_group_sum >= beta[g_i] * all_var_sum

    # time04  = time.time()
    # """
    # Third constraints
    # """
    # for i in range(k):
    #     for j in range(n):
    #         if d[j][i] > 2*lamb:
    #             var_sig = tuple([i, j])
    #             Lp_prob += variables[var_sig] >= 0
    #             Lp_prob += variables[var_sig] <= 0
    # time05 = time.time()

    path_to_cplex = r'/Applications/CPLEX_Studio1210/cplex/bin/x86-64_osx/cplex'
    solver_cmd = p.CPLEX_CMD(path=path_to_cplex, msg=0)

    solver = p.CPLEX_PY(msg=0)
    EPSILON = 0.1
    try:
        status = Lp_prob.solve(solver)
    except:
        status = 0

    if p.LpStatus[status]!='Optimal':
        return None, status, None, None

    # for v in Lp_prob.variables():
    #     print(v.name, "=", v.value())

    # for i in range(k):
    #     for j in range(n):
    #         print(variables[tuple([i, j])].value())

    # time06  = time.time()


    # calculate b-values of nodes
    v = np.zeros([k, m])  # v[i][g_i] (i.e., T_i^h in our paper) denote the value of points with color g_i assigned to center i
    for i in range(k):
        for j in range(n):
            if (d[j, i] <= 2 * lamb):
                v[i][points[j][0]] += variables[tuple([i, j])].value()

    v_floor = np.floor(v)

    bvalue_of_centers = np.zeros(k)  # bvalue_of_centers[i] (i.e., T_i in our paper) denote the value of points assigned to center i
    for i in range(k):
        bvalue_of_centers[i] = int(np.sum(v, axis=1)[i]) - np.sum(v_floor, axis=1)[i]
    bvalue_of_t = n - np.sum(np.floor(np.sum(v, axis=1)))  # the b-value of sink node t


    """
    Constructed the min-cost flow 
    """
    G = nx.DiGraph()

    """
    G = ((V, E), b)
    E = E_1 + E_2 + E_3 + E_4
    """

    for j in range(n):  # constructed the edge form s to clientj for j in C with capacity 1, weight 0
        G.add_edges_from([('s', 'client%d'%j, {'capacity': 1, 'weight': 0})])  # the edges of E_1
        for i in range(k):
            if bvalue_of_centers[i] - math.floor(bvalue_of_centers[i]) > 0:  # when T_i - floor(T_i)>0  the edge is added to E_4
                G.add_edges_from([('facility%d'%i, 't', {'capacity': 1, 'weight': 0})])
            for g_i in range(m):
                if v[i][g_i] - math.floor(v[i][g_i]) > 0:  # when T_i^h - floor(T_i^h)>0  the edge is added to E_3
                    G.add_edges_from([('facility{}group{}'.format(i, g_i), 'facility%d'%i, {'capacity': 1, 'weight': 0})])
                if (d[j, i] <= 2 * lamb):
                    if variables[tuple([i, j])].value()>0 and g_i in points[j]:  # the edges of E_2
                        G.add_edges_from([('client%d'%j, 'facility{}group{}'.format(i, g_i), {'capacity': 1, 'weight': int(d[j][i])})])


    G.add_node("s", demand=-n)  # demand>0 denote demand, demand<0 denote supply, it is opposite of the definiton of b-values.
    G.add_node("t", demand=int(bvalue_of_t))
    for j in range(n):
        G.add_node('client%d'%j, demand=0)
    for i in range(k):
        G.add_node('facility%d'%i, demand=int(bvalue_of_centers[i]))
        for g_i in range(m):
            G.add_node('facility{}group{}'.format(i, g_i), demand=int(math.floor(v[i][g_i])))

    # nx.draw(G)
    # plt.show()

    # minFlowCost = nx.min_cost_flow_cost(G)
    minFlowDict = nx.min_cost_flow(G)

    edgeLists = []
    for i in minFlowDict.keys():
        for j in minFlowDict[i].keys():
            # edgeLists[(i, j)] += ',f=' + str(minFlowDict[i][j])
            if minFlowDict[i][j] > 0 and i[:6] == 'client':
                edgeLists.append((i, j))
    # time07  = time.time()



    #print("11111111111111111111111111111111111111111111111111111111111111111111111111111")
    # print(time02-time01)
    # print(time03-time02)
    # print(time04-time03)
    # print(time05 - time04)
    # print("111111111111111111111111111111111111111111111111111111111111111111111111111111")
    # print(time06 - time05)
    # print(time07 - time06)
    print("5222222222222222222222222222222222222222222222222222222222222222222")
    print(count1)
    print(count2)
    print("sddddddddddddddddddddddddddddddddddddddddddd")
    clusters = {i: [] for i in range(k)}
    for j in edgeLists:
        client, facility = int(j[0][6:]), int(j[1][j[1].find('y')+1:j[1].find('g')])
        clusters[facility].append(client)

    violations = max_add_violation(C, S, groups, clusters, alpha, beta, points)

    tot_points = 0
    for i in clusters.keys():
        tot_points += len(clusters[i])
    assert tot_points == n

    return Lp_prob, status, clusters, points







