#include "reduce.h"
#include "allocate.h"
#include "collect.h"
#include "inline.h"
#include "print.h"
#include "rank.h"
#include "report.h"
#include "trail.h"

#include <inttypes.h>
#include <math.h>

bool kissat_reducing (kissat *solver) {
  if (!GET_OPTION (reduce))
    return false;
  if (!solver->statistics.clauses_redundant)
    return false;
  if (CONFLICTS < solver->limits.reduce.conflicts)
    return false;
  return true;
}

typedef struct reducible reducible;

struct reducible {
  uint64_t rank;
  unsigned ref;
};

#define RANK_REDUCIBLE(RED) (RED).rank

// clang-format off
typedef STACK (reducible) reducibles;
// clang-format on

static bool collect_reducibles (kissat *solver, reducibles *reds,
                                reference start_ref) {
  assert (start_ref != INVALID_REF);
  assert (start_ref <= SIZE_STACK (solver->arena));
  ward *const arena = BEGIN_STACK (solver->arena);
  clause *start = (clause *) (arena + start_ref);
  const clause *const end = (clause *) END_STACK (solver->arena);
  assert (start < end);
  while (start != end && (!start->redundant || start->keep))
    start = kissat_next_clause (start);
  if (start == end) {
    solver->first_reducible = INVALID_REF;
    LOG ("no reducible clause candidate left");
    return false;
  }
  const reference redundant = (ward *) start - arena;
#ifdef LOGGING
  if (redundant < solver->first_reducible)
    LOG ("updating redundant clauses start from %zu to %zu",
         (size_t) solver->first_reducible, (size_t) redundant);
  else
    LOG ("no update to redundant clauses start %zu",
         (size_t) solver->first_reducible);
#endif
  solver->first_reducible = redundant;
  const unsigned tier2 = GET_OPTION (tier2);
  for (clause *c = start; c != end; c = kissat_next_clause (c)) {
    if (!c->redundant)
      continue;
    if (c->garbage)
      continue;
    if (c->reason)
      continue;
    if (c->keep)
      continue;
    if (c->used) {
      c->used--;
      if (c->glue <= tier2)
        continue;
    }
    assert (!c->garbage);
    assert (kissat_clause_in_arena (solver, c));
    reducible red;
    const uint64_t negative_size = ~c->size;
    const uint64_t negative_glue = ~c->glue;
    red.rank = negative_size | (negative_glue << 32);
    red.ref = (ward *) c - arena;
    PUSH_STACK (*reds, red);
  }
  if (EMPTY_STACK (*reds)) {
    LOG ("did not find any reducible redundant clause");
    return false;
  }
  return true;
}

#define USEFULNESS RANK_REDUCIBLE

static void sort_reducibles (kissat *solver, reducibles *reds) {
  RADIX_STACK (reducible, uint64_t, *reds, USEFULNESS);
}

static void mark_less_useful_clauses_as_garbage (kissat *solver,
                                                 reducibles *reds) {
  const size_t size = SIZE_STACK (*reds);
  double fraction = GET_OPTION (reducefraction) / 100.0;
  size_t target = size * fraction;
#ifndef QUIET
  statistics *statistics = &solver->statistics;
  const size_t clauses =
      statistics->clauses_irredundant + statistics->clauses_redundant;
  kissat_phase (solver, "reduce", GET (reductions),
                "reducing %zu (%.0f%%) out of %zu (%.0f%%) "
                "reducible clauses",
                target, kissat_percent (target, size), size,
                kissat_percent (size, clauses));
#endif
  unsigned reduced = 0;
  ward *arena = BEGIN_STACK (solver->arena);
  const reducible *const begin = BEGIN_STACK (*reds);
  const reducible *const end = END_STACK (*reds);
  for (const reducible *p = begin; p != end && target--; p++) {
    clause *c = (clause *) (arena + p->ref);
    assert (kissat_clause_in_arena (solver, c));
    assert (!c->garbage);
    assert (!c->keep);
    assert (!c->reason);
    assert (c->redundant);
    LOGCLS (c, "reducing");
    kissat_mark_clause_as_garbage (solver, c);
    reduced++;
  }
  ADD (clauses_reduced, reduced);
}

