# File: main.py
# Description: Main script for evaluating our proposed algorithms on the synthetic (SBM), Facebook, EmailCore, and LastFM datasets.


import os
import numpy as np
from insertion_pivot_alg import *
from our_insertion_alg import *
from icml21_alg import *
from our_dynamic_alg import *


def run_insertion_pivot_alg(num_of_nodes, streaming_file, k_list, output_dir, repeated_times):
    """
    Run the (3 + eps)-approximation algorithm in insertion-only streams.
    """
    cost_dict = {}
    for k in k_list:
        cost = []
        for time in range(repeated_times):
            output_file = os.path.join(output_dir, f"pivotAlg_output_k_{k}.txt")

            random.seed(time)
            np.random.seed(time)
            
            alg = PivotAlgorithm(num_of_nodes, k)
            alg.run(streaming_file, output_file)
            cost.append(alg_pay_cost(num_of_nodes, streaming_file, output_file))
        cost_dict[k] = (np.mean(cost), np.std(cost))
    return cost_dict

def run_our_insertion_alg(num_of_nodes, streaming_file, ground_truth_file, prediction_dir, pertubation_list, num_clusters_list, k_list, output_dir, repeated_times):
    """ 
    Run our proposed learning-augmented algorithm for complete graphs in insertion-only streams.
    """
    if num_clusters_list == None:
        beta_list = {}
        cost_dict = {}
        for k in k_list:
            beta_list[k] = {}
            cost_dict[k] = {}
            for pertubation in pertubation_list:
                prediction_file = os.path.join(prediction_dir, f"prediction_gt_{pertubation}.txt")
                if not os.path.exists(prediction_file):
                    prediction_file = os.path.join(prediction_dir, f"prediction_opt_{pertubation}.txt")

                predictor = load_prediction(prediction_file)
                beta_list[k][pertubation] = calculate_beta(streaming_file, ground_truth_file, prediction_file)

                cost = []
                for time in range(repeated_times):
                    output_file1 = os.path.join(output_dir, f"ourAlg_output1_k_{k}_pertubation_{pertubation}.txt")
                    output_file2 = os.path.join(output_dir, f"ourAlg_output2_k_{k}_pertubation_{pertubation}.txt")

                    random.seed(time)
                    np.random.seed(time)
                
                    alg = OurInsertionAlgorithm(num_of_nodes, k, predictor)
                    alg.run(streaming_file, output_file1, output_file2)
                    cost1 = alg_pay_cost(num_of_nodes, streaming_file, output_file1)
                    cost2 = alg_pay_cost(num_of_nodes, streaming_file, output_file2)
                    cost.append(min(cost1, cost2))
                cost_dict[k][pertubation] = (np.mean(cost), np.std(cost))
        return beta_list, cost_dict
    else:
        cluster_num = {}
        cost_dict = {}
        for k in k_list:
            cluster_num[k] = {}
            cost_dict[k] = {}
            for num in num_clusters_list:
                prediction_file = os.path.join(prediction_dir, f"prediction_se_{num}.txt")

                predictor = load_prediction(prediction_file)
                cluster_num[k][num] = num

                cost = []
                for time in range(repeated_times):
                    output_file1 = os.path.join(output_dir, f"ourAlg_output1_k_{k}_cluster_num_{num}.txt")
                    output_file2 = os.path.join(output_dir, f"ourAlg_output2_k_{k}_cluster_num_{num}.txt")

                    random.seed(time)
                    np.random.seed(time)
                
                    alg = OurInsertionAlgorithm(num_of_nodes, k, predictor)
                    alg.run(streaming_file, output_file1, output_file2)
                    cost1 = alg_pay_cost(num_of_nodes, streaming_file, output_file1)
                    cost2 = alg_pay_cost(num_of_nodes, streaming_file, output_file2)
                    cost.append(min(cost1, cost2))
                cost_dict[k][num] = (np.mean(cost), np.std(cost))
        return cluster_num, cost_dict
    
