#include "search.h"
#include "analyze.h"
#include "bump.h"
#include "decide.h"
#include "eliminate.h"
#include "internal.h"
#include "logging.h"
#include "print.h"
#include "probe.h"
#include "propsearch.h"
#include "reduce.h"
#include "reluctant.h"
#include "report.h"
#include "restart.h"
#include "terminate.h"
#include "trail.h"
#include "walk.h"

#include <inttypes.h>

static void start_search (kissat *solver) {
  START (search);
  INC (searches);

  REPORT (0, '*');

  bool stable = (GET_OPTION (stable) == 2);

  solver->stable = stable;
  kissat_phase (solver, "search", GET (searches),
                "initializing %s search after %" PRIu64 " conflicts",
                (stable ? "stable" : "focus"), CONFLICTS);

  kissat_init_averages (solver, &AVERAGES);

  if (solver->stable) {
    kissat_init_reluctant (solver);
    kissat_update_scores (solver);
  }

  kissat_init_limits (solver);

  unsigned seed = GET_OPTION (seed);
  solver->random = seed;
  LOG ("initialized random number generator with seed %u", seed);

#ifndef QUIET
  limits *limits = &solver->limits;
  limited *limited = &solver->limited;
  if (!limited->conflicts && !limited->decisions)
    kissat_very_verbose (solver, "starting unlimited search");
  else if (limited->conflicts && !limited->decisions)
    kissat_very_verbose (
        solver, "starting search with conflicts limited to %" PRIu64,
        limits->conflicts);
  else if (!limited->conflicts && limited->decisions)
    kissat_very_verbose (
        solver, "starting search with decisions limited to %" PRIu64,
        limits->decisions);
  else
    kissat_very_verbose (
        solver,
        "starting search with decisions limited to %" PRIu64
        " and conflicts limited to %" PRIu64,
        limits->decisions, limits->conflicts);
  if (stable) {
    START (stable);
    REPORT (0, '[');
  } else {
    START (focused);
    REPORT (0, '{');
  }
#endif
}

static void stop_search (kissat *solver, int res) {
  if (solver->limited.conflicts) {
    LOG ("reset conflict limit");
    solver->limited.conflicts = false;
  }

  if (solver->limited.decisions) {
    LOG ("reset decision limit");
    solver->limited.decisions = false;
  }

  if (solver->termination.flagged) {
    kissat_very_verbose (solver, "termination forced externally");
    solver->termination.flagged = 0;
  }

#ifndef QUIET
  LOG ("search result %d", res);
  if (solver->stable) {
    REPORT (0, ']');
    STOP (stable);
    solver->stable = false;
  } else {
    REPORT (0, '}');
    STOP (focused);
  }
  char type = (res == 10 ? '1' : res == 20 ? '0' : '?');
  REPORT (0, type);
#else
  (void) res;
#endif

  STOP (search);
}

static void iterate (kissat *solver) {
  assert (solver->iterating);
  solver->iterating = false;
  REPORT (0, 'i');
}

static bool conflict_limit_hit (kissat *solver) {
  if (!solver->limited.conflicts)
    return false;
  if (solver->limits.conflicts > solver->statistics.conflicts)
    return false;
  kissat_very_verbose (
      solver, "conflict limit %" PRIu64 " hit after %" PRIu64 " conflicts",
      solver->limits.conflicts, solver->statistics.conflicts);
  return true;
}

static bool decision_limit_hit (kissat *solver) {
  if (!solver->limited.decisions)
    return false;
  if (solver->limits.decisions > solver->statistics.decisions)
    return false;
  kissat_very_verbose (
      solver, "decision limit %" PRIu64 " hit after %" PRIu64 " decisions",
      solver->limits.decisions, solver->statistics.decisions);
  return true;
}

int kissat_search (kissat *solver) {
  start_search (solver);
  int res = solver->inconsistent ? 20 : 0;
  while (!res) {
    clause *conflict = kissat_search_propagate (solver);
    if (conflict)
      res = kissat_analyze (solver, conflict);
    else if (solver->iterating)
      iterate (solver);
    else if (!solver->unassigned)
      res = 10;
    else if (TERMINATED (search_terminated_1))
      break;
    else if (conflict_limit_hit (solver))
      break;
    else if (kissat_reducing (solver))
      res = kissat_reduce (solver);
    else if (kissat_switching_search_mode (solver))
      kissat_switch_search_mode (solver);
    else if (kissat_restarting (solver))
      kissat_restart (solver);
    else if (kissat_rephasing (solver))
      kissat_rephase (solver);
    else if (kissat_probing (solver))
      res = kissat_probe (solver);
    else if (kissat_eliminating (solver))
      res = kissat_eliminate (solver);
    else if (decision_limit_hit (solver))
      break;
    else
      kissat_decide (solver);
  }
  stop_search (solver, res);
  return res;
}

