import louvain.community as community_louvain
from stellargraph.core.graph import StellarGraph
import stellargraph as sg
import numpy as np
from sklearn import preprocessing
import pandas as pd
from src.utils import config


def louvain_graph_cut(whole_graph:StellarGraph,node_subjects, graph_split_seed):
    edges = np.copy(whole_graph.edges())
    df = pd.DataFrame()
    df['source'] = [edge[0] for edge in edges]
    df['target'] = [edge[1] for edge in edges]
    G = StellarGraph.to_networkx(whole_graph)

    partition = community_louvain.best_partition(G, random_state=graph_split_seed)

    groups = list(set(partition.values()))
    # print(groups)
    partition_groups = {group_i: [] for group_i in groups}

    for key, value in partition.items():
        partition_groups[value].append(key)

    group_len_max = len(list(whole_graph.nodes())) // config.num_owners - config.delta
    new_group_id = max(groups) + 1
    for group_i in list(groups):  # Use a copy of groups to iterate
        while len(partition_groups[group_i]) > group_len_max:
            long_group = partition_groups[group_i]
            partition_groups[group_i] = long_group[:group_len_max]
            partition_groups[new_group_id] = long_group[group_len_max:]
            groups.append(new_group_id)
            new_group_id += 1

    # print(groups)

    len_dict = {group_i: len(partition_groups[group_i]) for group_i in groups}
    sort_len_dict = dict(sorted(len_dict.items(), key=lambda item: item[1], reverse=True))

    owner_node_ids = {owner_id: [] for owner_id in range(config.num_owners)}
    owner_nodes_len = len(list(G.nodes())) // config.num_owners

    for group_i, group_size in sort_len_dict.items():
        available_owners = [owner for owner in range(config.num_owners)
                            if len(owner_node_ids[owner]) < owner_nodes_len + config.delta]

        if not available_owners:
            # If no owner can accommodate the entire group, split it
            for node in partition_groups[group_i]:
                min_owner = min(range(config.num_owners), key=lambda x: len(owner_node_ids[x]))
                owner_node_ids[min_owner].append(node)
        else:
            # Assign to the owner with the least nodes
            min_owner = min(available_owners, key=lambda x: len(owner_node_ids[x]))
            owner_node_ids[min_owner].extend(partition_groups[group_i])

    for owner_i in owner_node_ids.keys():
        print('nodes len for '+str(owner_i)+' = '+str(len(owner_node_ids[owner_i])))

    nodes_id = whole_graph.nodes()
    local_G = []
    local_node_subj = []
    local_nodes_ids = []
    target_encoding = preprocessing.LabelBinarizer()
    target = target_encoding.fit_transform(node_subjects)
    local_target = []
    subj_set = list(set(node_subjects.values))
    local_node_subj_0=[]
    for owner_i in range(config.num_owners):
        partition_i = owner_node_ids[owner_i]
        locs_i = whole_graph.node_ids_to_ilocs(partition_i)
        sbj_i = node_subjects.copy(deep=True)
        sbj_i.values[:] = "" if node_subjects.values[0].__class__ == str else 0
        sbj_i.values[locs_i] = node_subjects.values[locs_i]
        local_node_subj_0.append(sbj_i)
    count=[]
    for owner_i in range(config.num_owners):
        count_i={k:[] for k in subj_set}
        sbj_i=local_node_subj_0[owner_i]
        for i in sbj_i.index:
            if sbj_i[i]!=0 and sbj_i[i]!="":
                count_i[sbj_i[i]].append(i)
        count.append(count_i)
    for k in subj_set:
        for owner_i in range(config.num_owners):
            if len(count[owner_i][k])<2:
                for j in range(config.num_owners):
                    if len(count[j][k])>2:
                        id=count[j][k][-1]
                        count[j][k].remove(id)
                        count[owner_i][k].append(id)
                        owner_node_ids[owner_i].append(id)
                        owner_node_ids[j].remove(id)
                        j=config.num_owners



    for owner_i in range(config.num_owners):
        partition_i =owner_node_ids[owner_i]
        locs_i = whole_graph.node_ids_to_ilocs(partition_i)
        sbj_i = node_subjects.copy(deep=True)
        sbj_i.values[:] = "" if node_subjects.values[0].__class__ == str else 0
        sbj_i.values[locs_i] = node_subjects.values[locs_i]

        local_node_subj.append(sbj_i)
        local_target_i = np.zeros(target.shape, np.int32)
        local_target_i[locs_i] += target[locs_i]
        local_target.append(local_target_i)
        local_nodes_ids.append(partition_i)

        feats_i = np.zeros(whole_graph.node_features().shape)
        feats_i[locs_i] = feats_i[locs_i] + whole_graph.node_features()[locs_i]

        nodes = sg.IndexedArray(feats_i, nodes_id)
        graph_i = StellarGraph(nodes=nodes, edges=df)
        local_G.append(graph_i)


    return local_G, local_node_subj, local_target, local_nodes_ids