from queue import PriorityQueue
from configurations import *
import pandas as pd
import numpy as np
from numba import njit
from itertools import combinations

# our learning-augmented algorithm
@njit
def run_alg(requests, predictions, L):
    alg = np.zeros((L, k), dtype=np.int32)
    j = 0
    config_id = start_config_id

    for t in range(T):
        for r in requests[t]:
            if (configBinary[config_id] & (1 << r)) == 0:
                min_val = 1e9
                new_config_id = -1

                for p in config_array[config_id]:
                    candidate_id = configBinaryInverse[configBinary[config_id] ^ (1 << r) ^ (
                                1 << p)]  # id of configuration obtained by swapping p with r

                    if predictions[t, candidate_id] + distances[p, r] < min_val:
                        min_val = predictions[t, candidate_id] + distances[p, r]
                        new_config_id = candidate_id

                assert new_config_id != -1
                config_id = new_config_id

            alg[j, :] = config_array[config_id]
            j += 1

    assert j == L
    return alg


# Work function algorithm
@njit
def run_wfa(initial_wf, requests, L):
    wf = initial_wf
    alg = np.zeros((L, k), dtype=np.int32)
    j = 0
    config_id = start_config_id

    for t in range(T):
        for r in requests[t]:
            wf = update_wf(wf, r)

            if (configBinary[config_id] & (1 << r)) == 0:

                min_val = 1e9
                new_config_id = -1

                for p in config_array[config_id]:
                    candidate_id = configBinaryInverse[configBinary[config_id] ^ (1 << r) ^ (
                                1 << p)]  # id of configuration obtained by swapping p with r
                    if wf[candidate_id] + distances[p, r] < min_val:
                        min_val = wf[candidate_id] + distances[p, r]
                        new_config_id = candidate_id

                config_id = new_config_id
                assert new_config_id != -1

            alg[j, :] = config_array[config_id]
            j += 1

    assert j == L

    opt = np.min(wf)
    return alg, opt


@njit
def run_double_coverage(requests, L):
    config = start_config_array.copy()
    alg = np.zeros((L, k), dtype=np.int32)
    p = 0

    for i in range(T):
        for r in requests[i]:
            if r not in config:
                if r < config[0]:  # request is to the left of the leftmost server
                    config[0] = r
                elif r > config[k - 1]:  # request is to the right of the rightmost server
                    config[k - 1] = r
                else:
                    for j in range(k - 1):
                        if config[j] < r and r < config[j + 1]:  # request is between j-th and (j+1)-st servers
                            movement_length = min(r - config[j], config[j + 1] - r)
                            config[j] += movement_length
                            config[j + 1] -= movement_length
                            break

            alg[p, :] = config
            p += 1

    assert p == L
    return alg


# check that the suggested solution is valid
@njit
def validate(requests, alg, L):
    assert len(alg) == L
    p = 0
    for i in range(T):
        for r in requests[i]:
            assert len(alg[p]) == k  # exactly k servers are being used
            assert r in alg[p]  # request is served properly
            assert (alg[p] == np.sort(alg[p])).all()  # server positions are in increasing order
            p += 1


# compute the cost of a given solution
@njit
def compute_cost(alg, L):
    def cost_per_step(prev_config, curr_config):
        return sum(
            [abs(prev_config[i] - curr_config[i]) for i in range(k)])  # servers never cross each other on the line

    cost = cost_per_step(start_config_array, alg[0])  # cost paid at first step

    for i in range(L):
        cost += cost_per_step(alg[i - 1], alg[i])

    return cost



# initialize to cone wf
def initialize_wf(initial_config):
    config_space = combinations(list(range(n)), k)
    w = {frozenset(c): 1e9 for c in config_space}
    w[initial_config] = 0

    # do bfs
    queue = PriorityQueue()
    queue.put((0, initial_config))

    while not queue.empty():
        (val, config) = queue.get()
        if (val != w[config]):
            continue

        for p in config:
            for r in frozenset(range(n)) - config:
                if w[config - {p} | {r}] > val + distances[p][r]:
                    w[config - {p} | {r}] = val + distances[p][r]
                    queue.put((val + distances[p][r], config - {p} | {r}))

    # convert from dict to array
    w_array = np.full(num_config, 1e9, dtype=np.int32)

    for i in range(num_config):
        w_array[i] = w[frozenset(config_array[i])]

    return w_array


initial_wf = initialize_wf(start_config_set)