from typing import Callable, Dict, Optional, List
import numpy as np
from models.utils import add_retraining_cost, compute_strategy_cost
from scipy.optimize import dual_annealing


name = "Threshold"


def run(C: np.ndarray, threshold: float) -> dict:
    T = C.shape[0]-1
    retrains = [0]
    curr_model = 0
    for i in range(1, T+1):
        if C[curr_model, i] > threshold:
            retrains.append(i)
            curr_model = i
    result = {"retrains": retrains, "num_retrains": len(
        retrains)-1, "parameters": {"threshold": threshold}}
    return result


def run_cost(t, *args):
    C = args[0]
    return compute_strategy_cost(run(C, t), C)


def optimize(C, retrain_cost) -> dict:
    iu = np.triu_indices(C.shape[0])
    _C = add_retraining_cost(C, retrain_cost)

    res = dual_annealing(run_cost, bounds=((_C[iu].min()/2, _C[iu].max()),), args=(_C,))
    if res.success:
        opt_threshold = res.x[0]
        return run(_C, opt_threshold)
    else:
        raise RuntimeError("Did not find optimal Threshold")
