import numpy as np
import scipy.sparse as sp
import pandas as pd
from ucimlrepo import fetch_ucirepo
import torch
import math
from tqdm import tqdm
import pdb
import utils
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def two_way_clustering(numnodes, features, targets, round, args, verbose=False):
# get number of hyperedges
    feature_to_hyperedgeid_raw = {}
    num_hyperedges_for_each_feature = {}
    numedges = 0
    for i in range(features.shape[1]):
        feature_to_hyperedgeid_raw[i] = {}
        num_diff_attributes = 0
        for j in range(features.shape[0]):
            attr = features[j, i]
            if pd.isna(attr):
                continue
            if attr not in feature_to_hyperedgeid_raw[i]:
                feature_to_hyperedgeid_raw[i][attr] = numedges
                num_diff_attributes += 1
                numedges += 1
            num_hyperedges_for_each_feature[i] = num_diff_attributes
    

    # if only one hyperedge for a feature (all nodes have the same categorical feature), do not construct the hyperedge to prevent a hyperedge connecting all the nodes
    feature_to_hyperedgeid = {}
    numedges = 0
    for i in range(features.shape[1]):
        if num_hyperedges_for_each_feature[i] == 1:
            continue
        feature_to_hyperedgeid[i] = {}
        for hyperedge in feature_to_hyperedgeid_raw[i]:
            feature_to_hyperedgeid[i][hyperedge] = numedges
            numedges += 1


    # construct the hyperedges, with filtering the "connecting all"-hyperedges
    edgeid_to_nodes = {}
    numedges = 0
    for i in range(features.shape[1]):
        if num_hyperedges_for_each_feature[i] == 1:
            continue
        for j in range(num_hyperedges_for_each_feature[i]):
            edgeid_to_nodes[numedges] = []
            numedges += 1
        for j in range(features.shape[0]):
            attr = features[j, i]
            if pd.isna(attr):
                continue
            edgeid_to_nodes[feature_to_hyperedgeid[i][attr]].append(j)
    
    # from edgeid_to_nodes to nodeid_to_edges
    nodeid_to_edges = {}
    for edgeid in edgeid_to_nodes:
        for nodeid in edgeid_to_nodes[edgeid]:
            if nodeid not in nodeid_to_edges:
                nodeid_to_edges[nodeid] = []
            nodeid_to_edges[nodeid].append(edgeid)

    # all edges have weight 1
    edge_weights = {}
    
    for edgeid in range(numedges):
        edge_weights[edgeid] = 1


    actual_clusters = np.zeros(numnodes).astype(int)
    numclusters = 0
    target_to_clusterid = {}
    for i in range(targets.shape[0]):
        attr = targets[i, 0]
        if attr not in target_to_clusterid:
            target_to_clusterid[attr] = numclusters
            numclusters += 1
    for i in range(targets.shape[0]):
        attr = targets[i, 0]
        actual_clusters[i] = target_to_clusterid[attr]



    # edge-dependent vertex weight
    edvw = utils.create_edvw_kway(numedges, edgeid_to_nodes, actual_clusters, numclusters, args.m)

    # calculate total number of hyperedge-node relations
    if round == 1:
        total_hyperedge_node = 0
        for edgeid in edgeid_to_nodes:
            for nodeids in edgeid_to_nodes[edgeid]:
                total_hyperedge_node += 1
        print("Total Number of Nodes: ", numnodes)
        print("Total Number of Hyperedges: ", numedges)
        print("Total Number of Hyperedge-node Connections: ", total_hyperedge_node)


    # compute the matrices R, W, D_V, D_E, P
    R = sp.lil_matrix((numedges, numnodes))
    W = sp.lil_matrix((numnodes, numedges))
    Dv = sp.lil_matrix((numnodes, numnodes))
    Dv_inv = sp.lil_matrix((numnodes, numnodes))
    De = sp.lil_matrix((numedges, numedges))
    De_inv = sp.lil_matrix(((numedges, numedges)))

    for edge in edvw:
        for node in edvw[edge]:
            R[edge, node] = edvw[edge][node]

    for edge in edge_weights:
        for node in edgeid_to_nodes[edge]:
            W[node, edge] = edge_weights[edge]

    for node in range(numnodes):
        node_degree = 0
        for edge in nodeid_to_edges[node]:
            node_degree += edge_weights[edge]
        Dv[node, node] = node_degree
        Dv_inv[node, node] = 1/node_degree

    for edge in range(numedges):
        edge_delta = 0
        for nodeid in edgeid_to_nodes[edge]:
            edge_delta += edvw[edge][nodeid]
        De[edge, edge] = edge_delta
        De_inv[edge, edge] = 1/edge_delta

    P = Dv_inv.tocsr() @ W.tocsr() @ De_inv.tocsr() @ R.tocsr()
    PT = P.T
    # tensor for faster calculation
    PT_coo = PT.tocoo()
    values = PT_coo.data
    indices = np.vstack((PT_coo.row, PT_coo.col))
    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    shape = PT_coo.shape
    PT_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape)).to_dense().to(device)
    P = P.tolil()


