#include <gflags/gflags.h>
#include <ncode/file.h>
#include <ncode/logging.h>
#include <ncode/lp/demand_matrix.h>
#include <ncode/map_util.h>
#include <ncode/net/net_common.h>
#include <ncode/net/net_gen.h>
#include <ncode/strutil.h>
#include <ncode/thread_runner.h>
#include <chrono>
#include <map>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <algorithm>

#include "demand_matrix_input.h"
#include "opt/common.h"
#include "opt/ldr.h"
#include "opt/opt.h"
#include "opt/path_provider.h"
#include "topology_input.h"

using namespace std::chrono;
using namespace nc::net;

DEFINE_string(top_file, "", "Topology file to load.");
DEFINE_string(dm_file, "", "Demand matrix file to load.");
DEFINE_string(wopt_name, "", "Name of weight optimization method.");
DEFINE_bool(use_ecmp, false, "Whether to use ECMP to derive routing config.");

class ECMPTreeNode {
    public:
        ECMPTreeNode(double fraction_, bool is_leaf_,
                    std::string graph_node_id, ECMPTreeNode* parent_):
                    fraction(fraction_), is_leaf(is_leaf_),
                    graph_node_id(graph_node_id), parent(parent_)
                    {}

    public:
        double fraction;
        bool is_leaf;
        std::string graph_node_id;
        ECMPTreeNode* parent;
};

void AddECMPNodesRecursively(ECMPTreeNode* tree_node, std::string dst_id,
                             nc::net::GraphStorage& graph, AdjacencyList& adj_list,
                             AllPairShortestPath& apsp, std::vector<ECMPTreeNode*>* leaves) {


    std::string nodeId = tree_node->graph_node_id;
    GraphNodeIndex node = graph.NodeFromStringOrDie(nodeId);
    GraphNodeIndex dst = graph.NodeFromStringOrDie(dst_id);


    double fraction = tree_node->fraction;
    const std::vector<AdjacencyList::LinkInfo>& links = adj_list.GetNeighbors(node);

    std::map<GraphNodeIndex, uint64_t> neighbor_dst_distances;
    std::vector<GraphNodeIndex> sp_neighbors;


    for (AdjacencyList::LinkInfo li: links) {
        GraphNodeIndex neighbor = li.dst_index;
        uint64_t distance = apsp.GetDistance(neighbor, dst).count();
        neighbor_dst_distances[neighbor] = distance;

//        LOG(INFO) << "Distance from neighbor " << graph.GetNode(neighbor)->id() << " to destination is " << distance;

    }

    auto min_to_dst_element = std::min_element(neighbor_dst_distances.begin(), neighbor_dst_distances.end(),
                      [](decltype(neighbor_dst_distances)::value_type& l, decltype(neighbor_dst_distances)::value_type& r) -> bool { return l.second < r.second; });

   uint64_t min_distance = min_to_dst_element->second;

//    LOG(INFO) << "Min distance to destination is " << min_distance;

    for(std::pair<GraphNodeIndex, uint64_t> item: neighbor_dst_distances) {
        if ((min_distance - 0.0001 <= item.second)
            && (item.second <= min_distance + 0.0001)) {
            sp_neighbors.push_back(item.first);
        }
    }

    double child_fraction = fraction / sp_neighbors.size();
//    LOG(INFO) << "There are " << sp_neighbors.size() << " SP neighbors.";

    for(GraphNodeIndex sp_neighbor: sp_neighbors) {
        std::string sp_neighbor_id = graph.GetNode(sp_neighbor)->id();
//        LOG(INFO) << "Examining neighbor " << sp_neighbor_id << " and dst " << dst_id;


        bool is_leaf = false;
        if (sp_neighbor_id == dst_id) {
//        if (sp_neighbor == dst) {
            is_leaf = true;
//            LOG(INFO) << "Reached leaf " << sp_neighbor_id << "with fraction << " << child_fraction << ">>.";
        }

        ECMPTreeNode* next = new ECMPTreeNode(child_fraction, is_leaf, sp_neighbor_id, tree_node);

        if(is_leaf) {
            leaves->push_back(next);
        }
        else {
            AddECMPNodesRecursively(next, dst_id, graph, adj_list, apsp, leaves);
        }
    }
}

