// This is demo implementation of the semi-streaming algorithm for Correlation Clustering
// We do not include error checking in this code to make it more readable.

// To compile the code, run the following command:
//    clang++ streaming-CC.cc -O3 -std=c++17 -o streaming-CC

#include <algorithm>
#include <cassert>
#include <fstream>
#include <iostream>
#include <random>
#include <unordered_map>
#include <vector>

// Find a random ordering/ranking of numbers from 0 to n-1
void GetRandomOrdering(int n, std::vector<int>& id2rank, std::vector<int>& rank2id) {
   rank2id.resize(n);
   for (int i = 0; i < n; ++i) {
      rank2id[i] = i;
   }

   std::random_device rd;
   std::mt19937 prg(rd());
   std::shuffle(rank2id.begin(), rank2id.end(), prg);

   id2rank.resize(n);
   for (int i = 0; i < n; ++i) {
      id2rank[rank2id[i]] = i;
   }
}

std::vector<int> TranslateRanks2IDs(const std::vector<int>& rank2id,
                                    const std::vector<int>&  clusterIDs) {
   int n = clusterIDs.size();
   std::vector<int> newIDs(n);

   for (int i = 0; i < n; ++i) {
      newIDs[rank2id[i]] = rank2id[clusterIDs[i]];
   }

   return newIDs;
}

using NeighborList = std::vector<std::vector<int>>;

// Add element u to the priority queue pq of maximum size k
void AddElement(std::vector<int>& pq, int u, int k) {
   if (pq.size() < k) {
      pq.push_back(u);
      std::push_heap(pq.begin(), pq.end());
   } else if (u < pq[0]) {
      // replace the largest (lowest ranked) element with u
      std::pop_heap(pq.begin(), pq.end());
      pq[k-1] = u;
      std::push_heap(pq.begin(), pq.end());
   }
}

// Read edges of the graph from a binary stream
NeighborList ReadEdges(std::istream& in, int n, std::vector<int>& id2rank, int k) {
   NeighborList neighbors(n);

   // add self-loops
   for (int i = 0; i < n; ++i) {
      neighbors[i].push_back(i);
   }

   int u, v;
   while (in >> u >> v) {
      // ignore edges that are not in the graph
      if ((u >= 0) && (u < n) && (v >= 0) && (v < n)) {
         int rankU = id2rank[u];
         int rankV = id2rank[v];
         AddElement(neighbors[rankU], rankV, k);
         AddElement(neighbors[rankV], rankU, k);
      }
   }
   return neighbors;
}

// The Pivot step of the algorithm
std::vector<int> Pivot(const NeighborList& neighbors) {
   int n = neighbors.size();
   
   // clusterIDs[i] is the "rank" of the pivot vertex i is assigned to
   //n indicates that n is not assigned to any pivot.
   std::vector<int> clusterIDs(n, n);

   for (int i = 0; i < n; ++i) {
      // find the smallest neighbor j of i which is a pivot
      for (int j: neighbors[i]) {
        if ((i == j) || (j < i) && (clusterIDs[j] == j)) {
            clusterIDs[i] = std::min(clusterIDs[i], j);
        }
      }
   }

   // assign "singleton" clusterIDs to all unassigned vertices
   for (int i = 0; i < n; ++i) {
      if (clusterIDs[i] == n) {
          clusterIDs[i] = i;
      }
   }

   return clusterIDs;
}

// The main algorithm: Read the graph from the stream and cluster it
std::vector<int> StreamingCC(std::istream& in, int n, int k) {
   std::vector<int> id2rank;
   std::vector<int> rank2id;
   GetRandomOrdering(n, id2rank, rank2id);
   NeighborList neighbors = ReadEdges(in, n, id2rank, k);
   std::vector<int> clusterIDs = Pivot(neighbors);
   
   return TranslateRanks2IDs(rank2id, clusterIDs);
}


////////////////////////////////////////////////////////////////////////////////////////
// The following functions are used for debugging purposes. 
// They are not a part of the algorithm described in the paper.

// Compute the cost of the clustering.
int ComputeCost(std::istream& in, const std::vector<int>& clusterIDs) {
   int n = clusterIDs.size();

   // use a hash table to count the number of vertices in each cluster
   // if the number of clusters is large, it is better to use a vector instead
   std::unordered_map<int,int> clusterSizes;

   for (int i = 0; i < n; ++i) {
      clusterSizes[clusterIDs[i]]++;
   }

   int cost = 0;

   // first, we assume that all edges are negative
   for (auto cluster: clusterSizes) {
      cost += cluster.second * (cluster.second - 1) / 2;
   }

   // now read the input stream
   int u, v;
   while (in >> u >> v) {
      if ((u >= 0) && (u < n) && (v >= 0) && (v < n)) {
         if (clusterIDs[u] == clusterIDs[v]) {
            cost--; // same cluster
         }
         else {
            cost++; // different clusters
         }
      }
   }

   return cost;
}

// Count the number of vertices and edges in the input stream.
std::pair<int,int> CountVerticesAndEdges(std::istream& in) {
   int n = 0;
   int m = 0;
   int u, v;
   while (in >> u >> v) {
      n = std::max(n, std::max(u, v) + 1);
      m++;
   }
   return std::make_pair(n, m);
}

// Print a help message and exit if the number of arguments is incorrect.
void PrintHelpMessage(int argc, const char* programName) {
   std::cout << std::endl;
   std::cout << "Usage: " << programName << " <filename> <number of vertices>" << std::endl << std::endl;
   std::cout << "The file should contain a list of edges in the text format. ";
   std::cout << "All vertices are represented by numbers from 0 to n-1. " << std::endl;
   std::cout << "You can omit the number of vertices. In this case, ";
   std::cout << "the program will make an extra pass to count them." << std::endl << std::endl;

   if (argc < 2) {
      std::cout << "Please specify the filename." << std::endl;
      exit(1);
   }

   if (argc > 3) {
      std::cout << "Incorrect number of arguments. Please, provide 1 or 2 arguments (see above)." << std::endl;
      exit(1);
   }
}

int main(int argc, char** argv) {

   // print a help message & exit if the number of arguments is incorrect
   PrintHelpMessage(argc, argv[0]);

   const int k = 10; //number of neighbors to keep

   std::string filename = argv[1];

   std::ifstream in(filename);
   if (!in) {
      std::cout << "Cannot open file: "<< filename << std::endl;
      return 1;
   }

   std::cout << "Filename: " << filename << std::endl;

   int n = 0;
   
   if (argc == 3) {
      n = atoi(argv[2]);

      if (n <= 0) {
         std::cout << "Incorrect number of vertices: " << n << std::endl;
         return 1;
      }

      std::cout << "Number of vertices: " << n << std::endl;
   }
   else if (argc == 2) {
      auto count = CountVerticesAndEdges(in);
      n = count.first;
      assert(n > 0);
      std::cout << "Number of vertices: " << n << " number of edges: " << count.second << std::endl;
      in.close();
      in.open(filename);
   }

   //run the streaming algorithm
   std::vector<int> clusterIDs = StreamingCC(in, n, k);
   in.close();

   //compute the cost of the clustering
   in.open(filename);
   int cost = ComputeCost(in, clusterIDs);
   in.close();
   std::cout << "Cost: " << cost << std::endl;

   return 0;
}