# compute the stationary distribution
    phi = torch.from_numpy(np.ones(numnodes)/numnodes).to(device).float()
    for iter in range(500):
        phi = torch.mv(PT_tensor, phi)
        phi = phi/torch.sum(phi)

    # calculate the laplacian
    Pi = torch.diag(phi)
    phi_minushalf_inv_arr = np.zeros(numnodes)
    for i in range(numnodes):
        phi_minushalf_inv_arr[i] = 1/(math.sqrt(phi[i]))
    phi_minushalf_inv = torch.from_numpy(phi_minushalf_inv_arr).to(device).float()
    Pi_minushalf_inv = torch.diag(phi_minushalf_inv)
    L = Pi - (torch.matmul(Pi, torch.transpose(PT_tensor, 0, 1)) + torch.matmul(PT_tensor, Pi))/2

    # calculate the eigens and find the smallest non-negative eigenvalue
    if verbose:
        print("Calculate the eigenvalues and eigenvectors of the normalized Laplacian for round ", round)
    eigvalues, eigvectors = torch.linalg.eig(Pi_minushalf_inv.mm(L).mm(Pi_minushalf_inv))
    eigvalues = eigvalues.real  # real symmetric matrix
    eigvectors = eigvectors.real    # real symmetric matrix
    smallest_positive_eigvalue = np.inf
    spe_index = 0
    for i in range(1, len(eigvalues)):      # the smallest one is all-none partition
        if eigvalues[i] >= 0 and eigvalues[i] < smallest_positive_eigvalue:
            smallest_positive_eigvalue = eigvalues[i]
            spe_index = i

    best_eigenvector = eigvectors[:, spe_index]

    predicted_clusters = best_eigenvector.to('cpu').numpy()

    predicted_clusters_zero_one = np.zeros(numnodes).astype(int)
    for i in range(numnodes):
        if predicted_clusters[i] > 0:
            predicted_clusters_zero_one[i] = 1

    if round == 1:
        return predicted_clusters_zero_one, P, phi
    else:
        return predicted_clusters_zero_one



