#include "decide.h"
#include "inlineframes.h"
#include "inlineheap.h"
#include "inlinequeue.h"
#include "print.h"

#include <inttypes.h>

static unsigned last_enqueued_unassigned_variable (kissat *solver) {
  assert (solver->unassigned);
  const links *const links = solver->links;
  const value *const values = solver->values;
  unsigned res = solver->queue.search.idx;
  if (values[LIT (res)]) {
    do {
      res = links[res].prev;
      assert (!DISCONNECTED (res));
    } while (values[LIT (res)]);
    kissat_update_queue (solver, links, res);
  }
#ifdef LOGGING
  const unsigned stamp = links[res].stamp;
  LOG ("last enqueued unassigned %s stamp %u", LOGVAR (res), stamp);
#endif
#ifdef CHECK_QUEUE
  for (unsigned i = links[res].next; !DISCONNECTED (i); i = links[i].next)
    assert (VALUE (LIT (i)));
#endif
  return res;
}

static unsigned largest_score_unassigned_variable (kissat *solver) {
  heap *scores = SCORES;
  unsigned res = kissat_max_heap (scores);
  const value *const values = solver->values;
  while (values[LIT (res)]) {
    kissat_pop_max_heap (solver, scores);
    res = kissat_max_heap (scores);
  }
#if defined(LOGGING) || defined(CHECK_HEAP)
  const double score = kissat_get_heap_score (scores, res);
#endif
  LOG ("largest score unassigned %s score %g", LOGVAR (res), score);
#ifdef CHECK_HEAP
  for (all_variables (idx)) {
    if (!ACTIVE (idx))
      continue;
    if (VALUE (LIT (idx)))
      continue;
    const double idx_score = kissat_get_heap_score (scores, idx);
    assert (score >= idx_score);
  }
#endif
  return res;
}

void kissat_start_random_sequence (kissat *solver) {
  if (!GET_OPTION (randec))
    return;

  if (solver->stable && !GET_OPTION (randecstable))
    return;

  if (!solver->stable && !GET_OPTION (randecfocused))
    return;

  if (solver->randec)
    kissat_very_verbose (solver,
                         "continuing random decision sequence "
                         "at %s conflicts",
                         FORMAT_COUNT (CONFLICTS));
  else {
    INC (random_sequences);
    const uint64_t count = solver->statistics.random_sequences;
    const unsigned length = GET_OPTION (randeclength) * LOGN (count);
    kissat_very_verbose (solver,
                         "starting random decision sequence "
                         "at %s conflicts for %s conflicts",
                         FORMAT_COUNT (CONFLICTS), FORMAT_COUNT (length));
    solver->randec = length;

    UPDATE_CONFLICT_LIMIT (randec, random_sequences, LOGN, false);
  }
}

static unsigned next_random_decision (kissat *solver) {
  if (!VARS) {
    // printf("not vars\n");
    return INVALID_IDX;
  }

  if (!GET_OPTION (randec)) {
    // printf("not randec\n");
    return INVALID_IDX;
  }

  if (solver->stable && !GET_OPTION (randecstable)) {
    // printf("not randecstable\n");
    return INVALID_IDX;
  }

  if (!solver->stable && !GET_OPTION (randecfocused)) {
    // printf("not randecfocused\n");
    return INVALID_IDX;
  }

  if (!solver->randec) {
    assert (solver->level);
    if (solver->level > 1) {
    //   printf("level too large\n");
      return INVALID_IDX;
    }

    uint64_t conflicts = CONFLICTS;
    limits *limits = &solver->limits;
    if (conflicts < limits->randec.conflicts) {
    //   printf("limit reached\n");
      return INVALID_IDX;
    }

    kissat_start_random_sequence (solver);
  }

  for (;;) {
    unsigned idx = kissat_next_random32 (&solver->random) % VARS;
    if (!ACTIVE (idx))
      continue;
    unsigned lit = LIT (idx);
    if (solver->values[lit])
      continue;
    return idx;
  }
}

static unsigned kissat_next_decision_variable (kissat *solver) {
#ifdef LOGGING
  const char *type = 0;
#endif
  unsigned res = next_random_decision (solver);
  if (res == INVALID_IDX) {
    if (solver->stable) {
#ifdef LOGGING
      type = "maximum score";
#endif
      res = largest_score_unassigned_variable (solver);
      // printf("score decision\n");
      INC (score_decisions);
    } else {
#ifdef LOGGING
      type = "dequeued";
#endif
      res = last_enqueued_unassigned_variable (solver);
      // printf("queued decision\n");
      INC (queue_decisions);
    }
  } else {
#ifdef LOGGING
    type = "random";
#endif
    // printf("random decision\n");
    INC (random_decisions);
  }
  LOG ("next %s decision %s", type, LOGVAR (res));
  return res;
}

