from typing import List
import torch
from torch.nn.functional import relu

from policy_selection_for_inventories.algorithms.algorithm import Algorithm
from policy_selection_for_inventories.environments.environment import Environment

class Base_Stock(Algorithm) :
    def __init__(self, environment: Environment, base_stock_levels : List[torch.Tensor]) :
        self.environment = environment
        self.base_stock_levels = base_stock_levels
        self.reset()

    def get_control(self, t:int, state:torch.Tensor) -> torch.Tensor :
        control = torch.zeros(len(self.environment.products_list))
        state_idx = 0
        self.parameter = []
        for product_idx, product in enumerate(self.environment.products_list) :
            control[product_idx] = relu(self.base_stock_levels[product_idx][t] - state[state_idx:state_idx+product.state_dim].sum())
            self.parameter.append(self.base_stock_levels[product_idx][t])
            state_idx += product.state_dim
        self.parameter = torch.FloatTensor(self.parameter)
        return control
    
    def reset(self) :
        self.parameter = None