# File: utils.py
# Description: This file provides utility classes and functions used in the algorithms.


import heapq
import random
import networkx as nx
import numpy as np


class TruncatedPriorityQueue:
    """
    A priority queue that maintains only the k elements with the highest priority.
    """
    def __init__(self, maxsize):
        self.maxsize = maxsize
        self.queue = []

    def push(self, item, priority):
        if len(self.queue) < self.maxsize:
            heapq.heappush(self.queue, (-priority, item))
        else:
            heapq.heappushpop(self.queue, (-priority, item))

    def pop_all(self):
        list = []
        while self.queue.__len__() > 0:
            list.append(heapq.heappop(self.queue)[1])
        return list[::-1]
    
    def num_elements(self):
        return len(self.queue)
    

class Vertexpermutator:
    """
    A utility class for generating a random permutation of vertex IDs.
    """
    def __init__(self, vertices_num):
        self.permutation = np.random.permutation(np.arange(vertices_num))
        self.unpermutation = [0] * (vertices_num + 1)
        for i in range(vertices_num):
            self.unpermutation[self.permutation[i]] = i

    def get_permuted_vertexID(self, vertexID):
        return self.permutation[vertexID]   

    def get_unpermuted_vertexID(self, permuted_vertexID):
        return self.unpermutation[permuted_vertexID]
    

def post_process(vertices_num, priority_queue, permutator, output_file):
    """
    Select a pivot for each vertex.
    """
    pivots = []
    with open(output_file, "w") as f:
        for i in range(vertices_num):
            neighbors = priority_queue[i].pop_all()   
            found_pivot = False                       
            for j in neighbors:
                if j in pivots or i == j:
                    found_pivot = True
                    f.write(f"{permutator.get_unpermuted_vertexID(i)} {permutator.get_unpermuted_vertexID(j)}\n")
                    if i == j:
                        pivots.append(j)
                    break
            if not found_pivot:
                f.write(f"{permutator.get_unpermuted_vertexID(i)} {permutator.get_unpermuted_vertexID(i)}\n")   


def load_prediction(prediction_file):
    """
    Load the prediction file and return a NetworkX graph with prediction values as edge attributes.
    """
    predictor = nx.Graph()
    with open(prediction_file, "r") as f:
        for line in f:
            line = line.strip().split()
            u, v, pred_val = int(line[0]), int(line[1]), float(line[2])
            predictor.add_edge(u, v, prediction=pred_val)
    return predictor


def alg_pay_cost(vertices_num, streaming_file, output_file):
    """
    Compute the Correlation Clustering cost.
    """
    which_pivot = [-1 for i in range(vertices_num)]
    with open(output_file, "r") as f:
        for line in f:
            u, pivot = line.strip().split()
            which_pivot[int(u)] = int(pivot)
    cost = 0
    with open(streaming_file, "r") as f:
        f.readline()
        for line in f:
            u, v, label = line.strip().split()
            if which_pivot[int(u)] == which_pivot[int(v)] and label == "-":
                cost += 1
            elif which_pivot[int(u)] != which_pivot[int(v)] and label == "+":
                cost += 1
    return cost


def calculate_cost(streaming_file, complete_graph_relation_file):
    """
    Compute the fractional Correlation Clustering cost.
    """
    cost = 0
    graph = nx.Graph()
    with open(streaming_file, "r") as f:
        f.readline()
        for line in f:
            u, v, label = line.strip().split()
            if label == "+":
                graph.add_edge(int(u), int(v))
    with open(complete_graph_relation_file, "r") as f:
        for line in f:
            u, v, x_uv = line.strip().split()
            if graph.has_edge(int(u), int(v)):
                cost += float(x_uv)
            else:
                cost += 1 - float(x_uv)
    return cost


def calculate_beta(streaming_file, OPT_file, prediction_file):
    """
    Calculate the beta ratio.
    """
    beta = 0
    opt_cost = calculate_cost(streaming_file, OPT_file)
    prediction_cost = calculate_cost(streaming_file, prediction_file)
    beta = prediction_cost / opt_cost
    return beta