static inline value decide_phase (kissat *solver, unsigned idx) {
  bool force = GET_OPTION (forcephase);

  value *target;
  if (force)
    target = 0;
  else if (!GET_OPTION (target))
    target = 0;
  else if (solver->stable || GET_OPTION (target) > 1)
    target = solver->phases.target + idx;
  else
    target = 0;

  value *saved;
  if (force)
    saved = 0;
  else if (GET_OPTION (phasesaving))
    saved = solver->phases.saved + idx;
  else
    saved = 0;

  value res = 0;

  if (!solver->stable) {
    switch ((solver->statistics.switched >> 1) & 7) {
    case 1:
      res = INITIAL_PHASE;
      break;
    case 3:
      res = -INITIAL_PHASE;
      break;
    }
  }

  if (!res && target && (res = *target)) {
    LOG ("%s uses target decision phase %d", LOGVAR (idx), (int) res);
    INC (target_decisions);
  }

  if (!res && saved && (res = *saved)) {
    LOG ("%s uses saved decision phase %d", LOGVAR (idx), (int) res);
    INC (saved_decisions);
  }

  if (!res) {
    res = INITIAL_PHASE;
    LOG ("%s uses initial decision phase %d", LOGVAR (idx), (int) res);
    INC (initial_decisions);
  }
  assert (res);

  return res;
}

void kissat_decide (kissat *solver) {
  START (decide);
  assert (solver->unassigned);
  INC (decisions);
  if (solver->stable)
    INC (stable_decisions);
  else
    INC (focused_decisions);
  solver->level++;
  assert (solver->level != INVALID_LEVEL);
  const unsigned idx = kissat_next_decision_variable (solver);
  const value value = decide_phase (solver, idx);
  unsigned lit = LIT (idx);
  if (value < 0)
    lit = NOT (lit);
  kissat_push_frame (solver, lit);
  assert (solver->level < SIZE_STACK (solver->frames));
  LOG ("decide literal %s", LOGLIT (lit));
  kissat_assign_decision (solver, lit);
  STOP (decide);
}

// void kissat_decide2 (kissat *solver) {
//   START (decide);
//   assert (solver->unassigned);
//   INC (decisions);
//   if (solver->stable)
//     INC (stable_decisions);
//   else
//     INC (focused_decisions);
//   solver->level++;
//   assert (solver->level != INVALID_LEVEL);
//   int n = kissat_candidate_num(solver);
//   for (int i = 0; i < MIN(n, 2); i++) {
//     const unsigned idx = kissat_next_decision_variable (solver);
//     const value value = decide_phase (solver, idx);
//     unsigned lit = LIT (idx);
//     if (value < 0)
//       lit = NOT (lit);
//     kissat_push_frame (solver, lit);
//     assert (solver->level < SIZE_STACK (solver->frames));
//     LOG ("decide literal %s", LOGLIT (lit));
//     kissat_assign_decision (solver, lit);
//   }
//   STOP (decide);
// }

// int kissat_candidate_num(kissat *solver) {
//   int n = 0;
//   for (unsigned int i = 0; i < LITS; i++) {
//     if (!ACTIVE (i))
//       continue;
//     unsigned lit = LIT (i);
//     if (solver->values[lit])
//       continue;
//     n = n + 1;
//   }
//   return n;
// }

void kissat_decide_prestep (kissat *solver, rl_state *state) {
  START (decide);
  assert (solver->unassigned);
  INC (decisions);
  if (solver->stable)
    INC (stable_decisions);
  else
    INC (focused_decisions);
  solver->level++;
  assert (solver->level != INVALID_LEVEL);
}

void kissat_decide_collect (kissat *solver, rl_state *state) {
  CLEAR_STACK (state->literal_values);
  CLEAR_STACK (state->literal_candidates);
  for (unsigned int i = 0; i < LITS; i++) {
    PUSH_STACK (state->literal_values, solver->values[i]);
    PUSH_STACK (state->literal_candidates, ACTIVE(IDX(i)) && !(solver->values[i]));
  }
}

void kissat_decide_step(kissat *solver, unsigned literal, bool use_original) {
  unsigned lit = literal;
  if (use_original) { // use the original algorithm
    const unsigned idx = kissat_next_decision_variable (solver);
    const value value = decide_phase (solver, idx);
    lit = LIT (idx);
    if (value < 0)
      lit = NOT (lit);
  }
  kissat_push_frame (solver, lit);
  assert (solver->level < SIZE_STACK (solver->frames));
  LOG ("decide literal %s", LOGLIT (lit));
  kissat_assign_decision (solver, lit);
  STOP (decide);
}

void kissat_internal_assume (kissat *solver, unsigned lit) {
  assert (solver->unassigned);
  assert (!VALUE (lit));
  solver->level++;
  assert (solver->level != INVALID_LEVEL);
  kissat_push_frame (solver, lit);
  assert (solver->level < SIZE_STACK (solver->frames));
  LOG ("assuming literal %s", LOGLIT (lit));
  kissat_assign_decision (solver, lit);
}