// RL

static int res_rl;
static bool reducing_rl;
static bool deciding_rl;

int kissat_search_until_reduce (kissat *solver, bool cont) {
  if (cont) {
    goto reduce_statement; // bring the control flow to right after reduce
  }
  while (!res_rl) {
    clause *conflict = kissat_search_propagate (solver);
    if (conflict)
      res_rl = kissat_analyze (solver, conflict);
    else if (solver->iterating)
      iterate (solver);
    else if (!solver->unassigned)
      res_rl = 10;
    else if (TERMINATED (search_terminated_1))
      return res_rl;
    else if (conflict_limit_hit (solver))
      return res_rl;
    else if (kissat_reducing (solver)) {
      reducing_rl = true;
      return res_rl;
reduce_statement:
      ;
    }
    else if (kissat_switching_search_mode (solver))
      kissat_switch_search_mode (solver);
    else if (kissat_restarting (solver))
      kissat_restart (solver);
    else if (kissat_rephasing (solver))
      kissat_rephase (solver);
    else if (kissat_probing (solver))
      res_rl = kissat_probe (solver);
    else if (kissat_eliminating (solver))
      res_rl = kissat_eliminate (solver);
    else if (decision_limit_hit (solver))
      return res_rl;
    else
      kissat_decide (solver);
  }
  assert (false);
}

int kissat_search_until_decide (kissat *solver, int interval) {
  bool restarted = false;
  int step = 1;
  while (!res_rl) {
    clause *conflict = kissat_search_propagate (solver);
    if (conflict)
      res_rl = kissat_analyze (solver, conflict);
    else if (solver->iterating)
      iterate (solver);
    else if (!solver->unassigned)
      res_rl = 10;
    else if (TERMINATED (search_terminated_1)) {
      return res_rl;
    }
    else if (conflict_limit_hit (solver)) {
      return res_rl;
    }
    else if (kissat_reducing (solver)) {
      res_rl = kissat_reduce (solver);
    }
    else if (kissat_switching_search_mode (solver))
      kissat_switch_search_mode (solver);
    else if (kissat_restarting (solver)) {
      kissat_restart (solver);
      restarted = true;
    }
    else if (kissat_rephasing (solver))
      kissat_rephase (solver);
    else if (kissat_probing (solver))
      res_rl = kissat_probe (solver);
    else if (kissat_eliminating (solver))
      res_rl = kissat_eliminate (solver);
    else if (decision_limit_hit (solver)) {
      return res_rl;
    }
    else {
      if (
        (interval == 0 && restarted) ||
        (interval > 0 && step >= interval)
      ) {
        deciding_rl = true;
        return res_rl;
      }
      else {
        kissat_decide (solver);
        step += 1;
      }
    }
  }
  return res_rl;
}

void kissat_search_influence (kissat *solver, rl_state *state, unsigned *indices, unsigned num) {
  kissat_bump_influence (solver, indices, num);
}

void kissat_search_reduce_step (kissat *solver, rl_state *state, unsigned *refs, unsigned size) {
  bool cont = true;
  if (size == 0) { // dummy call, init search
    start_search (solver);
    res_rl = solver->inconsistent ? 20 : 0;
    cont = false; // start the search loop
  } else {
    res_rl = kissat_reduce_step (solver, refs, size);
  }
  reducing_rl = false;
  res_rl = kissat_search_until_reduce(solver, cont);
  if (!reducing_rl) { // indicate finish
    state->literal_num = 0;
  }
  else {
    kissat_reduce_prestep (solver, state);
    kissat_reduce_collect (solver, state);
    kissat_decide_collect (solver, state);
  }
}

void kissat_search_decide_step (kissat *solver, rl_state *state, unsigned literal, int mode, int interval) {
  /*
  mode = 0: dummy call, init search
  mode = 1: use original heuristic
  mode = others: use outside heuristic

  interval > 0: return at the next (interval)th decision
  interval = 0: return at the next restart
  interval < 0: return after the search is done
  */
  if (mode == 0) {
    start_search (solver);
    res_rl = solver->inconsistent ? 20 : 0;
  }
  else {
    bool use_original = (mode == 1) ? true : false;
    kissat_decide_step(solver, literal, use_original);
  }
  deciding_rl = false;
  res_rl = kissat_search_until_decide(solver, interval);
  if (!deciding_rl) { // indicate finish
    state->literal_num = 0;
  }
  else {
    kissat_decide_prestep (solver, state);
    kissat_reduce_collect (solver, state);
    kissat_decide_collect (solver, state);
  }
}

int kissat_search_close (kissat *solver) {
  stop_search (solver, res_rl);
  return res_rl;
}