import numpy as np
from typing import Optional, List, Tuple


def add_retraining_cost(C: np.ndarray, retrain_cost: float = 0) -> np.ndarray:
    """Adds retraining cost to existing cost matrix. 

    Args:
        C (np.ndarray): The upper triangular Cost matrix
        retrain_cost (float, optional): The value of the retraining cost. Defaults to 0.

    Returns:
        np.ndarray: _description_
    """
    if retrain_cost == 0:
        return C
    C_copy = C.copy()
    retrain_u = np.diag(C_copy)
    retrain_u = retrain_u+retrain_cost
    np.fill_diagonal(C_copy, retrain_u)
    return C_copy


def dp_costs_recursive(t: int, p: int, C: np.ndarray) -> float:
    """Recursive element that returns the value of DP[t,p] of the
    the DP problem given a cost matrix.

    Args:
        t (int): The timestep of model deployment 
        p (int): The timestep of previous retrain
        C (np.ndarray):The upper triangular Cost matrix. Shape: (T+1,T+1)

    Returns:
        float: the current costs
    """
    # base cases for recursion
    if t < 0 or p < 0:
        return np.inf
    # print(f"\t{t=} {p=}")
    # Sum all the costs since the last retrain at p
    #  till current deployment time t
    current_costs = sum(C[p, i] for i in range(p, t+1))
    # if the last retrain was at the beginning return
    if p == 0:
        return current_costs
    # recursively find the lowest loss before p
    prev_costs, _ = dp_min_argmin(p-1, C)
    # return the combined costs
    return current_costs+prev_costs


def dp_min_argmin(t: int, C: np.ndarray) -> Tuple[int, int]:
    """Recursive element that returns the value for the t'th 
    timestep. Effectively obtains row DP[t,:] of the DP and returns the
    min and argmin.
    Args:
        t (int): The timestep of model deployment
        C (np.ndarray): The upper triangular Cost matrix. Shape: (T+1,T+1)

    Returns:
        Tuple(int,int): the smallest cost and the corresponding timestep for last retrain
    """
    # print(f"{t=}")
    # Collect costs over all possible previous retrains
    row = np.array([dp_costs_recursive(t=t, p=i, C=C) for i in range(t+1)])
    # print(f"{t=} {row=}")
    # return the min of the row costs and the location of the previous retrain
    return row.min(), row.argmin()


def retrains_recursive(t: int, C: np.ndarray) -> List[int]:
    """Returns the retrains given a last timestep and cost matrix

    Args:
        t (int): The timestep to consider as last
        C (np.ndarray): The upper-triangular cost matrix. Shape (T+1,T+1)

    Returns:
        List[int]: The timesteps when retraining should occur
    """
    _, prev_retrain = dp_min_argmin(t, C)
    retrains = [prev_retrain]
    while prev_retrain > 0:
        _, prev_retrain = dp_min_argmin(prev_retrain-1, C)
        retrains.append(prev_retrain)
    return retrains


def dp_iterative(C: np.ndarray):
    """Iteratively fills the DP table based on the cost matrix.
    The rows of the DP table correspond to the time-step and the 
    columns correspond to when the last retraining occurred. Hence, 
    the DP table is a lower triangular matrix.

    Args:
        C (np.ndarray): The upper-triangular cost matrix of shape (T+1,T+1) 

    Returns:
        dp: The DP table of shape (T+1,T+1)
    """
    T = C.shape[0]-1
    # rows are time steps
    # columns are when retrained occurred
    dp = np.full(fill_value=np.inf, shape=(T+1, T+1))
    # if retrained at time step 0, then just cumulative sum
    #  of costs for model trained at T=0
    dp[:, 0] = np.cumsum(C[0, 0:])
    # similarly we can fill for trained at time step 1
    #  it is cumsum of 1-t + cost of 0th timestep model
    dp[1:, 1] = np.cumsum(C[1, 1:])+C[0, 0]

    # for the rest of the time steps t,
    for t in range(2, T+1):
        # Only consider retraining steps p<=t
        for p in range(2, t+1):
            # if retraining occurred at this timestep
            if t == p:
                dp[t, p] = dp[t-1, :].min()+C[t, t]
            else:
                dp[t, p] = dp[t-1, p] + C[p, t]
    return dp


def retrains_iterative(dp: np.ndarray, t: Optional[int] = None) -> List[int]:
    """Finds the retrains ending at a given timestep with the DP table.

    Args:
        dp (np.ndarray): The filled lower triangular DP table.
        shape: (T+1,T+1)
        t (Optional[int], optional): The timestep to consider last. Defaults to None
                                        which corresponds to T.

    Returns:
        List[int]: list of timesteps where retraining occurred
    """
    T = dp.shape[0]-1
    if t is not None:
        assert 0 <= t <= T
    # otherwise produce retrains at last timestep
    else:
        t = T
    prev_retrain = dp[t].argmin()
    retrains = [prev_retrain]
    while prev_retrain > 0:
        prev_retrain = dp[prev_retrain-1].argmin()
        retrains.append(prev_retrain)
    return retrains


def online_dp_recusive(t: int, C: np.ndarray):
    # base case of recursion
    # print(f"Calling {t=} ")
    if t < 0:
        return np.inf, -1
    elif t == 0:
        return C[0, 0], 0

    prev_cost, prev_index = online_dp_recusive(t-1, C)

    if (rt_cost := C[t, t] + prev_cost) < (keep_cost := C[prev_index, t] + prev_cost):
        # print(f"{t=}",rt_cost, t)
        return rt_cost, t
    else:
        # print(f"{t=}",keep_cost, prev_index)
        return keep_cost, prev_index


def online_retrains_recursive(t: int, C: np.ndarray):
    _, index = online_dp_recusive(t, C)
    retrains = [index]
    while index > 0:
        _, index = online_dp_recusive(index-1, C)
        retrains.append(index)
    return retrains


def online_dp_iterative(C: np.ndarray):
    T = C.shape[0]-1
    dp = np.full(fill_value=np.inf, shape=(T+2))
    retrain_dp = np.full(fill_value=-2, shape=(T+2), dtype=int)

    dp[0] = C[0, 0]
    retrain_dp[0] = 0
    dp[-1] = np.inf

    for t in range(1, T+1):
        prev_cost = dp[t-1]
        prev_index = retrain_dp[t-1]
        retrain_loss = C[t, t]+prev_cost
        not_retrain_cost = C[prev_index, t] + prev_cost
        if retrain_loss < not_retrain_cost:
            dp[t] = retrain_loss
            retrain_dp[t] = t
        else:
            dp[t] = not_retrain_cost
            retrain_dp[t] = prev_index

    return np.array([dp, retrain_dp])


def online_retrains_iterative(dp: np.ndarray, t: Optional[int] = None) -> List[int]:
    _retrains_dp = dp[1].astype(int)
    T = dp.shape[1]-2
    if t is not None:
        assert 0 <= t <= T
    else:
        t = T  # otherwise produce retrains at last timestep
    index = _retrains_dp[t]
    retrains = [index]
    while index > 0:
        index = _retrains_dp[index-1]
        retrains.append(index)
    return retrains
