import numpy as np

from lp_freq_distributor import *
from greedy_k_center import *
from loaddatasets import *
from lp_bera import *
from lp_ahmadian import *
import math
import re

import faircluster

k = 25
epsilon = 0.1
delta = 0.3

C, groups, colors, name = get_creditcard()

C = np.array(C, dtype="float64")


print(len(C))
m = len(groups)
print(m)
alpha = 0.8 * np.ones(m)

print(alpha)
beta = 0 * np.zeros(m)
delta = None


class TreeNodeHST:
    def __init__(self, Lower, Upper, D, id_range):
        self.lower = Lower
        self.upper = Upper
        self.Dia = D
        self.id = id_range
        self.childrenlist = None

class MultiView:
    def __init__(self, data, k, lda, delta, epsilon, problem, args):
        self.data_ = data
        self.clusters_ = k
        self.lda_ = lda
        self.delta_ = delta
        self.problem_ = problem
        self.args_ = args
        self.epsilon_ = epsilon
        self.root_ = self.shifted_quadtree(self.epsilon_)
        self.radius_list_ = self.range_cover(self.root_)
        # print(np.sort(self.radius_list_))

    # def CreateNode1(self, Node, shift, epsilon):
    #     k = self.clusters_
    #     if Node.id.shape[0] <= self.data_.shape[0] * epsilon / k:
    #         return None
    #     mid = 0.5 * (Node.lower + Node.upper)
    #     node_data = self.data_[Node.id]
    #     id_array = np.zeros(node_data.shape[0], dtype=int)
    #     mul = 0
    #     for i in range(self.data_.shape[1] - 1, 0, -1):
    #         data_diff = node_data[:, i] + shift[i] - mid[i]
    #         id_large = np.argwhere(data_diff > 0)[:, 0]
    #         id_array[id_large] += int(math.pow(2, mul))
    #         mul += 1

    #     id_array_sort = np.argsort(id_array)
    #     new_node_list = []
    #     start = 0
    #     for i in range(0, id_array_sort.shape[0] - 1):
    #         if (id_array[id_array_sort[i]] != id_array[id_array_sort[i + 1]]):
    #             binary = id_array[id_array_sort[i]]
    #             binary = bin(binary)[2:]
    #             binary = binary.zfill(self.data_.shape[1])
    #             binary = np.array(list(binary), dtype=int)
    #             lower_temp = Node.lower.copy()
    #             upper_temp = Node.upper.copy()
    #             small = np.argwhere(binary == 0)[:, 0]
    #             large = np.argwhere(binary == 1)[:, 0]
    #             lower_temp[large] = mid[large]
    #             upper_temp[small] = mid[small]
    #             D = math.sqrt(((upper_temp - lower_temp) ** 2).sum())
    #             new_node = TreeNodeHST(lower_temp, upper_temp, D, id_array_sort[np.arange(start, i + 1)])
    #             start = i + 1
    #             new_node_list.append(new_node)

    #     binary = id_array[id_array_sort[start]]
    #     binary = bin(binary)[2:]
    #     binary = binary.zfill(self.data_.shape[1])
    #     binary = np.array(list(binary), dtype=int)
    #     lower_temp = Node.lower.copy()
    #     upper_temp = Node.upper.copy()
    #     small = np.argwhere(binary == 0)[:, 0]
    #     large = np.argwhere(binary == 1)[:, 0]
    #     lower_temp[large] = mid[large]
    #     upper_temp[small] = mid[small]
    #     D = math.sqrt(((upper_temp - lower_temp) ** 2).sum())
    #     new_node = TreeNodeHST(lower_temp, upper_temp, D, id_array_sort[np.arange(start, node_data.shape[0])])
    #     new_node_list.append(new_node)

    #     return new_node_list

    def CreateNode1(self, Node, shift, k, epsilon):
        if Node.id.shape[0] <= 1:
            # print("Pass")
            return None
        d_diff = Node.upper - Node.lower
        if (np.max(d_diff) < 1E-6):
            return None

        #print(Node.id)

        mid = 0.5 * (Node.lower + Node.upper)
        #print("mid", mid)
        #print("Lower", Node.lower)
        #print("Upper", Node.upper)

        node_data = self.data_[Node.id]
        str_count = np.full(node_data.shape[0], '', dtype=object)

        for i in range(0, node_data.shape[1]):
            data_diff = node_data[:, i] + shift[i] - mid[i]
            id_large = np.argwhere(data_diff > 0)[:, 0]
            id_small = np.argwhere(data_diff <= 0)[:, 0]
            str_count[id_large] += '1'
            str_count[id_small] += '0'

        str_count_id = np.argsort(str_count)

        new_node_list = []
        start = 0

        # print(str_count)

        for i in range(0, str_count_id.shape[0] - 1):
            if (str_count[str_count_id[i]] != str_count[str_count_id[i + 1]]):
                str_now = str_count[str_count_id[i]]
                # print(str_now)
                large = [match.start() for match in re.finditer('1', str_now)]
                small = [match.start() for match in re.finditer('0', str_now)]
                large = np.array(large, dtype=int)
                small = np.array(small, dtype=int)
                lower_temp = Node.lower.copy()
                upper_temp = Node.upper.copy()
                lower_temp[large] = mid[large]
                upper_temp[small] = mid[small]
                D = math.sqrt(((upper_temp - lower_temp) ** 2).sum())
                if (start == i):
                    D = 0
                new_node = TreeNodeHST(lower_temp, upper_temp, D, Node.id[str_count_id[np.arange(start, i + 1)]])
                # print(start)
                start = i + 1
                new_node_list.append(new_node)
        temp = np.unique(str_count)
        # print("finished")

        str_now = str_count[str_count_id[start]]

        large = [match.start() for match in re.finditer('1', str_now)]
        small = [match.start() for match in re.finditer('0', str_now)]
        large = np.array(large, dtype=int)
        small = np.array(small, dtype=int)
        lower_temp = Node.lower.copy()
        upper_temp = Node.upper.copy()
        lower_temp[large] = mid[large]
        upper_temp[small] = mid[small]
        D = math.sqrt(((upper_temp - lower_temp) ** 2).sum())
        if (start == node_data.shape[0] - 1):
            D = 0
        new_node = TreeNodeHST(lower_temp, upper_temp, D, Node.id[str_count_id[np.arange(start, node_data.shape[0])]])
        new_node_list.append(new_node)

        temp = 0
        for i in range(0, len(new_node_list)):
            temp += new_node_list[i].id.shape[0]

        #print("Check Size", temp, "Real Size", Node.id.shape[0])

        return new_node_list

        # for i in range(self.data_.shape[1] - 1, 0, -1):
        #     data_diff = node_data[:, i] + shift[i] - mid[i]
        #     id_large = np.argwhere(data_diff > 0)[:, 0]
        #     large_lsit.append(id_large)
        #     id_array[id_large] += int(math.pow(2, mul))
        #     mul += 1

        # for i in range(0, id_array_sort.shape[0] - 1):
        #     if (id_array[id_array_sort[i]] != id_array[id_array_sort[i + 1]]):
        #         binary = id_array[id_array_sort[i]]
        #         binary = bin(binary)[2:]
        #         binary = binary.zfill(data.shape[1])
        #         binary = np.array(list(binary), dtype=int)
        #         lower_temp = Node.lower.copy()
        #         upper_temp = Node.upper.copy()
        #         small = np.argwhere(binary == 0)[:, 0]
        #         large = np.argwhere(binary == 1)[:, 0]
        #         lower_temp[large] = mid[large]
        #         upper_temp[small] = mid[small]
        #         D = math.sqrt(((upper_temp - lower_temp) ** 2).sum())
        #         new_node = TreeNodeHST(lower_temp, upper_temp, D, id_array_sort[np.arange(start, i + 1)])

        #         start = i + 1
        #         new_node_list.append(new_node)

    def shifted_quadtree(self, epsilon):
        k = self.clusters_
        "Create the root node"
        D = math.sqrt(self.data_.shape[1])
        id_range = np.array([i for i in range(0, self.data_.shape[0])], dtype=int)
        lower = np.zeros(self.data_.shape[1])
        upper = np.ones(self.data_.shape[1])
        root = TreeNodeHST(lower, upper, D, id_range)

        "Start Partitioning"
        queue = [root]
        nodes = 0
        shift = np.zeros(self.data_.shape[1])
        for i in range(0, shift.shape[0]):
            shift[i] += np.random.uniform(0, 0.5)
        while queue:
            node = queue.pop(0)
            nodes += 1
            new_node_list = self.CreateNode1(node, shift, self.clusters_, epsilon)
            node.childrenlist = new_node_list

            if (node.childrenlist):
                for i in range(0, len(node.childrenlist)):
                    if (node.childrenlist[i]):
                        queue.append(node.childrenlist[i])
        return root

    def range_cover(self, root):
        lda = self.lda_
        delta = self.delta_
        radius_list = set()
        radius_list.add(root.Dia)
        "Start Range Cover"
        queue = [root]
        while queue:
            node = queue.pop(0)
            childrenlist = node.childrenlist
            for i in range(0, len(childrenlist)):
                if (childrenlist[i].childrenlist):
                    queue.append(childrenlist[i])
                rH = node.Dia / lda
                rL = max(childrenlist[i].Dia / lda, node.Dia / delta)
                tL = math.ceil(math.log2(rL) / math.log2(1 + lda))
                tR = math.floor(math.log2(rH) / math.log2(1 + lda))
                # print("Check", rL, rH)
                for j in range(tL, tR + 2):
                    # print(j)
                    radius_list.add(math.pow((1 + lda), j + 1) * lda)

        radius_list = np.array(list(radius_list))
        return np.sort(radius_list)





