from typing import Callable, Dict, Optional, List
import numpy as np
from models.utils import add_retraining_cost, compute_strategy_cost
from scipy.optimize import dual_annealing

name = "Cumulative Threshold"


def run(C: np.ndarray, cum_threshold: float):
    T = C.shape[0]-1
    retrains = [0]
    curr_model = 0
    curr_cost = 0
    for i in range(1, T+1):
        curr_cost += C[curr_model, i]
        if curr_cost > cum_threshold:
            retrains.append(i)
            curr_model = i
            curr_cost = 0
    result = {"retrains": retrains, "num_retrains": len(
        retrains)-1, "parameters": {"cum_threshold": cum_threshold}}
    return result


def run_cost(t, *args):
    C = args[0]
    return compute_strategy_cost(run(C, t), C)


def optimize(C, retrain_cost):
    T = C.shape[0]-1
    iu = np.triu_indices(T+1)
    _C = add_retraining_cost(C, retrain_cost)

    res = dual_annealing(run_cost, bounds=((_C[iu].min()/2, _C[iu].max()*T),), args=(_C,))
    if res.success:
        opt_cum_threshold = res.x[0]
        return run(C, opt_cum_threshold)
    else:
        raise RuntimeError("Did not find optimal Cumulative Threshold")
