import torch
from typing import List, Tuple

from policy_selection_for_inventories.environments.environment import Environment
from policy_selection_for_inventories.environments.product import Product

class Infinite_Warehouse(Environment) :
    """
    Environment representing an infinite warehouse.
    """
    def __init__(self, products_list : List[Product]) :
        super().__init__(products_list)

    def get_copy(self) -> Environment :
        new_products_list = [product.get_copy() for product in self.products_list]
        return Infinite_Warehouse(new_products_list)
        
    # def transition(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
    #     state_idx = 0
    #     control_idx = 0
    #     product_transitions = []

    #     for product in self.products_list :
    #         product_state = state[state_idx:state_idx+product.state_dim]
    #         product_control = control[control_idx:control_idx+1]
    #         product_transitions.append(product.transition(t,product_state, product_control))
            
    #         state_idx += product.state_dim
    #         control_idx += 1
    #     return torch.concat(product_transitions)

    def transition(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
        return torch.cat([
            product.transition(t,state[self.state_indexes[p_idx]:self.state_indexes[p_idx]+product.state_dim], control[p_idx:p_idx+1])
            for p_idx, product in enumerate(self.products_list)
        ])

    def cost(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
        total_cost = 0.0
        for p_idx, product in enumerate(self.products_list) :
            total_cost += product.cost(t,state[self.state_indexes[p_idx]:self.state_indexes[p_idx]+product.state_dim], control[p_idx:p_idx+1])
        return total_cost
    
    def linear_decision_variable_structure(self, t:int) -> Tuple :
        names, integralities, lower_bounds, upper_bounds = [], [], [], []

        for p_idx, product in enumerate(self.products_list) :
            p_name, p_integralities, p_lower_bounds, p_upper_bounds = product.linear_decision_variable_structure(t)

            names += ["p{}_{}".format(p_idx, name) for name in p_name]
            integralities += p_integralities
            lower_bounds += p_lower_bounds 
            upper_bounds += p_upper_bounds

        return names, integralities, lower_bounds, upper_bounds
    
    def linear_costs(self, t:int) -> List :
        costs = []
        for product in self.products_list :
            p_costs = product.linear_costs(t)
            costs += p_costs
        return costs
    

    def stage_linear_constraints(self, t:int) -> Tuple :
        A = []
        b_l = []
        b_u = []

        dim = sum([len(product.linear_decision_variable_structure(t)[0]) for product in self.products_list])
        idx = 0
        for product in self.products_list :
            p_A, p_b_l, p_b_u =  product.stage_linear_constraints(t)
            p_dim = len(product.linear_decision_variable_structure(t)[0])

            for p_A_const_idx in range(len(p_A)) :
                A.append([0.0]*dim)
                A[-1][idx:idx+p_dim] = p_A[p_A_const_idx]
            
            idx += p_dim
            b_l += p_b_l
            b_u += p_b_u

        return A, b_l, b_u

    def transition_linear_constraints(self, t:int) -> Tuple :
        A = []
        b_l = []
        b_u = []

        dim = sum([len(product.linear_decision_variable_structure(t)[0]) for product in self.products_list])
        idx = 0

        for product in self.products_list :
            p_A, p_b_l, p_b_u =  product.transition_linear_constraints(t)
            p_dim = len(product.linear_decision_variable_structure(t)[0])

            for p_A_const_idx in range(len(p_A)) :
                A.append([0.0]*(2*dim))
                A[-1][idx:idx+p_dim] = p_A[p_A_const_idx][:p_dim]
                A[-1][dim+idx:dim+idx+p_dim] = p_A[p_A_const_idx][p_dim:]
            
            idx += p_dim 
            b_l += p_b_l
            b_u += p_b_u

        return A, b_l, b_u