def run_icml21_alg(num_of_nodes, streaming_file, output_dir, repeated_times):
    """ 
    Run the ICML'21 agreement decomposition algorithm.
    """
    step = 0.1
    best_beta = 0
    best_lambda = 0
    best_cost = (1000000, 0)

    for beta in range(0, int(1 / step) + 1):
        for lambda_ in range(0, int(1 / step) + 1):
            beta_value = beta * step
            lambda_value = lambda_ * step

            cost = []
            repeated_times = 1
            for time in range(repeated_times):
                output_file = os.path.join(output_dir, "icml21_output.txt")

                random.seed(time)
                np.random.seed(time)
            
                alg = AgreementBasedAlgorithm(num_of_nodes, beta_value, lambda_value)
                alg.run(streaming_file, output_file)
                cost.append(alg_pay_cost(num_of_nodes, streaming_file, output_file))
            if(np.mean(cost) < best_cost[0]):
                best_beta = beta_value
                best_lambda = lambda_value
                best_cost = (np.mean(cost), np.std(cost))
    return best_beta, best_lambda, best_cost

def eval_synthetic_datasets(data_dir, output_dir, repeated_times):
    """ 
    Evaluate our proposed algorithm in insertion-only streams on synthetic datasets.
    """
    num_of_nodes = [1000]
    probability_list = [0.95]
    # SBM n = 1000, k = 20, pertubation_list = [0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15]
    pertubation_list = [0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15]

    statistics_file = os.path.join(output_dir, f"SBM_statistics.csv")
    with open(os.path.join(statistics_file), "w") as f:
        f.write(f"num_of_nodes, ground_truth_cost, icml21_beta, icml21_lambda, icml21_alg_mean, icml21_alg_std, k, pivot_alg_mean, pivot_alg_std, our_alg_beta, our_alg_mean, our_alg_std\n")

    for num in num_of_nodes:
        for probability in probability_list:
            print(f"\n\n Evaluating on the SBM dataset with {num} nodes.")
            streaming_file = os.path.join(data_dir, f"nodes_{num}/prob_{probability}/streaming.txt")
            dir = os.path.join(output_dir, f"sbm/nodes_{num}/prob_{probability}")
            os.makedirs(dir, exist_ok=True)  

            ground_truth_file = os.path.join(data_dir, f"nodes_{num}/prob_{probability}/ground_truth.txt")
            ground_truth_cost = calculate_cost(streaming_file, ground_truth_file)
            print(f"Ground truth cost calculated.")

            icml21_beta, icml21_lambda, icml21_cost = run_icml21_alg(num, streaming_file, dir, repeated_times)
            print(f"ICML21 cost calculated.")

            k_list = [20]
            pivot_alg_cost = run_insertion_pivot_alg(num, streaming_file, k_list, dir, repeated_times)
            print(f"PivotAlg cost calculated.")

            prediction_dic = os.path.join(data_dir, f"nodes_{num}/prob_{probability}")
            our_alg_beta, our_alg_cost = run_our_insertion_alg(num, streaming_file, ground_truth_file, prediction_dic, pertubation_list, None, k_list, dir, repeated_times)
            print(f"OurInsertionAlg cost calculated.")

            with open(os.path.join(statistics_file), "a") as f:
                for k in k_list:
                    for pertubation in pertubation_list:
                        f.write(f"{num}, {ground_truth_cost}, {icml21_beta}, {icml21_lambda}, {icml21_cost[0]}, {icml21_cost[1]}, {k}, {pivot_alg_cost[k][0]}, {pivot_alg_cost[k][1]}, {our_alg_beta[k][pertubation]}, {our_alg_cost[k][pertubation][0]}, {our_alg_cost[k][pertubation][1]}\n")
            
            
