import torch
from torch.nn.functional import relu
from typing import List
from abc import ABC, abstractmethod

from policy_selection_for_inventories.environments.product import Product

class Environment(ABC) :
    """
    Abstract environment class. Encapsulates a list of products.
    """
    def __init__(self, products_list : List[Product]) :
        self.products_list = products_list
        self.state_indexes = []
        self.state_control_indexes = []
        for k in range(len(products_list)) :
            self.state_indexes.append(
                self.state_indexes[-1]+products_list[k-1].state_dim if self.state_indexes else 0
            )
            self.state_control_indexes.append(
                self.state_control_indexes[-1]+products_list[k-1].state_dim+1 if self.state_control_indexes else 0
            )
            
        self.all_control_indexes_of_sc = []
        self.all_state_indexes_of_sc = []
        for p_idx, product in enumerate(products_list) :
            self.all_state_indexes_of_sc += list(range(
                self.state_control_indexes[p_idx],
                self.state_control_indexes[p_idx]+product.state_dim
            ))
            self.all_control_indexes_of_sc.append(self.state_control_indexes[p_idx]+product.state_dim)
        
        self.all_on_hand_sc_indexes = []
        for k_prime, product_varying in enumerate(self.products_list) :
            self.all_on_hand_sc_indexes += list(range(
                self.state_control_indexes[k_prime],
                self.state_control_indexes[k_prime]+product_varying.lifetime
            ))
        self.all_just_received_sc_indexes = []
        for k_prime, product_varying in enumerate(self.products_list) :
            self.all_just_received_sc_indexes.append(
                self.state_control_indexes[k_prime]+product_varying.lifetime-1
            )
        


    def get_initial_state(self) :
        """
        Returns the environment's initial state. Used e.g. simulator.
        """
        return torch.concat([product.initial_state for product in self.products_list])
    
    @abstractmethod
    def get_copy(self) :
        raise NotImplementedError("")
    
    @abstractmethod
    def transition(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
        raise NotImplementedError("")

    @abstractmethod
    def cost(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
        raise NotImplementedError("")