from configurations import *
from constants import *
import pandas as pd
import numpy as np
from numba import njit
from itertools import combinations
from algorithms import *

@njit
def compute_optimal_duals(requests):
    w = np.zeros((T, num_config), dtype=np.int32)
    average_w = np.zeros((T, num_config))

    w[T-1] = np.zeros(num_config, dtype=np.int32)

    for i in range(T-1, 0, -1):
        curr_w = w[i].copy()
        sum_w = np.zeros(num_config, dtype=np.int64)
        sum_w += curr_w
        to_divide = 1

        for r in requests[i][::-1]:
            curr_w = update_wf(curr_w, r)
            sum_w += curr_w
            to_divide += 1

        w[i-1] = curr_w
        average_w[i-1] = sum_w / to_divide

    if TAKE_AVERAGE_WF:
        return average_w
    else:
        return w


# assume for now that there is a single state in the Markov chain (i.e. the predictions depend only on the current time)
# compute the average future wf
def learn(optimal_duals_by_day):
    f = np.zeros((T, num_config))
    M = len(optimal_duals_by_day)

    for (j, w) in enumerate(optimal_duals_by_day):
        for t in range(T):
            for i in range(num_config):
                f[t, i] += w[t, i]

    for t in range(T):
        for i in range(num_config):
            f[t, i] /= M

    return f

def create_requests_for_today(df_day, day):
    next_day = day + pd.Timedelta(days=1)
    requests = []

    # create blocks of requests for the current day, where each block is 15 min long
    for t_start in pd.date_range(start=day, end=next_day, freq="15min", inclusive="left"):
        t_end = t_start + pd.Timedelta(minutes=15)

        if USE_MOST_REQ_PER_MIN:
            most_freq_requested = df_day.loc[t_start:t_end].resample("min")['requests'].agg(lambda x: x.mode().iloc[0] if not x.empty else 0).values
            requests.append(most_freq_requested.copy())
        else:
            requests.append(df_day['requests'].loc[t_start:t_end].values)

    assert len(requests) == T
    return requests


def train(df, start_day, end_day):
    num_days = (end_day - start_day).days
    optimal_duals_by_day = [[] for _ in
                              range(num_days)]  # for each day and each time block, store the future work function

    for day_id in range(num_days):
        day = start_day + pd.Timedelta(days=day_id)
        df_day = df.loc[str(day.date())]

        requests = create_requests_for_today(df_day, day)
        optimal_duals_by_day[day_id] = compute_optimal_duals(requests)

    return learn(optimal_duals_by_day)


def test(df, start_day, end_day, f):
    num_days = (end_day - start_day).days
    costs = [{} for _ in range(num_days)]

    for day_id in range(num_days):
        day = start_day + pd.Timedelta(days=day_id)
        df_day = df.loc[str(day.date())]

        requests = create_requests_for_today(df_day, day)
        L = sum([len(requests[i]) for i in range(T)])

        predictions = np.array([f[t, :] for t in range(T)])

        alg = run_alg(requests, predictions, L)
        validate(requests, alg, L)
        costs[day_id]['alg'] = compute_cost(alg, L)

        wfa, opt = run_wfa(initial_wf, requests, L)  # get offline optimal value as a by product of WFA
        validate(requests, wfa, L)
        costs[day_id]['wfa'] = compute_cost(wfa, L)

        double_coverage = run_double_coverage(requests, L)
        validate(requests, double_coverage, L)
        costs[day_id]['dc'] = compute_cost(double_coverage, L)

        costs[day_id]['opt'] = opt

    return pd.DataFrame(costs)