# File: our_dynamic_alg.py
# Description: Implementation of our proposed learning-augmented algorithm for complete graphs in dynamic streams.


import random
import numpy as np
import math
import networkx as nx
from collections import defaultdict

class OurDynamicAlgorithm:
    def __init__(self, vertices_num, degrees, predictor):
        self.vertices_num = vertices_num
        self.degrees = degrees
        self.predictor = predictor

    def run_approx_algorithm(self, original_graph, streaming_edges_file, approx_ratio, output_file):
        """
        Run the (approx_ratio)-approximation streaming algorithm.
        """
        self.permutation = np.random.permutation(np.arange(self.vertices_num))
        self.unpermutation = [0] * (self.vertices_num)
        for i in range(self.vertices_num):
            self.unpermutation[self.permutation[i]] = i

        interesting_vertices = []
        tau_list = []
        c = 50
        eps = 0.1
        for i in range(self.vertices_num):
            tau_i = (c / eps) * (self.vertices_num * math.log10(self.vertices_num) / self.degrees[i])
            tau_list.append(tau_i)
            if self.permutation[i] < tau_i:
                interesting_vertices.append(i)
        
        # Read and process the streaming edges
        G_store = self.read_stream(streaming_edges_file, interesting_vertices)
        
        # Post-process and write the clustering result
        self.post_process(original_graph, G_store, approx_ratio, tau_list, output_file)

    def permuted_vertexID(self, u):
        """
        Return the permuted vertex ID.
        """
        return self.permutation[u]

    def unpermuted_vertexID(self, u):
        """
        Return the unpermuted vertex ID.
        """
        return self.unpermutation[u]

    def read_stream(self, streaming_edges_file, interesting_vertices):
        G_store = nx.Graph()
        edges_incident_interesting = []
        for i in interesting_vertices:
            G_store.add_node(self.permuted_vertexID(i))
        
        with open(streaming_edges_file, "r") as f:
            f.readline()
            for line in f:
                u, v, label = line.strip().split()
                u, v = self.permuted_vertexID(int(u)), self.permuted_vertexID(int(v))
                if u in interesting_vertices and v in interesting_vertices:
                    if label == "+":
                        G_store.add_edge(u, v, label = "+")
                if u in interesting_vertices or v in interesting_vertices:
                    if label == "+":
                        edges_incident_interesting.append((u,v))
        return G_store

    def post_process(self, original_graph, G_store, approx_ratio, tau_list, output_file):
        """
        Post-process and write the clustering result.
        """
        pivots = []
        uninteresting_vertices = list(set(range(self.vertices_num)) - set(G_store.nodes))
        uninteresting_vertices.sort()
        with open(output_file, "w") as f:
            if approx_ratio == 3:
                for i in G_store.nodes: 
                    neighbors = list(G_store.neighbors(i))  
                    found_pivot = False          
                    for j in neighbors:
                        if j in pivots:
                            found_pivot = True
                            f.write(f"{self.unpermuted_vertexID(i)} {self.unpermuted_vertexID(j)}\n")
                            break
                    if not found_pivot: 
                        f.write(f"{self.unpermuted_vertexID(i)} {self.unpermuted_vertexID(i)}\n")  
                        pivots.append(i)
                for i in uninteresting_vertices: 
                    found_pivot = False         
                    neighbors = list(original_graph.neighbors(self.unpermuted_vertexID(i))) 
                    for j in neighbors: 
                        if j in pivots and self.permuted_vertexID(j) < tau_list[self.unpermuted_vertexID(i)]:
                            found_pivot = True
                            f.write(f"{self.unpermuted_vertexID(i)} {j}\n")
                            break
                    if not found_pivot: 
                        f.write(f"{self.unpermuted_vertexID(i)} {self.unpermuted_vertexID(i)}\n")   
            else: # approx_ratio == 2.06
                for i in G_store.nodes:
                    neighbors = G_store.nodes  
                    found_pivot = False          
                    for j in pivots:
                        if j != i:
                            if original_graph.has_edge(self.unpermuted_vertexID(i), self.unpermuted_vertexID(j)):
                                label = "+"
                            else:
                                label = "-"
                            random_number = random.random()
                            if random_number < self.probability(self.predict(self.unpermuted_vertexID(i), self.unpermuted_vertexID(j)), label):
                                found_pivot = True
                                f.write(f"{self.unpermuted_vertexID(i)} {self.unpermuted_vertexID(j)}\n")
                                break
                    if not found_pivot: 
                        f.write(f"{self.unpermuted_vertexID(i)} {self.unpermuted_vertexID(i)}\n")   
                        pivots.append(i)
                for i in uninteresting_vertices: 
                    found_pivot = False          
                    for j in pivots: 
                        if j < tau_list[self.unpermuted_vertexID(i)]:
                            if original_graph.has_edge(self.unpermuted_vertexID(i), self.unpermuted_vertexID(j)):
                                label = "+"
                            else:
                                label = "-"
                            random_number = random.random()
                            if random_number < self.probability(self.predict(self.unpermuted_vertexID(i), self.unpermuted_vertexID(j)), label):
                                found_pivot = True
                                f.write(f"{self.unpermuted_vertexID(i)} {self.unpermuted_vertexID(j)}\n")
                                break
                    if not found_pivot: 
                        f.write(f"{self.unpermuted_vertexID(i)} {self.unpermuted_vertexID(i)}\n")   
            
    def probability(self, d_uv, label):
        """
        Compute the probability of u, v being in the same cluster for a given label and predicted dissimilarity d_uv.
        """
        return 1 - self.f(d_uv, label) 
    
    def predict(self, u, v):
        """
        Predict the value d_uv of the edge (u, v).
        """
        if u != v:
            return self.predictor.edges[u, v]["prediction"]   
        
    def f(self, d_uv, label):
        """
        Compute the value of the function f(d_uv, label) for a given label.
        """
        a = 0.19
        b = 0.5095
        if label == "-":
            return d_uv
        else:
            if d_uv < a:
                return 0
            elif d_uv < b:
                return pow((d_uv - a) / (b - a), 2)
            else:
                return 1