import torch
import numpy as np
from tqdm import tqdm
import pdb
from ucimlrepo import fetch_ucirepo


def calc_second_dominant_eigenvector(M: torch.Tensor):
    # Von Mises Iteration
    dominant_eigenvector = torch.from_numpy(np.ones(M.shape[0])/M.shape[0]).to(M.get_device()).float()
    for iter in range(500):
        dominant_eigenvector = torch.mv(M, dominant_eigenvector)
        dominant_eigenvector = dominant_eigenvector/torch.norm(dominant_eigenvector, 2)
    dominant_eigenvalue = (torch.dot(dominant_eigenvector, torch.mv(M, dominant_eigenvector)))/torch.dot(dominant_eigenvector, dominant_eigenvector)
    # Wielandt Deflation
    x = 1/(M.shape[0]*dominant_eigenvector)
    deflated_matrix = M - dominant_eigenvalue * torch.matmul(dominant_eigenvector.view(M.shape[0], 1), x.view(1, M.shape[0]))
    # compute the dominant eigenvector of the deflated matrix
    de_of_deflated = torch.from_numpy(np.ones(M.shape[0])/M.shape[0]).to(M.get_device()).float()
    for iter in range(2000):
        de_of_deflated = torch.mv(deflated_matrix, de_of_deflated)
        de_of_deflated = de_of_deflated/torch.norm(de_of_deflated, 2)
    second_eigenvalue = (torch.dot(de_of_deflated, torch.mv(M, de_of_deflated)))/torch.dot(de_of_deflated, de_of_deflated)
    second_eigenvector = (second_eigenvalue - dominant_eigenvalue) * de_of_deflated + dominant_eigenvalue * torch.dot(x, de_of_deflated) * dominant_eigenvector
    return second_eigenvector



def calc_ncut_kway(predicted, numclusters, P, phi):
    NCut = 0
    for cluster_index in range(1, numclusters + 1):
        vol = 0
        phiP = 0
        node_in_this_cluster = []
        node_in_other_cluster = []
        for nodeid in range(len(predicted)):
            node_in_other_cluster.append(nodeid)
        for nodeid in range(len(predicted)):
            if predicted[nodeid] == cluster_index:
                vol += phi[nodeid]
                node_in_this_cluster.append(nodeid)
                node_in_other_cluster[nodeid] = -1
            
        node_in_other_cluster = list(filter(lambda num: num != -1, node_in_other_cluster))
        for i in range(len(node_in_this_cluster)):
            for j in node_in_other_cluster:
                phiP += phi[node_in_this_cluster[i]] * P[node_in_this_cluster[i], j]
        NCut += phiP/vol
    return NCut