# print("Ours")
# fair = fair_k_clustering(C,C,k, groups, alpha=alpha, beta=beta, delta=delta, epsilon=epsilon)
# print(fair)

axis_d_min = np.min(C)
axis_d_max = np.max(C)

scaler = (4 * (axis_d_max - axis_d_min))

for i in range(0, C.shape[1]):
    C[:, i] = (C[:, i] - axis_d_min) / (4.0 * (axis_d_max - axis_d_min)) + 1.0/4

tscaling = time.time()
Solver = MultiView(C, k, 0.195, 0.2, 0.01, "kcenter_with_outliers", [None])
radius_list = Solver.radius_list_

tscaling = time.time() - tscaling

itera = 5

cost_list1 = []
fair_list1 = []
time_list1 = []

for i in range(0, itera):

    print("Ours")
    fair1 = fair_k_clustering(C,C,k, groups, alpha=alpha, beta=beta, delta=delta, epsilon=epsilon, lb=radius_list, ub=radius_list, flag=1)
    cost_list1.append(fair1[2] *scaler)
    time_list1.append(fair1[1])
    fair_list1.append(fair1[0])
    
cost_list1 = np.array(cost_list1)
fair_list1 = np.array(fair_list1)
time_list1 = np.array(time_list1)


cost_list2 = []
fair_list2 = []
time_list2 = []

