# File: icml21_alg.py
# Description: Implementation of the ICML'21 agreement decomposition algorithm.

import math
import random


class AgreementBasedAlgorithm:
    def __init__(self, vertices_num, beta, lambda_):
        self.vertices_num = vertices_num
        self.beta = beta
        self.lambda_ = lambda_

        self.neighbors = [set() for _ in range(vertices_num + 1)]
        for i in range(vertices_num):
            self.neighbors[i].add(i)
        self.degrees = [1] * (vertices_num + 1)
        self.j = [0] * (vertices_num + 1)

        self.S = [set() for _ in range(vertices_num + 1)]
        self.S_prime = [set() for _ in range(vertices_num + 1)]

        self.light_or_not = [0] * (vertices_num + 1)

        self.id0 = [i for i in range(1, vertices_num + 1)]
        self.id1 = [0] * (vertices_num + 1)

    def run(self, streaming_edges_file, output_file):
        self.pass_get_degrees(streaming_edges_file)
        self.pass_compute_light_or_not(streaming_edges_file)
        self.pass_connected_components(streaming_edges_file)
        self.write_result(output_file)

    def pass_get_degrees(self, streaming_edges_file):
        with open(streaming_edges_file, 'r') as f:
            f.readline()
            for line in f:
                u, v, label = line.strip().split()
                u, v = int(u), int(v)
                if label == '+':
                    self.degrees[u] += 1
                    self.degrees[v] += 1
                    self.neighbors[u].add(v)
                    self.neighbors[v].add(u)

    def pass_compute_S_v(self, streaming_edges_file):
        a = 600
        with open(streaming_edges_file, 'r') as f:
            f.readline()
            for line in f:
                u, v, label = line.strip().split()
                u, v = int(u), int(v)

                if label == '+':
                    random_number = random.random()
                    if random_number <= min(1, a * math.log(self.vertices_num) / (self.beta * self.j[u])):
                        self.S[u].add(v)

                    random_number = random.random()
                    if random_number <= min(1, a * math.log(self.vertices_num) * (1 - self.beta) / (self.beta * self.j[u])):
                        self.S_prime[u].add(v)
                    
                    random_number = random.random()
                    if random_number <= min(1, a * math.log(self.vertices_num) / (self.beta * self.j[v])):
                        self.S[v].add(u)

                    random_number = random.random()
                    if random_number <= min(1, a * math.log(self.vertices_num) * (1 - self.beta) / (self.beta * self.j[v])):
                        self.S_prime[v].add(u)
    
    def agreement(self, u, v):
        """
        Check if vertices u and v are in agreement.
        """
        a = 600
        if self.degrees[v] < (1 - self.beta) * self.degrees[u] or self.degrees[v] >  self.degrees[u] / (1 - self.beta):
            return False
        else:
            x = u if self.degrees[u] > self.degrees[v] else v
            y = v if self.degrees[u] > self.degrees[v] else u
            tao = a * math.log(self.vertices_num) * max(self.degrees[u], self.degrees[v]) / self.j[x]
            if self.j[y] == self.j[x]:
                X = len(self.S[x].symmetric_difference(self.S[y]))
            else:
                X = len(self.S[x].symmetric_difference(self.S_prime[y]))
            if X <= 0.9 * tao:
                return True
        return False
    
    def real_agreement(self, u, v):
        if len(self.neighbors[u].symmetric_difference(self.neighbors[v])) < self.beta * max(self.degrees[u], self.degrees[v]):
            return True
        else:
            return False
        
    def pass_compute_light_or_not(self, streaming_edges_file):
        discard_num = [0] * (self.vertices_num + 1)
        with open(streaming_edges_file, 'r') as f:
            f.readline()
            for line in f:
                u, v, label = line.strip().split()
                u, v = int(u), int(v)
                if label == '+':
                    if not self.real_agreement(u, v):
                        discard_num[u] += 1
                        discard_num[v] += 1
        for i in range(self.vertices_num):
            if discard_num[i] >= self.lambda_ * self.degrees[i]:
                self.light_or_not[i] = 1

    def pass_connected_components(self, streaming_edges_file):
        for _ in range(4):
            with open(streaming_edges_file, 'r') as f:
                f.readline()
                for line in f:
                    u, v, label = line.strip().split()
                    u, v = int(u), int(v)
                    if label == '+': 
                        if self.light_or_not[u] == 0 or self.light_or_not[v] == 0:   # Discard all edges between two light vertices.
                            if self.real_agreement(u, v):                            # Discard all edges whose endpoints are not in agreement.
                                self.id1[u] = max(self.id0[u], self.id0[v])
                                self.id1[v] = max(self.id0[u], self.id0[v])
            for i in range(self.vertices_num):
                self.id0[i] = self.id1[i]

    def write_result(self, output_file):
        max_pivots_num = 0
        for i in range(self.vertices_num):
            if self.id1[i] == i:
                max_pivots_num += 1
        connected_components = {}
        for i in range(self.vertices_num):
            if self.id1[i] not in connected_components:
                connected_components[self.id1[i]] = set()
            connected_components[self.id1[i]].add(i)
        with open(output_file, 'w') as f:
            for i in range(self.vertices_num):
                f.write(f'{i} {min(connected_components[self.id1[i]])}\n')