#include<pybind11/pybind11.h>
#include<pybind11/numpy.h>
#include<iostream>
#include<cassert>
#include<thread>
#include<map>
#include<set>

using std::vector;
using std::pair;
using std::cout;
using std::endl;
using std::map;
using std::set;

namespace py = pybind11;

class UnionFind
{
public:
	UnionFind(int n) {
		_n = n;
		_parents = vector<int>(n);
		_bias = 0;

		for(int i = 0; i < n; ++i) {
			_parents[i] = i;
		}

		_sizes = vector<int>(n, 1);
	}

	UnionFind(int n, int bias) {
		_n = n;
		_bias = bias;
		_parents = vector<int>(n);

		for(int i = 0; i < n; ++i) {
			_parents[i] = i;
		}

		_sizes = vector<int>(n, 1);
	}


	int Find(int u) {

		u -= _bias;

		while (_parents[u] != u) {
			u = _parents[u];
		}

		return u;
	}

	void Union(int u, int v) {
		
		assert(u >= _bias);
		assert(v >= _bias);

		assert(u < _n + _bias);
		assert(v < _n + _bias);

		int c1 = Find(u);
		int c2 = Find(v);

		if (c1 == c2)
			return;
		
		if (_sizes[c1] < _sizes[c2])
			std::swap(c1, c2);

		_parents[c2] = c1;
		_sizes[c1] = _sizes[c1] + _sizes[c2];
	}

private:
	int _n;
	int _bias;
	vector<int> _parents;
	vector<int> _sizes;
};

bool cmp_second(const pair<int, float>& u, const pair<int, float>& v) { 

	return u.second < v.second;
}

bool has_common_elem(std::set<int>::iterator first1, std::set<int>::iterator last1, std::set<int>::iterator first2, std::set<int>::iterator last2)
{	
 	while (first1 != last1 && first2 != last2)
  	{
    		if (*first1 < *first2) 
			++first1;
    		else if (*first2 < *first1)
			++first2;
    		else
			return true;
	}

	return false;
}

void calc_barcodes_part_cycles(int n_nodes, py::array_t<float>& edges, py::array_t<float>& w, 
			int edges_from_id, int edges_to_id, int nodes_from_id,
			py::array_t<int>& h0, py::array_t<int>& h0_e, py::array_t<int>& h1, bool filter_cycles, int graph_num)
{
	int n_edges = edges_to_id - edges_from_id;
	vector<pair<int, float>> edges_sorted(n_edges);

	for (int i = 0; i < n_edges; ++i)
	{
		int edge_idx = edges_from_id + i;
		edges_sorted[i].first = edge_idx;
		edges_sorted[i].second = *w.data(edge_idx);
	}

	std::sort(edges_sorted.begin(), edges_sorted.end(), cmp_second); 

	UnionFind uf = UnionFind(n_nodes);
	int h0_idx = nodes_from_id;
	//int h1_idx = edges_from_id;
	
	std::map<int, std::set<int>> adj;
	std::set<pair<int, int>> taken_edges;

	if (filter_cycles) 
		for (int i = 0; i < n_nodes; ++i)
			adj.insert(std::pair<int, std::set<int>>(i, std::set<int>()));

	for (int i = 0; i < (int)edges_sorted.size(); ++i)
	{
		int edge_idx = edges_sorted[i].first;
		int v1 = *edges.data(0, edge_idx) - nodes_from_id;
		int v2 = *edges.data(1, edge_idx) - nodes_from_id;

		pair<int, int> v_pair = (v1 < v2 ? pair<int, int>(v1, v2) : pair<int, int>(v2, v1));

		if (uf.Find(v1) != uf.Find(v2)) {
			uf.Union(v1, v2);
			*h0.mutable_data(h0_idx) = edge_idx;
			h0_idx++;

			*h0_e.mutable_data(edge_idx) = graph_num;
		}
		else {

			bool add_cycle = (taken_edges.find(v_pair) == taken_edges.end());

			if (filter_cycles)
				add_cycle = add_cycle and (not has_common_elem(adj[v1].begin(), adj[v1].end(), adj[v2].begin(), adj[v2].end()));

			if (add_cycle) {
				//*h1.mutable_data(h1_idx) = edge_idx;
				//h1_idx++;
				*h1.mutable_data(edge_idx) = graph_num;
			}
		}

		taken_edges.insert(v_pair);

		if (filter_cycles) {
			adj.at(v1).insert(v2); 
			adj.at(v2).insert(v1); 
		}
	}
}