static bool compacting (kissat *solver) {
  if (!GET_OPTION (compact))
    return false;
  unsigned inactive = solver->vars - solver->active;
  unsigned limit = GET_OPTION (compactlim) / 1e2 * solver->vars;
  bool compact = (inactive > limit);
  LOG ("%u inactive variables %.0f%% <= limit %u %.0f%%", inactive,
       kissat_percent (inactive, solver->vars), limit,
       kissat_percent (limit, solver->vars));
  return compact;
}

int kissat_reduce (kissat *solver) {
  START (reduce);
  INC (reductions);
  kissat_phase (solver, "reduce", GET (reductions),
                "reduce limit %" PRIu64 " hit after %" PRIu64 " conflicts",
                solver->limits.reduce.conflicts, CONFLICTS);
  bool compact = compacting (solver);
  reference start = compact ? 0 : solver->first_reducible;
  if (start != INVALID_REF) {
#ifndef QUIET
    size_t arena_size = SIZE_STACK (solver->arena);
    size_t words_to_sweep = arena_size - start;
    size_t bytes_to_sweep = sizeof (word) * words_to_sweep;
    kissat_phase (solver, "reduce", GET (reductions),
                  "reducing clauses after offset %" REFERENCE_FORMAT
                  " in arena",
                  start);
    kissat_phase (solver, "reduce", GET (reductions),
                  "reducing %zu words %s %.0f%%", words_to_sweep,
                  FORMAT_BYTES (bytes_to_sweep),
                  kissat_percent (words_to_sweep, arena_size));
#endif
    if (kissat_flush_and_mark_reason_clauses (solver, start)) {
      reducibles reds;
      INIT_STACK (reds);
      if (collect_reducibles (solver, &reds, start)) {
        sort_reducibles (solver, &reds);
        mark_less_useful_clauses_as_garbage (solver, &reds);
        RELEASE_STACK (reds);
        kissat_sparse_collect (solver, compact, start);
      } else if (compact)
        kissat_sparse_collect (solver, compact, start);
      else
        kissat_unmark_reason_clauses (solver, start);
    } else
      assert (solver->inconsistent);
  } else
    kissat_phase (solver, "reduce", GET (reductions), "nothing to reduce");
  UPDATE_CONFLICT_LIMIT (reduce, reductions, SQRT, false);
  REPORT (0, '-');
  STOP (reduce);
  return solver->inconsistent ? 20 : 0;
}

static void mark_ref_clauses_as_garbage (kissat *solver, unsigned *refs, unsigned size) {
// not easy to get reds_size, ignore below
// #ifndef QUIET
//   statistics *statistics = &solver->statistics;
//   const size_t clauses =
//       statistics->clauses_irredundant + statistics->clauses_redundant;
//   kissat_phase (solver, "reduce", GET (reductions),
//                 "reducing %zu (%.0f%%) out of %zu (%.0f%%) "
//                 "reducible clauses",
//                 size, kissat_percent (size, reds_size), reds_size,
//                 kissat_percent (reds_size, clauses));
// #endif
  unsigned reduced = 0;
  ward *arena = BEGIN_STACK (solver->arena);
  for (unsigned i = 0; i < size; i++) {
    clause *c = (clause *) (arena + refs[i]);
    assert (kissat_clause_in_arena (solver, c));
    assert (!c->garbage);
    assert (!c->keep);
    assert (!c->reason);
    assert (c->redundant);
    LOGCLS (c, "reducing");
    kissat_mark_clause_as_garbage (solver, c);
    reduced++;
  }
  ADD (clauses_reduced, reduced);
}