for i in range(0, itera):

    print("KFC")
    fair = fair_k_clustering(C,C,k, groups, alpha=alpha, beta=beta, delta=delta, epsilon=epsilon, lb=None, ub=None, flag=0)
    cost_list2.append(fair[2] *scaler)
    time_list2.append(fair[1])
    fair_list2.append(fair[0])

cost_list2 = np.array(cost_list2)
fair_list2 = np.array(fair_list2)
time_list2 = np.array(time_list2)


cost_list3 = []
fair_list3 = []
time_list3 = []

for i in range(0, itera):

    print("Greedy Algorithm")
    greedy = greedy_k_center(C,C,k, groups, alpha=alpha, beta=beta, delta=delta, epsilon=epsilon)
    cost_list3.append(greedy[2]*scaler)
    time_list3.append(greedy[1])
    fair_list3.append(greedy[0])

cost_list3 = np.array(cost_list3)
fair_list3 = np.array(fair_list3)
time_list3 = np.array(time_list3)

# if colors is None:
#     ahmadian = (-1, -1, -1)
# else:
#     print("Ahmadian et al Algorithm")
#     ahmadian = lp_ahmadian(C, colors, k,  max(colors)+1, alpha=alpha)
#     print(ahmadian)

# print("Bera et al")
# bera = lp_bera(k, alpha, beta, delta, dataset='c50', final_code = 'bera/')
# print(bera)

print('======Result======')
print("additive violation, time, cost:")

#print("Fair k-center",fairk)
print("KFC: ",  "cost", np.mean(cost_list2), "fairness", np.mean(fair_list2), "time", np.mean(time_list2))
print("greedy: ", "cost", np.mean(cost_list3), "fairness", np.mean(fair_list3), "time", np.mean(time_list3))
print("ours: ", "cost", np.mean(cost_list1), "fairness", np.mean(fair_list1), "time", np.mean(time_list1))

print("scaler", scaler)

# print("ahmadian: ", ahmadian)

# print("bera: ", bera)
