import torch
from torch.nn.functional import relu
from typing import Iterable, List, Tuple

from policy_selection_for_inventories.environments.environment import Environment
from policy_selection_for_inventories.environments.infinite_warehouse import Infinite_Warehouse
from policy_selection_for_inventories.environments.product import Product

class Finite_Warehouse_Pre(Infinite_Warehouse) :
    """
    Environment representing an infinite warehouse.
    """
    def __init__(self, products_list : List[Product], volumes: Iterable[float], total_volume: float, overflow_cost: Iterable[torch.float]) :
        super().__init__(products_list)
        self.volumes = volumes
        self.total_volume = total_volume
        self.overflow_cost = overflow_cost
        self.all_volumes_as_sc = torch.zeros(sum([self.products_list[k].state_dim+1 for k in range(len(products_list))]))
        for k_prime, product_varying in enumerate(self.products_list) :
            self.all_volumes_as_sc[
                self.state_control_indexes[k_prime]:
                self.state_control_indexes[k_prime]+product_varying.state_dim+1
            ] = self.volumes[k_prime]

    def get_copy(self) -> Environment :
        new_products_list = [product.get_copy() for product in self.products_list]
        return Finite_Warehouse_Pre(new_products_list, self.volumes, self.total_volume)

    def threshold_state_control(self, t:int, state:torch.Tensor, control:torch.Tensor) -> Tuple :
        # Computing next state's volume and finding indexes of states and controls
        post_order_state = torch.zeros(0)
        post_state_volume = 0.0
        for p_idx, product in enumerate(self.products_list) :
            post_order_state = torch.cat([post_order_state, state[self.state_indexes[p_idx]:self.state_indexes[p_idx]+product.state_dim], control[p_idx:p_idx+1]])
            post_state_volume += self.volumes[p_idx]*post_order_state[self.state_control_indexes[p_idx] : self.state_control_indexes[p_idx] + product.lifetime].sum()

        volume_to_remove = post_state_volume-self.total_volume
        volume_to_remove = torch.where(volume_to_remove>=0.0, volume_to_remove, 0.0)
        overflow_cost_incurred = 0.0

        if(volume_to_remove == 0.0) :
            return post_order_state[self.all_state_indexes_of_sc], post_order_state[self.all_control_indexes_of_sc], overflow_cost_incurred
        
        new_state_control = post_order_state.clone()
        for p_idx, product in enumerate(self.products_list) :
            if(p_idx == 0) :
                tmp = volume_to_remove
            else :
                tmp = tmp - self.volumes[p_idx-1]*post_order_state[self.state_control_indexes[p_idx-1]+self.products_list[p_idx-1].lifetime -1]
            
            new_value = torch.where(tmp>=0.0, tmp, 0.0)
            new_value = post_order_state[self.state_control_indexes[p_idx]+product.lifetime -1] - new_value/self.volumes[p_idx]
            new_value = torch.where(new_value>=0.0, new_value, 0.0)

            overflow_cost_incurred += self.overflow_cost[p_idx][t]*(post_order_state[self.state_control_indexes[p_idx]+product.lifetime -1]-new_value)
            new_state_control[self.state_control_indexes[p_idx]+product.lifetime -1] = new_value

        return new_state_control[self.all_state_indexes_of_sc], new_state_control[self.all_control_indexes_of_sc], overflow_cost_incurred
        
    # def threshold_state_control(self, t:int, state:torch.Tensor, control:torch.Tensor) -> Tuple :
    #     # Computing next state's volume and finding indexes of states and controls
    #     post_order_state = torch.zeros(0)
    #     post_state_volume = 0.0
    #     for p_idx, product in enumerate(self.products_list) :
    #         post_order_state = torch.cat([post_order_state, state[self.state_indexes[p_idx]:self.state_indexes[p_idx]+product.state_dim], control[p_idx:p_idx+1]])
    #         post_state_volume += self.volumes[p_idx]*post_order_state[self.state_control_indexes[p_idx] : self.state_control_indexes[p_idx] + product.lifetime].sum()

    #     if(post_state_volume < self.total_volume) :
    #         return state, control, 0.0
    #     else :
    #         overflow_cost_incurred = 0.0
    #         for p_idx, product in enumerate(self.products_list) :
    #             quantity_to_remove = min(post_order_state[self.state_control_indexes[p_idx]+product.lifetime -1], (post_state_volume-self.total_volume)/self.volumes[p_idx])
    #             post_order_state[self.state_control_indexes[p_idx]+product.lifetime -1] -= quantity_to_remove
    #             post_state_volume -= self.volumes[p_idx]*quantity_to_remove
    #             overflow_cost_incurred += quantity_to_remove*self.overflow_cost[p_idx][t]
    #             if(post_state_volume <= self.total_volume) :
    #                 return post_order_state[self.all_state_indexes_of_sc], post_order_state[self.all_control_indexes_of_sc], overflow_cost_incurred
    #         return post_order_state[self.all_state_indexes_of_sc], post_order_state[self.all_control_indexes_of_sc], overflow_cost_incurred

    # def threshold_state_control(self, t:int, state:torch.Tensor, control:torch.Tensor) -> Tuple :
    #     # Computing next state's volume and finding indexes of states and controls
    #     post_order_state = torch.zeros(0)
    #     state_indexes = []
    #     control_indexes = []
    #     post_state_volume = 0.0
    #     post_state_idx = 0
    #     state_idx = 0
    #     for p_idx, product in enumerate(self.products_list) :
    #         post_order_state = torch.cat([post_order_state, state[state_idx:state_idx+product.state_dim], control[p_idx:p_idx+1]])
    #         post_state_volume += self.volumes[p_idx]*post_order_state[post_state_idx : post_state_idx + product.lifetime].sum()
    #         state_indexes += list(range(post_state_idx, post_state_idx+product.state_dim))
    #         control_indexes.append(post_state_idx+product.state_dim)
    #         post_state_idx += product.state_dim+1
    #         state_idx += product.state_dim

    #     if(post_state_volume <= self.total_volume) :
    #         return state, control, 0.0
    #     else :
    #         post_state_idx = 0
    #         overflow_cost_incurred = 0.0
    #         for p_idx, product in enumerate(self.products_list) :
    #             quantity_to_remove = min(post_order_state[post_state_idx+product.lifetime -1], (post_state_volume-self.total_volume)/self.volumes[p_idx])
    #             post_order_state[post_state_idx+product.lifetime -1] -= quantity_to_remove
    #             post_state_volume -= self.volumes[p_idx]*quantity_to_remove
    #             overflow_cost_incurred += quantity_to_remove*self.overflow_cost[p_idx][t]
    #             if(post_state_volume <= self.total_volume) :
    #                 return post_order_state[state_indexes], post_order_state[control_indexes], overflow_cost_incurred
    #             post_state_idx += product.state_dim+1
    #         return post_order_state[state_indexes], post_order_state[control_indexes], overflow_cost_incurred

    # def threshold_state_control(self, t:int, state:torch.Tensor, control:torch.Tensor) -> Tuple :
    #     # Forming the state control vector
    #     state_control = []
    #     for p_idx, product in enumerate(self.products_list) :
    #         state_control.append(state[self.state_indexes[p_idx]:self.state_indexes[p_idx]+product.state_dim])
    #         state_control.append(control[p_idx:p_idx+1])
    #     state_control = torch.cat(state_control)

    #     # Computing next state's volume and finding indexes of states and controls
    #     total_volume_to_remove = 0.0
    #     for p_idx, product in enumerate(self.products_list) :
    #         total_volume_to_remove += self.volumes[p_idx]*state_control[self.state_control_indexes[p_idx]:self.state_control_indexes[p_idx]+product.lifetime].sum()
    #     total_volume_to_remove = relu(total_volume_to_remove-self.total_volume)

    #     if(total_volume_to_remove == 0.0) :
    #         return state, control, 0.0
    #     else :
    #         alterable_state_control_indexes = torch.tensor(self.state_control_indexes) + torch.tensor([product.lifetime - 1 for product in self.products_list])

    #         quantity_to_remove = torch.minimum(
    #             relu(
    #                 total_volume_to_remove 
    #                 - torch.cat([torch.zeros(1), (self.volumes*state_control[alterable_state_control_indexes]).cumsum(dim=0)])[:len(self.products_list)]
    #             )/self.volumes, 
    #             state_control[alterable_state_control_indexes]
    #         )
    #         state_control[alterable_state_control_indexes] -= quantity_to_remove

    #         thresh_state = torch.cat([state_control[self.state_control_indexes[p_idx]:self.state_control_indexes[p_idx]+product.state_dim] for p_idx, product in enumerate(self.products_list)])
    #         thresh_control = torch.cat([state_control[self.state_control_indexes[p_idx]+product.state_dim:self.state_control_indexes[p_idx]+product.state_dim+1] for p_idx, product in enumerate(self.products_list)])
    #         overflow_cost_incurred = 0.0
    #         for p_idx in range(len(self.products_list)) :
    #             overflow_cost_incurred += self.overflow_cost[p_idx][t]*quantity_to_remove[p_idx]
    #         return thresh_state, thresh_control, overflow_cost_incurred
            
    def transition(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
        thresh_state, thresh_control, overflow_cost_incurred = self.threshold_state_control(t, state, control)
        return super().transition(t, thresh_state, thresh_control)
    
    def cost(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
        thresh_state, thresh_control, overflow_cost_incurred = self.threshold_state_control(t, state, control)

        additional_control_costs = 0.0
        for p_idx, product in enumerate(self.products_list) :
            additional_control_costs += (control[p_idx]-thresh_control[p_idx])*product.purchase_cost[t]

        return overflow_cost_incurred + additional_control_costs + super().cost(t, thresh_state, thresh_control)

    def stage_linear_constraints(self, t:int) -> Tuple :
        raise NotImplementedError("")
        A, b_l, b_u = super().stage_linear_constraints(t)
        ### Add linear volume constraints
        return A, b_l, b_u