void kissat_reduce_prestep (kissat *solver, rl_state *state) {
  START (reduce);
  INC (reductions);
  kissat_phase (solver, "reduce", GET (reductions),
                "reduce limit %" PRIu64 " hit after %" PRIu64 " conflicts",
                solver->limits.reduce.conflicts, CONFLICTS);
  bool compact = compacting (solver);
  reference start = compact ? 0 : solver->first_reducible;
  if (start == INVALID_REF) {exit (-1);}
#ifndef QUIET
  size_t arena_size = SIZE_STACK (solver->arena);
  size_t words_to_sweep = arena_size - start;
  size_t bytes_to_sweep = sizeof (word) * words_to_sweep;
  kissat_phase (solver, "reduce", GET (reductions),
                "reducing clauses after offset %" REFERENCE_FORMAT
                " in arena",
                start);
  kissat_phase (solver, "reduce", GET (reductions),
                "reducing %zu words %s %.0f%%", words_to_sweep,
                FORMAT_BYTES (bytes_to_sweep),
                kissat_percent (words_to_sweep, arena_size));
#endif
  if (!kissat_flush_and_mark_reason_clauses (solver, start)) {exit (-1);}
}

void kissat_reduce_collect (kissat *solver, rl_state *state) {
  CLEAR_STACK (state->clause_glues);
  CLEAR_STACK (state->literal_indices);
  CLEAR_STACK (state->clause_indices);
  CLEAR_STACK (state->clause_refs);

  bool compact = compacting (solver);
  reference start_ref = compact ? 0 : solver->first_reducible;
  REDUCE_ASSERT (start_ref != INVALID_REF, state);
  REDUCE_ASSERT (start_ref <= SIZE_STACK (solver->arena), state);
  ward *const arena = BEGIN_STACK (solver->arena);
  clause *start = (clause *) (arena + start_ref);
  const clause *const end = (clause *) END_STACK (solver->arena);
  REDUCE_ASSERT (start < end, state);
  while (start != end && (!start->redundant || start->keep))
    start = kissat_next_clause (start);
  if (start == end) {
    solver->first_reducible = INVALID_REF;
    LOG ("no reducible clause candidate left");
    state->clause_num = 0;
    return;
  }
  const reference redundant = (ward *) start - arena;
#ifdef LOGGING
  if (redundant < solver->first_reducible)
    LOG ("updating redundant clauses start from %zu to %zu",
         (size_t) solver->first_reducible, (size_t) redundant);
  else
    LOG ("no update to redundant clauses start %zu",
         (size_t) solver->first_reducible);
#endif
  solver->first_reducible = redundant;
  const unsigned tier2 = GET_OPTION (tier2);
  unsigned c_idx = 0;
  for (clause *c = start; c != end; c = kissat_next_clause (c)) {
    if (!c->redundant)
      continue;
    if (c->garbage)
      continue;
    if (c->reason)
      continue;
    if (c->keep)
      continue;
    if (c->used) {
      c->used--;
      if (c->glue <= tier2)
        continue;
    }
    assert (!c->garbage);
    assert (kissat_clause_in_arena (solver, c));
    unsigned size = c->size;
    for (unsigned i = 0; i < size; i++) {
      PUSH_STACK (state->literal_indices, c->lits[i]);
      PUSH_STACK (state->clause_indices, c_idx);
    }
    PUSH_STACK (state->clause_glues, c->glue);
    PUSH_STACK (state->clause_refs, (ward *) c - arena);
    c_idx += 1;
  }
  REDUCE_ASSERT (c_idx>0, state); // did not find any reducible redundant clause

  state->literal_num = LITS;
  state->clause_num = c_idx;
  state->edge_num = SIZE_STACK (state->literal_indices);
}

int kissat_reduce_step (kissat *solver, unsigned* refs, unsigned size) {
  mark_ref_clauses_as_garbage (solver, refs, size);

  bool compact = compacting (solver);
  reference start = compact ? 0 : solver->first_reducible;
  kissat_sparse_collect (solver, compact, start);

  UPDATE_CONFLICT_LIMIT (reduce, reductions, SQRT, false);
  REPORT (0, '-');
  STOP (reduce);
  return solver->inconsistent ? 20 : 0;
}
