# File: our_insertion_alg.py
# Description: Implementation of our proposed learning-augmented algorithm for complete graphs in insertion-only streams.


from utils import *
import time

class OurInsertionAlgorithm:
    def __init__(self, vertices_num, k, predictor):
        self.vertices_num = vertices_num
        self.predictor = predictor

        # Initialize two truncated priority queues for each vertex
        self.Q = [TruncatedPriorityQueue(k / 2) for _ in range(vertices_num)]
        for i in range(vertices_num):
            self.Q[i].push(item=i, priority=i + 1)
        
        self.P = [TruncatedPriorityQueue(k / 2) for _ in range(vertices_num)]
        for i in range(vertices_num):
            self.P[i].push(item=i, priority=i + 1)

    def run(self, streaming_edges_file, output_file_sub_alg1, output_file_sub_alg2):
        # Generate a random permutation of the vertices
        self.permutator = Vertexpermutator(self.vertices_num)

        # Read and process the streaming edges
        start_time = time.time()
        self.read_stream(streaming_edges_file)
        end_time = time.time()

        # Post-process and write the clustering result
        start_time_3 = time.time()
        post_process(self.vertices_num, self.Q, self.permutator, output_file_sub_alg1)
        end_time_3 = time.time()
        print(f"{(end_time - start_time) + (end_time_3 - start_time_3)}\n")
        
        start_time_2 = time.time()
        post_process(self.vertices_num, self.P, self.permutator, output_file_sub_alg2)
        end_time_2 = time.time()
        print(f"{(end_time - start_time) + (end_time_2 - start_time_2)}\n")
        
    def read_stream(self, streaming_edges_file):
        with open(streaming_edges_file, "r") as f:
            f.readline()
            for line in f:
                u, v, label = line.strip().split()
                u, v = self.permutator.get_permuted_vertexID(int(u)), self.permutator.get_permuted_vertexID(int(v))

                if label == "+":
                    self.Q[u].push(item=v, priority=v + 1)
                    self.Q[v].push(item=u, priority=u + 1)

                random_number = random.random()
                if random_number < self.probability(self.predict(u, v), label):
                    self.P[u].push(item=v, priority=v + 1)
                    self.P[v].push(item=u, priority=u + 1)
    
    def predict(self, u, v):
        """
        Predict the value d_uv of the edge (u, v).
        """
        u, v = self.permutator.get_unpermuted_vertexID(u), self.permutator.get_unpermuted_vertexID(v)
        if u != v:
            return self.predictor.edges[u, v]["prediction"]    
        
    def probability(self, d_uv, label):
        """
        Compute the probability of u, v being in the same cluster for a given label and predicted dissimilarity d_uv.
        """
        return 1 - self.f(d_uv, label) 

    def f(self, d_uv, label):
        """
        Compute the value of the function f(d_uv, label) for a given label.
        """
        a = 0.19
        b = 0.5095
        if label == "-":
            return d_uv
        else:
            if d_uv < a:
                return 0
            elif d_uv < b:
                return pow((d_uv - a) / (b - a), 2)
            else:
                return 1