def calc_acc_kway(predicted, actual, numclusters):
    # greedily match the predicted clusters and actual clusters
    F1_matrix = np.zeros((numclusters, numclusters))
    for i in range(numclusters):
        for j in range(numclusters):
            predicted_cluster = np.where(predicted == i+1)[0]
            actual_cluster = np.where(actual == j+1)[0]
            intersection = np.intersect1d(predicted_cluster, actual_cluster)
            precision = len(intersection)/len(predicted_cluster)
            recall = len(intersection)/len(actual_cluster)
            if precision == 0 or recall == 0:
                F1_matrix[i, j] = 0
            else:
                F1_matrix[i, j] = 2*precision*recall/(precision+recall)
    cluster_mapping_from_predicted_to_actual = {}
    mapped_pre = np.zeros(numclusters).astype(int)
    mapped_act = np.zeros(numclusters).astype(int)
    mapped_count = 0
    counter = 0
    counter_max = numclusters*numclusters
    while mapped_count < numclusters:
        i_index = np.argmax(F1_matrix) // numclusters
        j_index = np.argmax(F1_matrix) % numclusters
        if mapped_pre[i_index] == 0 and mapped_act[j_index] == 0:
            cluster_mapping_from_predicted_to_actual[i_index] = j_index
            mapped_pre[i_index] = 1
            mapped_act[j_index] = 1
            F1_matrix[i_index, :] = 0
            mapped_count += 1
        counter += 1
        if counter > counter_max:
            break

    # randomly match the predicted clusters and actual clusters with 0 F1 score
    while mapped_count < numclusters:
        cluster_mapping_from_predicted_to_actual[np.where(mapped_pre == 0)[0][0]] = np.where(mapped_act == 0)[0][0]
        mapped_pre[np.where(mapped_pre == 0)[0][0]] = 1
        mapped_act[np.where(mapped_act == 0)[0][0]] = 1
        mapped_count += 1

        

    TPs = np.zeros(numclusters).astype(int)
    precisions = np.zeros(numclusters)
    recalls = np.zeros(numclusters)
    F1s = np.zeros(numclusters)
    actual_counts = np.zeros(numclusters).astype(int)
    predicted_counts = np.zeros(numclusters).astype(int)
    for i in range(numclusters):
        # count # nodes in the predicted cluster
        count_predicted = 0
        for j in range(len(predicted)):
            if predicted[j] == i+1:
                count_predicted += 1
        # count # nodes in the actual cluster
        count_actual = 0
        for j in range(len(actual)):
            if  actual[j] == cluster_mapping_from_predicted_to_actual[i]+1:
                count_actual += 1
        for j in range(len(predicted)):
            if predicted[j] == i+1:
                if actual[j] == cluster_mapping_from_predicted_to_actual[i]+1:
                    TPs[i] += 1
        precisions[i] = TPs[i]/count_predicted
        recalls[i] = TPs[i]/count_actual
        if TPs[i] == 0:
            F1s[i] = 0
        else:
            F1s[i] = (2*precisions[i]*recalls[i])/(precisions[i]+recalls[i])
        predicted_counts[i] = count_predicted
        actual_counts[i] = count_actual
    return TPs, precisions, recalls, F1s, actual_counts, predicted_counts





def uciml_mushroom_dataset():
    mushroom = fetch_ucirepo(id=73)
    numnodes = mushroom.data.features.shape[0]
    features = mushroom.data.features
    features = features.drop(columns=['stalk-root'])
    # features = features.drop(columns=['stalk-shape', 'gill-size'])
    features = features.values
    targets = mushroom.data.targets.values
    return numnodes, features, targets


def uciml_covertype_dataset(num):
    covertype = fetch_ucirepo(id=31)
    if num == '45':
        covertype_45 = covertype.data.features[np.logical_or(covertype.data.features['Cover_Type'] == 4, covertype.data.features['Cover_Type'] == 5)].values
        features = covertype_45[:, :-1].copy()
        targets = covertype_45[:, -1].copy() - 4
    elif num == '67':
        covertype_67 = covertype.data.features[np.logical_or(covertype.data.features['Cover_Type'] == 6, covertype.data.features['Cover_Type'] == 7)].values
        features = covertype_67[:, :-1].copy()
        targets = covertype_67[:, -1].copy() - 6
    numnodes = features.shape[0]
    # quantize the numerical features into categorical ones
    for i in range(features.shape[1] - 1):
        maximum = np.max(covertype.data.features.values[:, i])
        for j in range(features.shape[0]):
            features[j, i] = int(features[j, i]/maximum * 10)
    # add a new dimension for targets
    targets = targets[..., np.newaxis]
    return numnodes, features, targets


def uciml_spambase_dataset():
    spambase = fetch_ucirepo(id=94)
    features = spambase.data.features.values
    targets = spambase.data.targets.values
    numnodes = features.shape[0]
    # quantize the numerical features into categorical ones
    for i in range(features.shape[1] - 1):
        maximum = np.max(spambase.data.features.values[:, i])
        for j in range(features.shape[0]):
            features[j, i] = int(features[j, i]/maximum * 10)
    # add a new dimension for targets
    return numnodes, features, targets