def eval_facebook_datasets(data_dir, output_dir, repeated_times):
    """ 
    Evaluate our proposed algorithm in insertion-only streams on Facebook datasets.
    """
    datasets = ["facebook3980"]
    # facebook0, k=50, pertubation_list = [0.002, 0.004, 0.006, 0.008, 0.01, 0.012, 0.014, 0.016, 0.018, 0.02]
    # facebook414, k=30, pertubation_list = [0.005, 0.006, 0.007, 0.008, 0.009, 0.01, 0.011, 0.012, 0.013, 0.014]
    # facebook3980, k=10, pertubation_list = [0.005, 0.006, 0.007, 0.008, 0.009, 0.01, 0.011, 0.012, 0.013, 0.014]
    pertubation_list = [0.002, 0.004, 0.006, 0.008, 0.01, 0.012, 0.014, 0.016, 0.018, 0.02]

    statistics_file = os.path.join(output_dir, f"facebook_statistics.csv")
    with open(os.path.join(statistics_file), "w") as f:
        f.write(f"dataset, num_of_nodes, opt_cost, icml21_beta, icml21_lambda, icml21_alg_mean, icml21_alg_std, k, pivot_alg_mean, pivot_alg_std, our_alg_beta, our_alg_mean, our_alg_std\n")

    for dataset in datasets:
        print(f"\n\n Evaluating on the {dataset} dataset.")
        streaming_file = os.path.join(data_dir, f"{dataset}/streaming.txt")
        dir = os.path.join(output_dir, f"facebook/{dataset}")
        os.makedirs(dir, exist_ok=True)  

        opt_file = os.path.join(data_dir, f"{dataset}/OPT_sol.txt")
        opt_cost = calculate_cost(streaming_file, opt_file)
        print(f"Optimal cost calculated.")

        with open(streaming_file, "r") as f:
            num = int(f.readline())

        icml21_beta, icml21_lambda, icml21_cost = run_icml21_alg(num, streaming_file, dir, repeated_times)
        print(f"ICML21 cost calculated.")

        k_list = [10]
        pivot_alg_cost = run_insertion_pivot_alg(num, streaming_file, k_list, dir, repeated_times)
        print(f"PivotAlg cost calculated.")

        prediction_dic = os.path.join(data_dir, f"{dataset}")
        our_alg_beta, our_alg_cost = run_our_insertion_alg(num, streaming_file, opt_file, prediction_dic, pertubation_list, None, k_list, dir, repeated_times)
        print(f"OurInsertionAlg cost calculated.")

        with open(os.path.join(statistics_file), "a") as f:
            for k in k_list:
                for pertubation in pertubation_list:
                    f.write(f"{dataset}, {num}, {opt_cost}, {icml21_beta}, {icml21_lambda}, {icml21_cost[0]}, {icml21_cost[1]}, {k}, {pivot_alg_cost[k][0]}, {pivot_alg_cost[k][1]}, {our_alg_beta[k][pertubation]}, {our_alg_cost[k][pertubation][0]}, {our_alg_cost[k][pertubation][1]}\n")


def eval_email_lastfm_datasets(data_dir, output_dir, repeated_times):
    """ 
    Evaluate our proposed algorithm in insertion-only streams on EmailCore and LastFM datasets.
    """
    # datasets = ["emailcore", "lastfm"]
    # EmailCore: k = 20, num_clusters_list = [600, 650, 700, 750, 800, 850, 900, 950, 1000]
    # LastFM: k = 100, num_clusters_list = [2000, 2500, 3000, 3500, 4000]
    datasets = ["lastfm"]
    num_clusters_list = [3000]

    statistics_file = os.path.join(output_dir, f"email_lastfm_statistics.csv")
    with open(os.path.join(statistics_file), "w") as f:
        f.write(f"dataset, num_of_nodes, icml21_beta, icml21_lambda, icml21_alg_mean, icml21_alg_std, k, pivot_alg_mean, pivot_alg_std, cluster_num, our_alg_mean, our_alg_std\n")

    for dataset in datasets:
        print(f"\n\n Evaluating on the {dataset} dataset.")
        streaming_file = os.path.join(data_dir, f"{dataset}/streaming.txt")
        dir = os.path.join(output_dir, f"{dataset}")
        os.makedirs(dir, exist_ok=True)  

        with open(streaming_file, "r") as f:
            vertices_num = int(f.readline())

        icml21_beta, icml21_lambda, icml21_cost = run_icml21_alg(vertices_num, streaming_file, dir, repeated_times)
        print(f"ICML21 cost calculated.")
        
        k_list = [100]
        pivot_alg_cost = run_insertion_pivot_alg(vertices_num, streaming_file, k_list, dir, repeated_times)
        print(f"PivotAlg cost calculated.")
        
        prediction_dic = os.path.join(data_dir, f"{dataset}")
        cluster_num, our_alg_cost = run_our_insertion_alg(vertices_num, streaming_file, None, prediction_dic, None, num_clusters_list, k_list, dir, repeated_times)
        print(f"OurInsertionAlg cost calculated.")
        for k in k_list:
            for num in num_clusters_list:
                print(f"{k}, {cluster_num[k][num]}, {our_alg_cost[k][num][0]}, {our_alg_cost[k][num][1]}\n")

        with open(os.path.join(statistics_file), "a") as f:
            for k in k_list:
                for num in num_clusters_list:
                    f.write(f"{dataset}, {vertices_num}, {icml21_beta}, {icml21_lambda}, {icml21_cost[0]}, {icml21_cost[1]}, {k}, {pivot_alg_cost[k][0]}, {pivot_alg_cost[k][1]}, {cluster_num[k][num]}, {our_alg_cost[k][num][0]}, {our_alg_cost[k][num][1]}\n")


