#include <vector>
#include <cassert>
#include <queue>

#include <partition.hxx>


#ifdef TIME
    #include <time.hxx>
#endif

namespace preordering {



template<typename ARC_COSTS, typename ADJACENCY>
typename ARC_COSTS::VALUE_TYPE greedy_arc_insertion(const ARC_COSTS& arc_costs, ADJACENCY& adjacency)
{
    #ifdef TIME
        Clock clock;
        double t_setup = 0;
        double t_max = 0;
        double t_increase = 0;
        double t_reduce = 0;
    #endif

    typedef typename ARC_COSTS::VALUE_TYPE COST_TYPE;
    assert (arc_costs.size() == adjacency.size());
    // transitive_closure(adjacency);
    size_t n = arc_costs.size();

    COST_TYPE total_cost = 0;
    std::vector<std::vector<size_t>> successors(n);
    std::vector<std::vector<size_t>> predecessors(n);
    for (size_t i = 0; i < n; ++i)
    for (size_t j = 0; j < n; ++j)
    {
        if (adjacency(i, j))
        {
            total_cost += arc_costs(i, j);
            successors[i].push_back(j);
            predecessors[j].push_back(i);
        }
    }
    
    ARC_COSTS outgoing(n, 0);  // outgoing(i, j) is the sum of the costs of all arcs from i to successors of j
    for (size_t j = 0; j < n; ++j)
    {
        for (size_t k = 0; k < n; ++k)
        {
            if (!adjacency(j, k))
                continue;
            for (size_t i = 0; i < n; ++i)
            {
                if (adjacency(i, k))
                    continue;
                outgoing(i, j) += arc_costs(i, k);
            }
        }
    }
    // with this, the change in objective value for setting i->j to 1 is computed as
    // the sum of all outgoing(k, j) for all predecessors k of i
    ARC_COSTS gain(n, 0);
    for (size_t k = 0; k < n; ++k)
    {
        for (size_t i = 0; i < n; ++i)
        {
            if (!adjacency(k, i))
                continue;
            for (size_t j = 0; j < n; ++j)
            {
                gain(i, j) += outgoing(k, j);
            }
        }
    }
    // ARC_COSTS gain = arc_costs;

    struct ArcGain
    {
        size_t i;
        size_t j;
        COST_TYPE gain;

        bool operator<(const ArcGain& other) const
        {
            if (gain != other.gain)
                return gain < other.gain;
            if (i != other.i)
                return i > other.i;
            return j > other.j;
        }
    };

    #define QUEUE
    #ifdef QUEUE
        std::priority_queue<ArcGain> queue;
        size_t num_pos_gains;
        auto reset_queue = [&] ()
        {
            queue = std::priority_queue<ArcGain>();
            for (size_t i = 0; i < n; ++i)
            for (size_t j = 0; j < n; ++j)
            {
                if (!adjacency(i, j) && gain(i, j) >= 0)
                    queue.push({i, j, gain(i, j)});
            }
            num_pos_gains = queue.size();
        };
        reset_queue();
    #endif

    #ifdef TIME
        t_setup += clock.elapsed();
        clock.reset();
    #endif
        
    size_t num_iter = 0;
    while (true)
    {
        ++num_iter;
        #ifdef TIME
            clock.reset();
        #endif
        
        // find the arc with the largest gain
        size_t best_i, best_j;
        COST_TYPE best_gain = -1;
        #ifdef QUEUE

            if (queue.size() > 10*num_pos_gains)
                reset_queue();

            while (!queue.empty())
            {
                ArcGain arc_gain = queue.top();
                queue.pop();
                if (adjacency(arc_gain.i, arc_gain.j))
                    continue;
                if (arc_gain.gain == gain(arc_gain.i, arc_gain.j))
                {
                    best_gain = arc_gain.gain;
                    best_i = arc_gain.i;
                    best_j = arc_gain.j;
                    break;
                }
            }
        #else
            for (size_t i = 0; i < n; ++i)
            for (size_t j = 0; j < n; ++j)
            {
                if (!adjacency(i, j) && gain(i, j) > best_gain)
                {
                    best_gain = gain(i, j);
                    best_i = i;
                    best_j = j;
                }
            }
        #endif
        

        #ifdef TIME
            t_max += clock.elapsed();
            clock.reset();
        #endif
        if (best_gain < 0)
            break;
        assert (!adjacency(best_i, best_j));
        total_cost += best_gain;
        // insert the arc best_i->best_j
        for (size_t i = 0; i < n; ++i)
        {
            if (!adjacency(i, best_i))
                continue;
            for (size_t j = 0; j < n; ++j)
            {
                if (!adjacency(best_j, j))
                    continue;
                if (adjacency(i, j))
                    continue;

                // update the gains
                // 1. For all successors u of i and all predecessors v of j, gain(u, v) is decreased by arc_costs(i, j)
                
                #ifdef TIME
                    clock.reset();
                #endif

                for (size_t u : successors[i])
                {
                    for (size_t v : predecessors[j])
                    {
                        if (u == v)
                            continue;
                        gain(u, v) -= arc_costs(i, j);
                        #ifdef QUEUE
                            if (!adjacency(u, v) && gain(u, v) >= 0)
                                queue.push({u, v, gain(u, v)});
                        #endif
                    }
                }
                adjacency(i, j) = 1;
                successors[i].push_back(j);
                predecessors[j].push_back(i);


                #ifdef TIME
                    t_reduce += clock.elapsed();
                    clock.reset();
                #endif

                // 2. For all arcs (u, v) in the preorder, 
                //      - gain(j, u) is increased by arc_costs(i, v) if (i, v) is not already in the preorder
                //      - gain(v, i) is increased by arc_costs(u, j) if (u, j) is not already in the preorder
                for (size_t u = 0; u < n; ++u)
                {
                    for (size_t v : successors[u])
                    {
                        if (j!= u && !adjacency(i, v))
                        {
                            gain(j, u) += arc_costs(i, v);
                            #ifdef QUEUE
                                if (!adjacency(j, u) && gain(j, u) >= 0)
                                    queue.push({j, u, gain(j, u)});
                            #endif
                        }
                        if (v != i && !adjacency(u, j))
                        {
                            gain(v, i) += arc_costs(u, j);
                            #ifdef QUEUE
                                if (!adjacency(v, i) && gain(v, i) >= 0)
                                    queue.push({v, i, gain(v, i)});
                            #endif
                        }
                    }
                }
                #ifdef TIME
                    t_increase += clock.elapsed();
                    clock.reset();
                #endif
            }
        }
        // #define DEBUG
        #ifdef DEBUG
            ARC_COSTS outgoing_debug(n, 0);
            for (size_t j = 0; j < n; ++j)
            {
                for (size_t k = 0; k < n; ++k)
                {
                    if (!adjacency(j, k))
                        continue;
                    for (size_t i = 0; i < n; ++i)
                    {
                        if (adjacency(i, k))
                            continue;
                        outgoing_debug(i, j) += arc_costs(i, k);
                    }
                }
            }
            ARC_COSTS gain_debug(n, 0);
            for (size_t k = 0; k < n; ++k)
            {
                for (size_t i = 0; i < n; ++i)
                {
                    if (!adjacency(k, i))
                        continue;
                    for (size_t j = 0; j < n; ++j)
                    {
                        gain_debug(i, j) += outgoing_debug(k, j);
                    }
                }
            }

            bool diff = false;
            for (size_t i = 0; i < n; ++i)
            for (size_t j = 0; j < n; ++j)
            {
                if (gain_debug(i, j) != gain(i, j))
                {
                    std::cout << i << "->" << j << " " << gain_debug(i, j) << " " << gain(i, j) << "\n";
                    diff = true;
                }
            }
            if (diff)
                throw std::runtime_error("Error!");
        #endif
    }

    #ifdef TIME
        std::cout 
            << "   t_setup = " << t_setup << "\n"
            << "     t_max = " << t_max << "\n"
            << "  t_reduce = " << t_reduce << "\n"
            << "t_increase = " << t_increase << "\n";
        std::cout << "num iter = " << num_iter << "\n";
    #endif
    return total_cost;
}


}  //namespace preordering