#define GUROBI
#include "src/env/ged_env.hpp"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include <iostream>
#include <tuple>
#include <utility>
#include <vector>

namespace py = pybind11;

using NodeLabel = int;
//using EdgeLabel = ged::NoLabel;
using EdgeLabel = int;

using Data = std::pair< std::vector< NodeLabel >, std::vector< std::pair< int, int >>>;

class GEDEditCosts: public ged::EditCosts<NodeLabel, EdgeLabel>
{

public:

	double node_ins_cost = 1;
	double node_del_cost = 1;
	double node_rel_cost = 1;
	double edge_ins_cost = 1;
	double edge_del_cost = 1;
	double edge_rel_cost = 0;

	GEDEditCosts(std::vector< double > &costs)
	{
		if (costs.size() != 6) {
			throw std::invalid_argument("costs.size() != 6");
		}
		node_ins_cost = costs[0];
		node_del_cost = costs[1];
		node_rel_cost = costs[2];
		edge_ins_cost = costs[3];
		edge_del_cost = costs[4];
		edge_rel_cost = costs[5];
	}


	double node_ins_cost_fun(const NodeLabel& node_label) const
	{
		return node_ins_cost;
	}
	
	double node_del_cost_fun(const NodeLabel& node_label) const
	{
		return node_del_cost;
	}
	
	double node_rel_cost_fun(const NodeLabel& node_label_1, const NodeLabel& node_label_2) const
	{
		return node_rel_cost * (node_label_1 != node_label_2);
	}
	
	double edge_ins_cost_fun(const EdgeLabel& edge_label) const
	{
		return edge_ins_cost;
	}
	
	double edge_del_cost_fun(const EdgeLabel& edge_label) const
	{
		return edge_del_cost;
	}
	
	double edge_rel_cost_fun(const EdgeLabel& edge_label_1, const EdgeLabel& edge_label_2) const
	{
		return edge_rel_cost;
	}
};

ged::Options::GEDMethod method_name_to_option(std::string name)
{
	if (name == "anchor_aware_ged") {
		return ged::Options::GEDMethod::ANCHOR_AWARE_GED;
	} else if (name == "blp_no_edge_labels") {
		return ged::Options::GEDMethod::BLP_NO_EDGE_LABELS;
	} else if (name == "branch") {
		return ged::Options::GEDMethod::BRANCH;
	} else if (name == "f2") {
		return ged::Options::GEDMethod::F2;
	} else if (name == "ipfp") {
		return ged::Options::GEDMethod::IPFP;
	} else {
		throw std::invalid_argument("unknown method");
	}
}


class GraphEditDistanceCalculator{
	public:
		ged::GEDEnv< int, NodeLabel, EdgeLabel > env;

		GraphEditDistanceCalculator(std::vector<Data> graphs, std::string method_name, std::vector< std::string > method_args, std::vector< double > costs){
			for (int i = 0; i < graphs.size(); i++){
				auto gi = env.add_graph();
				const auto& g_x = graphs[i].first;
				const auto& g_edge_index = graphs[i].second;
				for (int i = 0; i < (int)g_x.size(); ++i) {
					env.add_node(gi, i, g_x[i]);
				}
				for (const auto& p: g_edge_index) {
					//env.add_edge(gi, p.first, p.second, ged::NoLabel());
					env.add_edge(gi, p.first, p.second, 0);
				}
			}
			env.set_edit_costs(new GEDEditCosts(costs));
			env.init();
			env.set_method(method_name_to_option(method_name), method_args[0]);
			env.init_method();
			std::cout << "Initialization time for environment (seconds): " << env.get_init_time() << std::endl;
			// std::cout << "Number of graphs in the environment: " << env.num_graphs() << std::endl;
			//std::pair<GEDGraph::GraphID, GEDGraph::GraphID> ids = env.graph_ids();
			//std::cout << "Indices of the graphs in the environment: " << ids.first << " to " << ids.second << std::endl;
		}

		std::tuple<double, double, double> 
		calcged(int i, int j){
			env.run_method(i, j);
			double lb = env.get_lower_bound(i, j);
			double ub = env.get_upper_bound(i, j);
			double time = env.get_runtime(i ,j);
			return std::make_tuple(lb, ub, time);
		}

};





// std::tuple< double, double > calcged(const Data& g, const Data& h, std::vector< std::string > method_name, std::vector< std::string > method_args, std::vector< double > costs)
// {
// 	ged::GEDEnv< int, NodeLabel, EdgeLabel > env;
	
// 	auto gi = env.add_graph();
// 	const auto& g_x = g.first;
// 	const auto& g_edge_index = g.second;
// 	for (int i = 0; i < (int)g_x.size(); ++i) {
// 		env.add_node(gi, i, g_x[i]);
// 	}
// 	for (const auto& p: g_edge_index) {
// 		//env.add_edge(gi, p.first, p.second, ged::NoLabel());
// 		env.add_edge(gi, p.first, p.second, 0);
// 	}
	
// 	auto hi = env.add_graph();
// 	const auto& h_x = h.first;
// 	const auto& h_edge_index = h.second;
// 	for (int i = 0; i < (int)h_x.size(); ++i) {
// 		env.add_node(hi, i, h_x[i]);
// 	}
// 	for (const auto& p: h_edge_index) {
// 		//env.add_edge(hi, p.first, p.second, ged::NoLabel());
// 		env.add_edge(hi, p.first, p.second, 0);
// 	}
	
// 	// quick-fix: remove
// 	if (method_name[0] == "ged_f2") {
// 		env.set_edit_costs(new GEDEditCosts(costs));
// 		method_name[0] = "f2";
// 	} else if (method_name[0] == "ged_branch") {
// 		env.set_edit_costs(new GEDEditCosts(costs));
// 		method_name[0] = "branch";
// 	} else {
// 		env.set_edit_costs(new GEDEditCosts(costs));
// 	}
	
// 	env.init();
// 	double lb, ub;
// 	if (method_name.size() == 1) {
// 		env.set_method(method_name_to_option(method_name[0]), method_args[0]);
// 		env.init_method();
// 		env.run_method(gi, hi);
// 		lb = env.get_lower_bound(gi, hi);
// 		ub = env.get_upper_bound(gi, hi);
// 	} else if (method_name.size() == 2) {
// 		env.set_method(method_name_to_option(method_name[0]), method_args[0]);
// 		env.init_method();
// 		env.run_method(gi, hi);
// 		lb = env.get_lower_bound(gi, hi);
// 		env.set_method(method_name_to_option(method_name[1]), method_args[1]);
// 		env.init_method();
// 		env.run_method(gi, hi);
// 		ub = env.get_upper_bound(gi, hi);
// 	}
	
// 	return std::make_tuple(lb, ub);
// }


PYBIND11_MODULE(pyged, m) {
	// m.def("calcged", &calcged);
	py::class_<GraphEditDistanceCalculator>(m, "GraphEditDistanceCalculator")
		.def(py::init< std::vector<Data>, std::string, std::vector< std::string >, std::vector< double > >())
		.def("calcged", &GraphEditDistanceCalculator::calcged);
}

//std::vector<Data> graphs, std::string method_name, std::vector< std::string > method_args, std::vector< double > costs