def run_dynamic_alg(data_dir, prediction_dir, output_dir, repeated_times):
    """
    Evaluate our proposed algorithm in dynamic streams.
    
    Note: 
    This function can only evaluate on **one dataset at a time**.
    You need to **comment or uncomment** the corresponding code blocks manually
    depending on which dataset you want to evaluate.
    """
    # SBM synthetic datasets
    # perturbation_list = [0.1, 0.12, 0.14, 0.16, 0.18, 0.2, 0.22, 0.24, 0.26, 0.28]
    # num_of_nodes = [100]
    
    # Facebook datasets
    # datasets = ["facebook414"]
    # perturbation_list = [0.005, 0.006, 0.007, 0.008, 0.009, 0.01, 0.011, 0.012, 0.013, 0.014]

    # datasets = ["emailcore"]
    # num_clusters_list = [600, 650, 700, 750, 800, 850, 900, 950, 1000]

    datasets = ["lastfm"]
    num_clusters_list = [7000, 7100, 7200, 7300, 7400, 7500, 7600]

    # statistics_file = f"./results/statistics_dynamic_vary_bata.csv"
    statistics_file = f"./results/dynamic_lastfm.csv"
    with open(statistics_file, "w") as f:
        # f.write(f"dataset, nodes, OPT_cost, repeated_times, perturbation, beta, approx_ratio, average_cost, deviation\n")
        f.write(f"dataset, nodes, repeated_times, cluster_num, approx_ratio, average_cost, deviation\n")

    # for num in num_of_nodes:
    for dataset in datasets:
        # graph_file = data_dir + f"sbm/nodes_{num}/prob_0.7/edges.txt"
        # graph_file = data_dir + f"facebook/{dataset}/edges.txt"
        # graph_file = data_dir + f"{dataset}/email-Eu-core.txt"
        graph_file = data_dir + f"{dataset}/lastfm_asia_edges.csv"
        original_graph = nx.read_edgelist(graph_file, create_using=nx.Graph, nodetype=int)
        degrees = [original_graph.degree[i] for i in range(original_graph.number_of_nodes())]

        # streaming_file = data_dir + f"sbm/nodes_{num}/prob_0.7/streaming.txt"
        # streaming_file = data_dir + f"facebook/{dataset}/streaming.txt"
        streaming_file = data_dir + f"{dataset}/streaming.txt"
        with open(streaming_file, "r") as f:
            lines = f.readline()
            num = int(lines.strip())
        
        # ground_truth_file = data_dir + f"sbm/nodes_{num}/prob_0.95/ground_truth.txt"
        # ground_truth_cost = calculate_cost(streaming_file, ground_truth_file)
        # OPT_file = data_dir + f"sbm/nodes_{num}/prob_0.7/opt_solution.txt"
        # OPT_cost = calculate_cost(streaming_file, OPT_file)

        icml21_beta, icml21_lambda, icml21_cost = run_icml21_alg(num, streaming_file, output_dir, repeated_times)
        print(f"ICML21 cost calculated.")
        print(f"ICML21 beta: {icml21_beta}, ICML21 lambda: {icml21_lambda}, {icml21_cost[0]}, {icml21_cost[1]}\n")
        with open(statistics_file, "a") as f:
            # f.write(f"synthetic, {num}, {OPT_cost}, 1, N/A, N/A, 701, {icml21_cost[0]}, {icml21_cost[1]}\n")
            f.write(f"{dataset}, {num}, 1, N/A, 701, {icml21_cost[0]}, {icml21_cost[1]}\n")

        approx_ratio = 3
        predictor = None

        cost_3 = []
        for time in range(repeated_times):
            # os.makedirs(output_dir + f"sbm/nodes_{num}/prob_0.7", exist_ok=True)
            os.makedirs(output_dir + f"{dataset}", exist_ok=True)
            output_file = output_dir + f"{dataset}/dynamic_approx_{approx_ratio}.txt"
            # output_file = output_dir + f"facebook/{dataset}/dynamic_approx_{approx_ratio}.txt"
            # output_file = output_dir + f"{dataset}/dynamic_approx_{approx_ratio}.txt"

            random.seed(time)
            np.random.seed(time)
            alg = OurDynamicAlgorithm(num, degrees, predictor)
            alg.run_approx_algorithm(original_graph, streaming_file, approx_ratio, output_file)
            cost_3.append(alg_pay_cost(num, streaming_file, output_file))
            
        with open(statistics_file, "a") as f:
            # f.write(f"synthetic, {num}, {ground_truth_cost}, {repeated_times}, N/A, N/A, {approx_ratio}, {np.mean(cost_3)}, {np.std(cost_3)}\n")
            # f.write(f"synthetic, {num}, {OPT_cost}, {repeated_times}, N/A, N/A, {approx_ratio}, {np.mean(cost_3)}, {np.std(cost_3)}\n")
            f.write(f"{dataset}, {num}, {repeated_times}, N/A, {approx_ratio}, {np.mean(cost_3)}, {np.std(cost_3)}\n")

        approx_ratio = 2.06
        # for perturbation in perturbation_list:
        for cluster_num in num_clusters_list:
            # prediction_file = prediction_dir + f"sbm/nodes_{num}/prob_0.7/prediction_opt_{perturbation}.txt"
            # prediction_file = prediction_dir + f"facebook/{dataset}/prediction_opt_{perturbation}.txt"
            prediction_file = prediction_dir + f"{dataset}/prediction_se_{cluster_num}.txt"
            predictor = load_prediction(prediction_file)
            # beta = calculate_beta(streaming_file, ground_truth_file, prediction_file)
            # beta = calculate_beta(streaming_file, OPT_file, prediction_file)

            cost_2 = []
            for time in range(repeated_times):
                # output_file = output_dir + f"sbm/nodes_{num}/prob_0.7/dynamic_num_{num}_approx_{approx_ratio}_perturbation_{perturbation}.txt"
                # output_file = output_dir + f"facebook/{dataset}/dynamic_approx_{approx_ratio}_perturbation_{perturbation}.txt"
                output_file = output_dir + f"{dataset}/approx_{approx_ratio}_cluster_num_{cluster_num}.txt"

                random.seed(time)
                np.random.seed(time)
                alg = OurDynamicAlgorithm(num, degrees, predictor)
                alg.run_approx_algorithm(original_graph, streaming_file, approx_ratio, output_file)
                cost_2.append(alg_pay_cost(num, streaming_file, output_file))

            with open(statistics_file, "a") as f:
                # f.write(f"synthetic, {num}, {ground_truth_cost}, {repeated_times}, {perturbation}, {beta}, {approx_ratio}, {np.mean(cost_2)}, {np.std(cost_2)}\n")
                # f.write(f"synthetic, {num}, {OPT_cost}, {repeated_times}, {perturbation}, {beta}, {approx_ratio}, {np.mean(cost_2)}, {np.std(cost_2)}\n")
                f.write(f"{dataset}, {num}, {repeated_times}, {cluster_num}, {approx_ratio}, {np.mean(cost_2)}, {np.std(cost_2)}\n")
        
            cost_min = [min(a, b) for a, b in zip(cost_3, cost_2)]
            with open(statistics_file, "a") as f:
                    # f.write(f"synthetic, {num}, {ground_truth_cost}, {repeated_times}, {perturbation}, {beta}, min, {np.mean(cost_min)}, {np.std(cost_min)}\n")
                    # f.write(f"synthetic, {num}, {OPT_cost}, {repeated_times}, {perturbation}, {beta}, min, {np.mean(cost_min)}, {np.std(cost_min)}\n")
                    f.write(f"{dataset}, {num}, {repeated_times}, {cluster_num}, min, {np.mean(cost_min)}, {np.std(cost_min)}\n")


if __name__ == "__main__":
    output_dir = "./results/"
    repeated_times = 20

    # Evaluate our proposed algorithm in dynamic streams on various datasets
    run_dynamic_alg("../data/", "../data/", output_dir, repeated_times)
    
    # Evaluate our proposed algorithm in insertion-only streams on various datasets
    data_dir = "../data/sbm"
    eval_synthetic_datasets(data_dir, output_dir, repeated_times)

    data_dir = "../data/facebook"
    eval_facebook_datasets(data_dir, output_dir, repeated_times)

    data_dir = "../data/"
    eval_email_lastfm_datasets(data_dir, output_dir, repeated_times)