import numpy as np
from utils import add_retraining_cost
from retrain.src.models.models_helper.dp import online_dp_iterative, online_retrains_iterative


name = "Markov"


def general_markov_strategy(C: np.ndarray, retrain_cost: float = 0) -> dict:
    """A generalized markov strategy that compares current cost, retraining cost and cost after retraining.
            Uses a Dynamic Programming approach to find the correct solution

    Args:
        C (np.ndarray): the upper-triangular cost matrix
        retrain_cost (float, optional): the cost of retraining to be added to diagonal of cost matrix. Defaults to 0.

    Returns:
        dict: the strategy result dictionary
    """
    # gather the DP table
    if retrain_cost == 0:
        _C = C
    else:
        _C = add_retraining_cost(C, retrain_cost)
    dp = online_dp_iterative(add_retraining_cost(C, retrain_cost))
    retrains = online_retrains_iterative(dp)
    result = {"retrains": retrains, "num_retrains": len(
        retrains)-1, "parameters": {}}
    return result


def run(C: np.ndarray, retrain_cost: float) -> dict:
    """A simple markov strategy that checks current costs against retraining costs and decides.

    Args:
        C (np.ndarray): the upper triangular cost matrix
        retrain_cost (float): the cost of retraining to be added to diagonal of the cost matrix

    Returns:
        dict: the strategy dict
    """
    retrains = [0]
    T = C.shape[0] - 1
    tprime = 0
    for t in range(1, T+1):
        if C[tprime, t] > retrain_cost:
            tprime = t
            retrains.append(t)
    result = {"retrains": retrains, "num_retrains": len(
        retrains)-1, "parameters": {"retrain_cost": retrain_cost}}
    return result


def optimize(*args, **kwargs):
    return run(*args, **kwargs)
