# File: predict.py
# Description: Predict the dissimilarity (d_{uv}) for each edge (u,v) in the graph.

import os
import networkx as nx
import gurobipy as gp
import numpy as np
from sklearn.manifold import SpectralEmbedding


def generate_ground_truth(data_dir, num_nodes_list, p_list):
    """
    Generate ground-truth clusterings for synthetic datasets.
    """
    for num in num_nodes_list:
        for prob in p_list:
            groud_truth_file = os.path.join(data_dir, f"nodes_{num}/prob_{prob}/ground_truth.txt")
            with open(groud_truth_file, "w") as f:
                for i in range(num):
                    for j in range(i + 1, num):
                        if i < num / 2 and j >= num / 2:
                            f.write(f"{i} {j} {1}\n")
                        else:
                            f.write(f"{i} {j} {0}\n")        


def get_OPT_from_ILP(input_graph_file, opt_file):
    """
    Solve Integer Linear Program (ILP) to compute the optimal clustering.
    """
    graph = nx.read_edgelist(input_graph_file, create_using=nx.Graph(), nodetype=int)

    # Gurobi license parameters (fill in appropriately)
    options = {
        "WLSACCESSID": "",
        "WLSSECRET": "",
        "LICENSEID": None,
    }

    with gp.Env(params=options) as env, gp.Model(env=env) as model:
        model = gp.Model("CC_ILP_solver")

        # Add variables
        vertices_num = graph.number_of_nodes()
        x = {}
        for i in range(vertices_num):
            for j in range(i + 1, vertices_num):
                x[i, j] = model.addVar(vtype=gp.GRB.BINARY, name="x(%s,%s)" % (i, j))

        # Triangle inequality constraints
        for i in range(vertices_num):
            for j in range(i + 1, vertices_num):
                for k in range(j + 1, vertices_num):
                    if graph.has_edge(i, j) and graph.has_edge(j, k) and not graph.has_edge(i, k):
                        model.addConstr(x[i, j] + x[j, k] >= x[i, k], "c(%s,%s,%s)" % (i, j, k))
                    elif graph.has_edge(i, j) and graph.has_edge(i, k) and not graph.has_edge(j, k):
                        model.addConstr(x[i, j] + x[i, k] >= x[j, k], "c(%s,%s,%s)" % (i, j, k))
                    elif graph.has_edge(i, k) and graph.has_edge(j, k) and not graph.has_edge(i, j):
                        model.addConstr(x[i, k] + x[j, k] >= x[i, j], "c(%s,%s,%s)" % (i, j, k))

        # Objective
        obj = sum(x[i, j] for i in range(vertices_num) for j in range(i + 1, vertices_num) if graph.has_edge(i, j))
        obj += sum(1 - x[i, j] for i in range(vertices_num) for j in range(i + 1, vertices_num) if not graph.has_edge(i, j))
        model.setObjective(obj, sense=gp.GRB.MINIMIZE)

        model.optimize()

    with open(opt_file, "w") as f:
        for i in range(vertices_num):
            for j in range(i + 1, vertices_num):
                f.write(f"{i} {j} {x[i, j].x}\n")


def generate_perturbed_prediction(opt_file, perturbation, prediction_file):
    """
    Generate predictions by performing perturbations on optimal clusterings.
    """
    with open(opt_file, "r") as file_read, \
        open(prediction_file, "w") as file_write:

        for line in file_read:
            u, v, opt_val = line.strip().split()
            u, v, opt_val = int(u), int(v), float(opt_val)
            if opt_val == 1:
                file_write.write(f"{u} {v} {1 - perturbation}\n")
            else:
                file_write.write(f"{u} {v} {perturbation}\n")


def generate_spectral_embedding_prediction(input_graph_file, cluster_num, prediction_file): 
    """
    Generate predictions using spectral embedding.
    """
    graph = nx.Graph()
    graph = nx.read_edgelist(input_graph_file, nodetype=int, create_using=nx.Graph())
    
    se = SpectralEmbedding(n_components=cluster_num,  
                        affinity="nearest_neighbors",  
                        random_state=0)
    embeddings = se.fit_transform(nx.to_numpy_array(graph))
    
    with open(prediction_file, "w") as f:
        for u in range(graph.number_of_nodes()):
            for v in range(u + 1, graph.number_of_nodes()):
                similarity = np.dot(embeddings[u], embeddings[v]) / (np.linalg.norm(embeddings[u]) * np.linalg.norm(embeddings[v]))
                normalized_similarity = abs(similarity)
                f.write(f"{u} {v} {1 - normalized_similarity}\n")


if __name__ == "__main__":
    # Generate predictions for SBM datasets with p=0.95
    num_nodes_list = [100, 500, 1000, 1500, 2000, 2500]
    p_list = [0.95]
    pertubation_list = [0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28]
    
    generate_ground_truth("../data/sbm", num_nodes_list, p_list)
    for num in num_nodes_list:
        for prob in p_list:
            opt_file = os.path.join("../data/sbm", f"nodes_{num}/prob_{prob}/ground_truth.txt")
            for perturbation in pertubation_list:
                prediction_file = os.path.join("../data/sbm", f"nodes_{num}/prob_{prob}/prediction_gt_{perturbation}.txt")
                generate_perturbed_prediction(opt_file, perturbation, prediction_file)
                
    # Generate predictions for SBM datasets with p<=0.9
    num_nodes_list = [100]
    p_list = [0.9, 0.8, 0.7]
    pertubation_list = [0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28]
    
    for num in num_nodes_list:
        for prob in p_list:
            input_graph_file = os.path.join("../data/sbm", f"nodes_{num}/prob_{prob}/edges.txt")
            opt_file = os.path.join("../data/sbm", f"nodes_{num}/prob_{prob}/opt_solution.txt")
            get_OPT_from_ILP(input_graph_file, opt_file)

            for perturbation in pertubation_list:
                prediction_file = os.path.join("../data/sbm", f"nodes_{num}/prob_{prob}/prediction_opt_{perturbation}.txt")
                generate_perturbed_prediction(opt_file, perturbation, prediction_file)

    # Generate predictions for Facebook datasets
    datasets = ["facebook0", "facebook414", "facebook3980"]
    pertubation_list = [0.002, 0.004, 0.006, 0.008, 0.01, 0.012, 0.014, 0.016, 0.018, 0.02]
    for dataset in datasets:
        graph_file = os.path.join("../data/facebook", f"{dataset}/edges.txt")
        opt_file = os.path.join("../data/facebook", f"{dataset}/OPT_sol.txt")
        get_OPT_from_ILP(graph_file, opt_file)

        for perturbation in pertubation_list:
            prediction_file = os.path.join("../data/facebook", f"{dataset}/prediction_opt_{perturbation}.txt")
            generate_perturbed_prediction(opt_file, perturbation, prediction_file)

    # Generate predictions for the EmailCore dataset
    num_clusters_list = [600, 650, 700, 750, 800, 850, 900, 950, 1000]
    for cluster_num in num_clusters_list:
        prediction_file = os.path.join("../data/emailcore", f"prediction_se_{cluster_num}.txt")
        generate_spectral_embedding_prediction("../data/emailcore/email-Eu-core.txt", cluster_num, prediction_file)
    
    # Generate predictions for the LastFM dataset
    num_clusters_list = [5500, 6000, 6500]
    for cluster_num in num_clusters_list:
        prediction_file = os.path.join("../data/lastfm", f"prediction_se_{cluster_num}.txt")
        generate_spectral_embedding_prediction("../data/lastfm/edges.txt", cluster_num, prediction_file)