int main(int argc, char** argv) {
    gflags::ParseCommandLineFlags(&argc, &argv, true);
//    LOG(INFO) << "Hello world from weights to routing config tool. Got topology file to read as " << FLAGS_top_file;

    std::vector<std::string> node_order;
    GraphBuilder builder = LoadRepetitaOrDie(nc::File::ReadFileToStringOrDie(FLAGS_top_file), &node_order);
    builder.RemoveMultipleLinks();
    auto graph = nc::make_unique<GraphStorage>(builder);
    auto adj_list = graph->AdjacencyList();

    for (const auto& node_and_neighbors : adj_list.Adjacencies()) {
        for (const AdjacencyList::LinkInfo& link_info :
             *node_and_neighbors.second) {

//          LOG(INFO) << "Read in link from " << link_info.src_index << " to " << link_info.dst_index;
        }
    }


    tm_gen::PathProvider path_provider(graph.get());
    std::unique_ptr<nc::lp::DemandMatrix> demand_matrix = nc::lp::DemandMatrix::LoadRepetitaFileOrDie(FLAGS_dm_file, node_order, graph.get());
    std::unique_ptr<tm_gen::TrafficMatrix> tm =
          tm_gen::TrafficMatrix::DistributeFromDemandMatrix(*demand_matrix);

    AllPairShortestPath apsp({}, adj_list, nullptr, nullptr);
    GraphNodeSet all_nodes = adj_list.AllNodes();

    auto rc_out = nc::make_unique<tm_gen::RoutingConfiguration>(*tm);
    rc_out->set_time_to_compute(milliseconds(0));
    rc_out->set_optimizer_string(FLAGS_wopt_name);

    for (const auto& aggregate_and_demand : tm->demands()) {
        const tm_gen::AggregateId& aggregate_id = aggregate_and_demand.first;

        GraphNodeIndex src = aggregate_id.src();
        std::string src_id = graph->GetNode(src)->id();

        GraphNodeIndex dst = aggregate_id.dst();
        std::string dst_id = graph->GetNode(dst)->id();

        if (! FLAGS_use_ecmp) {
            uint64_t distance = apsp.GetDistance(src, dst).count();
            std::unique_ptr<Walk> sp = apsp.GetPath(src, dst);
            const nc::net::Walk* new_sp = path_provider.TakeOwnership(std::move(sp));

//            LOG(INFO) << "Doing src " << src << ", dst " << dst << ", with distance " << std::to_string(distance);
            rc_out->AddRouteAndFraction(aggregate_id, {{new_sp, 1.0}});
        }

        else {
//            LOG(INFO) << "Doing src " << src_id << ", dst " << dst_id;
            std::vector<tm_gen::RouteAndFraction> rfs;

            ECMPTreeNode* root = new ECMPTreeNode(1., false, src_id, nullptr);
            std::vector<ECMPTreeNode*>* leaves = new std::vector<ECMPTreeNode*>();

            AddECMPNodesRecursively(root, dst_id, *graph, adj_list, apsp, leaves);

            for (ECMPTreeNode* leaf_node : *leaves) {
                double leaf_fraction = leaf_node->fraction;
                std::vector<std::string> nodes_to_parent;

                ECMPTreeNode* ptr = leaf_node;
                while(ptr != nullptr) {
                    std::string node_on_path_id = ptr->graph_node_id;
//                    LOG(INFO) << "Node on path is " << node_on_path_id;
                    nodes_to_parent.push_back(node_on_path_id);
                    ptr = ptr->parent;
                }

                // LOG(INFO) << "Size of nodes to parent is " << nodes_to_parent.size();
                std::reverse(nodes_to_parent.begin(), nodes_to_parent.end());


                nc::net::Links links_on_path;
                for (size_t j = 0; j < nodes_to_parent.size() - 1; ++j) {
                  // LOG(INFO) << "Link from " << nodes_to_parent[j] << " to " << nodes_to_parent[j+1];
                  auto link = graph->LinkOrDie(nodes_to_parent[j], nodes_to_parent[j + 1]);
                  links_on_path.emplace_back(link);
                }

                auto new_path = nc::make_unique<nc::net::Walk>(links_on_path, *graph);
                const nc::net::Walk* path_ptr =
                    path_provider.TakeOwnership(std::move(new_path));

                rfs.push_back({path_ptr, leaf_fraction});


            }

            rc_out->AddRouteAndFraction(aggregate_id, rfs);
            for (ECMPTreeNode* leaf_node : *leaves) {
                delete leaf_node;
            }
            if(root != nullptr) {
                delete root;
            }
            delete leaves;

        }

    }

    std::string serialized = rc_out->SerializeToText(node_order);
    std::string out = nc::StringReplace(FLAGS_top_file, ".graph", ".rc", true);
    nc::File::WriteStringToFileOrDie(serialized, out);


//    LOG(INFO) << "Successfully computed shortest paths. Exiting... ";

}