void calc_barcodes_part(int n_nodes, py::array_t<float>& edges, py::array_t<float>& w, 
			int edges_from_id, int edges_to_id, int nodes_from_id,
			py::array_t<int>& h0)
{

	//cout << "running " << edges_from_id << " " << edges_to_id << endl;

	int n_edges = edges_to_id - edges_from_id;
	vector<pair<int, float>> edges_sorted(n_edges);

	for (int i = 0; i < n_edges; ++i)
	{
		int edge_idx = edges_from_id + i;
		edges_sorted[i].first = edge_idx;
		edges_sorted[i].second = *w.data(edge_idx);
	}

	std::sort(edges_sorted.begin(), edges_sorted.end(), cmp_second); 
	//cout << "sorted" << endl;

	//for (int i = 0; i < (int)edges_sorted.size(); ++i)
	//`	std::cout << edges_sorted[i].first << " " << edges_sorted[i].second << endl;

	//cout << "n_nodes " << n_nodes << endl;

	UnionFind uf = UnionFind(n_nodes);
	int h0_idx = nodes_from_id;
	//cout << "union find" << endl;

	for (int i = 0; i < (int)edges_sorted.size(); ++i){
		int edge_idx = edges_sorted[i].first;
		//cout << "edge_idx " << edge_idx << endl;
		int v1 = *edges.data(0, edge_idx);
		int v2 = *edges.data(1, edge_idx);

		v1 -= nodes_from_id;
		v2 -= nodes_from_id;

		//cout << v1 << " " << v2 << " " << uf.Find(v1) << " " << uf.Find(v2) << endl;

		if (uf.Find(v1) != uf.Find(v2)) {
			uf.Union(v1, v2);
			*h0.mutable_data(h0_idx) = edge_idx;
			h0_idx++;
		}
	}

	//cout << "in c++: " << h0_idx << endl;
}

void calc_barcodes_part2(int n_nodes, py::array_t<float>& edges, py::array_t<float>& w, 
			int edges_from_id, int edges_to_id, int nodes_from_id,
			py::array_t<int>& h0, py::array_t<int>& h0_e, int graph_num)
{

	//cout << "running " << edges_from_id << " " << edges_to_id << endl;
	//
	int n_edges = edges_to_id - edges_from_id;
	vector<pair<int, float>> edges_sorted(n_edges);

	for (int i = 0; i < n_edges; ++i)
	{
		int edge_idx = edges_from_id + i;
		edges_sorted[i].first = edge_idx;
		edges_sorted[i].second = *w.data(edge_idx);
	}

	std::sort(edges_sorted.begin(), edges_sorted.end(), cmp_second); 
	

	/*int n_edges = edges_to_id - edges_from_id;
	std::map<pair<int, int>, int> nodes2edges;

	for (int i = 0; i < n_edges; ++i)
	{
		int edge_idx = edges_from_id + i;
		int v1 = *edges.data(0, edge_idx);
		int v2 = *edges.data(1, edge_idx);

		pair<int, int> key(v1, v2);
		nodes2edges.insert(pair<pair<int, int>, int>(key, edge_idx));
	}

	vector<pair<int, float>> edges_sorted;

	for (int i = 0; i < n_edges; ++i)
	{
		int edge_idx = edges_from_id + i;
		int v1 = *edges.data(0, edge_idx);
		int v2 = *edges.data(1, edge_idx);

		int edge_idx_sym = nodes2edges[pair<int, int>(v2, v1)];
		float e_w = *w.data(edge_idx);
		float e_w_sym = *w.data(edge_idx_sym);

		if (e_w < e_w_sym)
			edges_sorted.push_back(pair<int, float>(edge_idx, e_w));
	}

	std::sort(edges_sorted.begin(), edges_sorted.end(), cmp_second);
	*/
	
	//cout << "sorted" << endl;

	//for (int i = 0; i < (int)edges_sorted.size(); ++i)
	//`	std::cout << edges_sorted[i].first << " " << edges_sorted[i].second << endl;

	//cout << "n_nodes " << n_nodes << endl;

	UnionFind uf = UnionFind(n_nodes);
	int h0_idx = nodes_from_id;
	//cout << "union find" << endl;

	for (int i = 0; i < (int)edges_sorted.size(); ++i){
		int edge_idx = edges_sorted[i].first;
		//cout << "edge_idx " << edge_idx << endl;
		int v1 = *edges.data(0, edge_idx);
		int v2 = *edges.data(1, edge_idx);

		v1 -= nodes_from_id;
		v2 -= nodes_from_id;

		//cout << v1 << " " << v2 << " " << uf.Find(v1) << " " << uf.Find(v2) << endl;

		if (uf.Find(v1) != uf.Find(v2)) {
			uf.Union(v1, v2);
			*h0.mutable_data(h0_idx) = edge_idx;
			*h0_e.mutable_data(edge_idx) = graph_num;
			h0_idx++;
		}
	}

	//cout << "in c++: " << h0_idx << endl;
}