def uciml_rice_dataset():
    rice = fetch_ucirepo(id=545)
    features = rice.data.features.values
    targets = rice.data.targets.values
    numnodes = features.shape[0]
    # quantize the numerical features into categorical ones
    for i in range(features.shape[1] - 1):
        maximum = np.max(rice.data.features.values[:, i])
        for j in range(features.shape[0]):
            features[j, i] = int(features[j, i]/maximum * 10)
    # add a new dimension for targets
    return numnodes, features, targets


def uciml_car_dataset():
    car = fetch_ucirepo(id=19)
    features = car.data.features[np.logical_or(car.data.targets['class'] == 'good', car.data.targets['class'] == 'vgood')].values
    targets = car.data.targets[np.logical_or(car.data.targets['class'] == 'good', car.data.targets['class'] == 'vgood')].values
    numnodes = features.shape[0]
    unacc = 0
    acc = 0
    good = 0
    vgood = 0
    for i in range(numnodes):
        if targets[i, 0] == 'unacc':
            unacc += 1
        if targets[i, 0] == 'acc':
            acc += 1
        if targets[i, 0] == 'good':
            good += 1
        if targets[i, 0] == 'vgood':
            vgood += 1
    # add a new dimension for targets
    return numnodes, features, targets


def uciml_digit24_dataset():
    dataset = fetch_ucirepo(id=80)
    features = dataset.data.features[np.logical_or(dataset.data.targets['class'] == 2, dataset.data.targets['class'] == 4)].values
    targets = dataset.data.targets[np.logical_or(dataset.data.targets['class'] == 2, dataset.data.targets['class'] == 4)].values
    numnodes = features.shape[0]
    # add a new dimension for targets
    return numnodes, features, targets


def uciml_zoo_dataset(num):
    dataset = fetch_ucirepo(id=111)
    if num == '47':
        features = dataset.data.features[np.logical_or(dataset.data.targets['type'] == 4, dataset.data.targets['type'] == 7)].values
        targets = dataset.data.targets[np.logical_or(dataset.data.targets['type'] == 4, dataset.data.targets['type'] == 7)].values
        numnodes = features.shape[0]
    elif num == '27':
        features = dataset.data.features[np.logical_or(dataset.data.targets['type'] == 2, dataset.data.targets['type'] == 7)].values
        targets = dataset.data.targets[np.logical_or(dataset.data.targets['type'] == 2, dataset.data.targets['type'] == 7)].values
        numnodes = features.shape[0]
    else:
        features = dataset.data.features.values
        targets = dataset.data.targets.values
        numnodes = features.shape[0]
        # add a new dimension for targets
        # get number of each animal class
        num_animals = np.zeros(7)
        for i in range(numnodes):
            for j in range(7):
                if targets[i, 0] == j + 1:
                    num_animals[j] += 1
    
    return numnodes, features, targets


def uciml_digit_dataset():
    dataset = fetch_ucirepo(id=80)
    features = dataset.data.features.values
    targets = dataset.data.targets.values
    numnodes = features.shape[0]
    # add a new dimension for targets
    return numnodes, features, targets


def uciml_letter_dataset(): # use --longtail
    dataset = fetch_ucirepo(id=59)
    features = dataset.data.features[np.logical_or(np.logical_or(dataset.data.targets['lettr'] == 'I', dataset.data.targets['lettr'] == 'C'), np.logical_or(dataset.data.targets['lettr'] == 'M', dataset.data.targets['lettr'] == 'L'))].values
    targets = dataset.data.targets[np.logical_or(np.logical_or(dataset.data.targets['lettr'] == 'I', dataset.data.targets['lettr'] == 'C'), np.logical_or(dataset.data.targets['lettr'] == 'M', dataset.data.targets['lettr'] == 'L'))].values
    numnodes = features.shape[0]
    # add a new dimension for targets
    return numnodes, features, targets