def main(args):
    np.random.seed(args.seed)

    numnodes, features, targets = utils.parse_dataset(args.dataset)
    actual_clusters = np.zeros(numnodes).astype(int)
    numclusters = 0
    target_to_clusterid = {}
    for i in range(targets.shape[0]):
        attr = targets[i, 0]
        if attr not in target_to_clusterid:
            target_to_clusterid[attr] = numclusters
            numclusters += 1
    for i in range(targets.shape[0]):
        attr = targets[i, 0]
        actual_clusters[i] = target_to_clusterid[attr]

    actual_clusters = actual_clusters + 1

    round = 1

    start = time.time()
    
    # count number of different classes
    target_to_count = {}
    for i in range(targets.shape[0]):
        attr = targets[i, 0]
        if attr not in target_to_count:
            target_to_count[attr] = 0
        target_to_count[attr] += 1

    max_round = len(target_to_count.keys())
    print("Number of classes: ", max_round)
    print("Number of instances for each class: ", target_to_count)
    predicted_clusters_int = np.ones(numnodes).astype(int)
    clustered_class_num = {}
    clustered_class_num[1] = numnodes
    for i in range(2, max_round + 1):
        clustered_class_num[i] = 0

    # clustering
    if not args.longtail:
        while round < max_round:
            # find the predicted cluster with the largest number of instances
            max_cluster_num = 0
            max_cluster_id = 0
            for i in range(1, max_round + 1):
                if clustered_class_num[i] > max_cluster_num:
                    max_cluster_num = clustered_class_num[i]
                    max_cluster_id = i
            # find the instances of the predicted cluster with the largest number of instances
            to_clustering = predicted_clusters_int == max_cluster_id
            # numnodes is the number of True in to_clustering
            numnodes_2way = np.sum(to_clustering)
            # features and targets are the instances of the predicted cluster with the largest number of instances
            features_2way = features[to_clustering, :]
            targets_2way = targets[to_clustering, :]

            if round == 1:
                predicted_clusters_zero_one, P, phi = two_way_clustering(numnodes_2way, features_2way, targets_2way, round, args, verbose=True)
            else:
                predicted_clusters_zero_one = two_way_clustering(numnodes_2way, features_2way, targets_2way, round, args, verbose=True)
            
            clustering_sub = np.where(to_clustering == True)[0]
            for index in range(len(clustering_sub)):
                if predicted_clusters_zero_one[index] == 1:
                    predicted_clusters_int[clustering_sub[index]] = round + 1
                    clustered_class_num[round + 1] += 1
                    clustered_class_num[max_cluster_id] -= 1

            # for debugging
            pci = predicted_clusters_int
            ccn = clustered_class_num

            round += 1
    else:
        while round < max_round:
            min_ncut = np.inf
            best_predicted_clusters_int = np.zeros(numnodes).astype(int)
            best_clustered_class_num = {}
            best_cluster_to_split = 0
            print("Round ", round)
            for cluster_to_split in range(1, round + 1):
                # find the instances of the predicted cluster with the largest number of instances
                to_clustering = predicted_clusters_int == cluster_to_split
                # numnodes is the number of True in to_clustering
                numnodes_2way = np.sum(to_clustering)
                # features and targets are the instances of the predicted cluster with the largest number of instances
                features_2way = features[to_clustering, :]
                targets_2way = targets[to_clustering, :]

                if round == 1:
                    predicted_clusters_zero_one, P, phi = two_way_clustering(numnodes_2way, features_2way, targets_2way, round, args)
                else:
                    predicted_clusters_zero_one = two_way_clustering(numnodes_2way, features_2way, targets_2way, round, args)
            
                
                clustering_sub = np.where(to_clustering == True)[0]
                # prevent variables to be changed
                predicted_clusters_int_copy = predicted_clusters_int.copy()
                clustered_class_num_copy = clustered_class_num.copy()
                for index in range(len(clustering_sub)):
                    if predicted_clusters_zero_one[index] == 1:
                        predicted_clusters_int_copy[clustering_sub[index]] = round + 1
                        clustered_class_num_copy[round + 1] += 1
                        clustered_class_num_copy[cluster_to_split] -= 1
                # calculate the ncut

                ncut = utils.calc_ncut_kway(predicted_clusters_int_copy, round + 1, P, phi.to('cpu').numpy())
                if ncut < min_ncut:
                    min_ncut = ncut
                    best_predicted_clusters_int = predicted_clusters_int_copy
                    best_clustered_class_num = clustered_class_num_copy
                    best_cluster_to_split = cluster_to_split

            predicted_clusters_int = best_predicted_clusters_int
            clustered_class_num = best_clustered_class_num
            
            # for debugging
            pci = predicted_clusters_int
            ccn = clustered_class_num

            round += 1


    end = time.time()

    # match the predicted clusters to the actual clusters

    print("Testing......")
    TPs, precisions, recalls, F1s, actual_counts, predicted_counts = utils.calc_acc_kway(predicted_clusters_int, actual_clusters, numclusters)
    print("Size of the predicted clusters: ", predicted_counts)
    print("Size of the actual clusters (mapped by population): ", actual_counts)
    print("True positive predictions of each predicted cluster: ", TPs)
    print("Precisions of each predicted cluster: ", np.round(precisions, 4))
    print("Recalls of each predicted cluster: ", np.round(recalls, 4))
    print("F1 score of the predicted clusters, compared to corresponded actual clusters: ", np.round(F1s, 4))
    print("weighted Sum of F1s: ", np.dot(actual_counts, F1s)/np.sum(actual_counts))
    print("NCut value of the global clustering: ", utils.calc_ncut_kway(predicted_clusters_int, numclusters, P, phi.to('cpu').numpy()))
    print("Execution time (in seconds): ", end - start)



if __name__ == '__main__':
    args = utils.params()
    utils.check_dataset(args.dataset)
    main(args)
