# File: ourAlg.py
# This file contains the implementation of our algorithm, whose approximation ratio is (min{2.06beta, 3} + \eps).

from utils import *

class OurStreamingAlgorithm:
    
    def __init__(self, vertices_num, k, predictor):
        self.vertices_num = vertices_num
        self.predictor = predictor

        # Initialize the truncated priority queues of each vertex for pivot algotithm
        self.Q = [TruncatedPriorityQueue(k / 2) for _ in range(vertices_num)]
        for i in range(vertices_num):
            self.Q[i].push(item=i, priority=i + 1)
        
        # Initialize the truncated priority queues of each vertex for algotithm with predictions
        self.P = [TruncatedPriorityQueue(k / 2) for _ in range(vertices_num)]
        for i in range(vertices_num):
            self.P[i].push(item=i, priority=i + 1)

    def run(self, streaming_edges_file, output_file_sub_alg1, output_file_sub_alg2):
        # Permute the vertices to get a random order
        self.permutator = Vertexpermutator(self.vertices_num)

        # Read the streaming edges and update the truncated priority queues
        self.read_stream(streaming_edges_file)

        # Post-process and output the final clustering result
        post_process(self.vertices_num, self.Q, self.permutator, output_file_sub_alg1)
        post_process(self.vertices_num, self.P, self.permutator, output_file_sub_alg2)

    def read_stream(self, streaming_edges_file):
        with open(streaming_edges_file, 'r') as f:
            f.readline()
            for line in f:
                u, v, label = line.strip().split()
                u, v = self.permutator.get_permuted_vertexID(int(u)), self.permutator.get_permuted_vertexID(int(v))

                if label == '+':
                    self.Q[u].push(item=v, priority=v + 1)
                    self.Q[v].push(item=u, priority=u + 1)

                random_number = random.random()
                if random_number < self.probability(self.predict(u, v), label):
                    self.P[u].push(item=v, priority=v + 1)
                    self.P[v].push(item=u, priority=u + 1)
    
    def predict(self, u, v):
        '''
        Predict the value d_uv of the edge (u, v).
        '''
        u, v = self.permutator.get_unpermuted_vertexID(u), self.permutator.get_unpermuted_vertexID(v)
        if u != v:
            return self.predictor.edges[u, v]['prediction']    
        
    def probability(self, d_uv, label):
        '''
        Compute the value of the probability of u, v being in the same cluster for a given label and predicted dissimilarity d_uv.
        '''
        return 1 - self.f(d_uv, label) 

    def f(self, d_uv, label):
        '''
        Compute the value of the function f(d_uv, label) for a given label.
        '''
        a = 0.19
        b = 0.5095
        if label == '-':
            return d_uv
        else:
            if d_uv < a:
                return 0
            elif d_uv < b:
                return pow((d_uv - a) / (b - a), 2)
            else:
                return 1