void calc_barcodes(int n_nodes, py::array_t<float> edges, py::array_t<float> w, py::array_t<int> h0){

	calc_barcodes_part(n_nodes, edges, w, 0, edges.size(), 0, h0);
}

void calc_barcodes_batch2(int batch_size, py::array_t<float> edges, py::array_t<float> w, py::array_t<int> edge_ptr, py::array_t<int> node_ptr, py::array_t<int> h0, py::array_t<int> h0_e, int multiprocessing)
{
	vector<std::thread> threads;

	for (int i = 0; i < batch_size; ++i) {
		int edges_from_id = *edge_ptr.data(i);
		int edges_to_id = *edge_ptr.data(i+1);

		int nodes_from_id = *node_ptr.data(i);
		int nodes_to_id = *node_ptr.data(i+1);

		int n_nodes = nodes_to_id - nodes_from_id;

		if (multiprocessing)
			threads.push_back(std::thread(calc_barcodes_part2, n_nodes, std::ref(edges), std::ref(w), \
						edges_from_id, edges_to_id, nodes_from_id, std::ref(h0), std::ref(h0_e), i));
		else
			calc_barcodes_part2(n_nodes, edges, w, edges_from_id, edges_to_id, nodes_from_id, h0, h0_e, i);
	}

	if (multiprocessing)
		for (int i = 0; i < batch_size; ++i)
			threads[i].join();
}

void calc_barcodes_batch(int batch_size, py::array_t<float> edges, py::array_t<float> w, py::array_t<int> edge_ptr, py::array_t<int> node_ptr, py::array_t<int> h0, int multiprocessing)
{
	vector<std::thread> threads;

	for (int i = 0; i < batch_size; ++i) {
		int edges_from_id = *edge_ptr.data(i);
		int edges_to_id = *edge_ptr.data(i+1);

		int nodes_from_id = *node_ptr.data(i);
		int nodes_to_id = *node_ptr.data(i+1);

		int n_nodes = nodes_to_id - nodes_from_id;

		if (multiprocessing)
			threads.push_back(std::thread(calc_barcodes_part, n_nodes, std::ref(edges), std::ref(w), \
						edges_from_id, edges_to_id, nodes_from_id, std::ref(h0)));
		else
			calc_barcodes_part(n_nodes, edges, w, edges_from_id, edges_to_id, nodes_from_id, h0);
	}

	if (multiprocessing)
		for (int i = 0; i < batch_size; ++i)
			threads[i].join();
}



void calc_barcodes_batch_cycles(int batch_size, py::array_t<float> edges, py::array_t<float> w, py::array_t<int> edge_ptr, py::array_t<int> node_ptr, py::array_t<int> h0, py::array_t<int> h0_e, py::array_t<int> h1, bool filter_cycles, int multiprocessing)
{
	vector<std::thread> threads;

	for (int i = 0; i < batch_size; ++i) {
		int edges_from_id = *edge_ptr.data(i);
		int edges_to_id = *edge_ptr.data(i+1);

		int nodes_from_id = *node_ptr.data(i);
		int nodes_to_id = *node_ptr.data(i+1);

		int n_nodes = nodes_to_id - nodes_from_id;

		if (multiprocessing)
			threads.push_back(std::thread(calc_barcodes_part_cycles, n_nodes, std::ref(edges), std::ref(w), \
						edges_from_id, edges_to_id, nodes_from_id, std::ref(h0), std::ref(h0_e), std::ref(h1), filter_cycles, i));
		else
			calc_barcodes_part_cycles(n_nodes, edges, w, edges_from_id, edges_to_id, nodes_from_id, h0, h0_e, h1, filter_cycles, i);
	}

	if (multiprocessing)
		for (int i = 0; i < batch_size; ++i)
			threads[i].join();
}

PYBIND11_MODULE(ph_simple, m) {

    m.def("calc_barcodes", &calc_barcodes);
    m.def("calc_barcodes_part", &calc_barcodes_part);
    m.def("calc_barcodes_batch", &calc_barcodes_batch);
    m.def("calc_barcodes_batch2", &calc_barcodes_batch2);
    m.def("calc_barcodes_batch_cycles", &calc_barcodes_batch_cycles);  // main variant

};