def uciml_drybean_dataset(): # use --longtail
    dataset = fetch_ucirepo(id=602)
    features = dataset.data.features.values
    targets = dataset.data.targets.values
    numnodes = features.shape[0]
    # quantize the numerical features into categorical ones
    for i in range(features.shape[1] - 1):
        maximum = np.max(dataset.data.features.values[:, i])
        for j in range(features.shape[0]):
            features[j, i] = int(features[j, i]/maximum * 10)
    # add a new dimension for targets
    targets = targets[..., np.newaxis]
    return numnodes, features, targets


def uciml_wine_dataset():
    dataset = fetch_ucirepo(id=186)
    features = dataset.data.features[(dataset.data.targets['quality'] == 5).values | (dataset.data.targets['quality'] == 6).values | (dataset.data.targets['quality'] == 7).values].values
    targets = dataset.data.targets[(dataset.data.targets['quality'] == 5).values | (dataset.data.targets['quality'] == 6).values | (dataset.data.targets['quality'] == 7).values].values
    numnodes = features.shape[0]
    # quantize the numerical features into categorical ones
    for i in range(features.shape[1]):
        maximum = np.max(dataset.data.features.values[:, i])
        for j in range(features.shape[0]):
            features[j, i] = int(features[j, i]/maximum * 20)
    # add a new dimension for targets
    return numnodes, features, targets



def params():
    import argparse

    parser = argparse.ArgumentParser('hyperclus_local_conductance')
    parser.add_argument('-s', '--seed', type=int, default=42,
                        help='Random seed')
    parser.add_argument('--patience', type=int, default=100)
    parser.add_argument('--dataset', type=str, default = 'mushroom')
    parser.add_argument('--rdub', type=int, default=50) # random upperbound
    parser.add_argument('--m', type=int, default=5) # edvw method/strategy
    parser.add_argument('--longtail', action='store_true')
    args = parser.parse_args()

    if args.dataset == 'zoo' or args.dataset == 'letter':
        args.longtail = True

    return args



def parse_dataset(dataset_name):
    if dataset_name == 'mushroom':
        numnodes, features, targets = uciml_mushroom_dataset()
    if dataset_name == 'covertype':
        numnodes, features, targets = uciml_covertype_dataset('45')
    if dataset_name == 'rice':  
        numnodes, features, targets = uciml_rice_dataset()
    if dataset_name == 'car':  
        numnodes, features, targets = uciml_car_dataset()
    if dataset_name == '24':  
        numnodes, features, targets = uciml_digit24_dataset()
    if dataset_name == 'zoo':  
        numnodes, features, targets = uciml_zoo_dataset(0)
    if dataset_name == 'digit':  
        numnodes, features, targets = uciml_digit_dataset()
    if dataset_name == 'letter':  
        numnodes, features, targets = uciml_letter_dataset()
    if dataset_name == 'wine':
        numnodes, features, targets = uciml_wine_dataset()
    return numnodes, features, targets



def check_dataset(dataset_name):
    if dataset_name.lower() not in ['mushroom', 'covertype', 'rice', 'car', '24', 'zoo', 'digit', 'letter', 'wine']:
        print('illegal dataset name: must be one of mushroom, covertype, rice, car, 24, zoo, zoo47, zoo27, digit, letter, wine')
        raise NotImplementedError



def create_edvw_kway(numedges, edgeid_to_nodes, actual_clusters, numclusters, m):
    if m == 5:
        edvw = {}
        for edgeid in range(numedges):
            edvw[edgeid] = {}
            nums = np.zeros(numclusters)
            for node in edgeid_to_nodes[edgeid]:
                nums[actual_clusters[node]] += 1

            for node in edgeid_to_nodes[edgeid]:
                edvw[edgeid][node] = nums[actual_clusters[node]]

    return edvw