from typing import List, Tuple
import torch
import numpy as np

from torch.nn.functional import relu

from policy_selection_for_inventories.environments.product import Product

class LS_FIFOP_Product(Product) :
    """
    Implements a lost sales perishable product with FIFO issuing policy.
    Costs include:
    * Purchase costs (assessed when the order is place)
    * Holding costs (assessed on all the products available in the inventory just after meeting the demand)
    * Revenue
    * Outdating costs
    * Lost sales penalty costs

    Constraints:
    * lifetime >= 1 (it is the number of periods a products is sellable)
    * leadtime >= 0 (a product ordered at t is received at t+leadtime)
    * lifetime + leadtime >= 2

    For the case lifetime=1, leadtime=0 see the stateless dynamic.
    """

    def __init__(self, lifetime:int, leadtime:int, demands: torch.Tensor, purchase_cost:torch.Tensor, 
                holding_cost:torch.Tensor, selling_price:torch.Tensor, outdating_cost:torch.Tensor, penalty_cost:torch.Tensor, state_bound:float) :
        
        assert (lifetime + leadtime >= 2), "Error: lifetime + leadtime <= 1."
        super().__init__(lifetime - 1 + leadtime, torch.zeros(lifetime - 1 + leadtime))

        self.lifetime = lifetime
        self.leadtime = leadtime
        self.demands = demands

        self.purchase_cost = purchase_cost
        self.holding_cost = holding_cost
        self.selling_price = selling_price
        self.outdating_cost = outdating_cost
        self.penalty_cost = penalty_cost

        self.state_bound = state_bound

    def get_copy(self) -> Product :
        return LS_FIFOP_Product(self.lifetime, self.leadtime, self.demands, self.purchase_cost, self.holding_cost, 
            self.selling_price, self.outdating_cost, self.penalty_cost, self.state_bound)

    def transition(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
        post_order_state = torch.cat([state, control])
        post_order_state_shifted_cumsum = torch.cat([torch.zeros(1), post_order_state]).cumsum(dim=0)
        diff = self.demands[t]-post_order_state_shifted_cumsum[:self.lifetime]
        sales = torch.minimum(post_order_state[:self.lifetime], relu(diff))#torch.where(diff>=0.0,diff,0.0))
        return post_order_state[1:] - torch.cat([sales[1:], torch.zeros(self.leadtime)])

    # def transition(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
    #     post_order_state = torch.cat([state, control])
    #     sales = torch.zeros_like(post_order_state)
    #     for i in range(self.lifetime) :
    #         unmet_demand = self.demands[t]-sales[:i-1].sum()
    #         if(unmet_demand <= 0.0) :
    #             break
    #         sales[i] = torch.minimum(unmet_demand, post_order_state[i])
    #     return (post_order_state-sales)[1:]

    def cost(self, t:int, state:torch.Tensor, control:torch.Tensor) -> torch.Tensor :
        post_order_state = torch.concatenate([state, control])

        costs = self.purchase_cost[t]*control.sum()
        # print("t =",t,"Control:",control.sum())
        # print("t =",t,"Unit Purchase:",self.purchase_cost[t])
        tmp = post_order_state[:self.lifetime].sum()-self.demands[t]
        costs = costs + self.holding_cost[t]*relu(tmp)#torch.where(tmp>=0.0,tmp,0.0)
        tmp = post_order_state[:self.lifetime].sum()
        costs = costs - self.selling_price[t]*torch.where(tmp<=self.demands[t], tmp, self.demands[t])
        tmp = post_order_state[0]-self.demands[t]
        costs = costs + self.outdating_cost[t]*relu(tmp)#torch.where(tmp>=0.0,tmp,0.0)
        tmp = self.demands[t]-post_order_state[:self.lifetime].sum()
        costs = costs + self.penalty_cost[t]*relu(tmp)#torch.where(tmp>=0.0,tmp,0.0)
        return costs
    

    def linear_decision_variable_structure(self, t:int) -> Tuple :
        """
        Returns a 4-tuple of lists constaining:
        * Names of the decision variable
        * Integrality types
        * Lower bounds
        * Upper bounds
        """
        names = ["state_{}".format(i) for i in range(self.lifetime+self.leadtime-1)]
        names += ["control"]
        names += ["internal_positive_part_{}".format(i) for i in range(self.lifetime-1)]
        names += ["internal_positive_part_indicator_{}".format(i) for i in range(self.lifetime-1)]
        names += ["external_positive_part_indicator_{}".format(i) for i in range(self.lifetime-1)]
        names += ["holding", "neg_revenue", "outdating", "penalty"]

        integralities = [0]*(self.lifetime+self.leadtime-1)
        integralities += [0]
        integralities += [0]*(self.lifetime-1)
        integralities += [1]*(self.lifetime-1)
        integralities += [1]*(self.lifetime-1)
        integralities += [0]*4

        lower_bounds = [0.0]*len(names)
        lower_bounds[-3] = -self.selling_price.numpy()[t]*self.demands.numpy()[t]

        upper_bounds = [np.inf]*(self.lifetime+self.leadtime-1)
        upper_bounds += [np.inf]
        upper_bounds += [np.inf]*(self.lifetime-1)
        upper_bounds += [1.0]*(self.lifetime-1)
        upper_bounds += [1.0]*(self.lifetime-1)
        upper_bounds += [np.inf]*4

        return names, integralities, lower_bounds, upper_bounds
    
    def linear_costs(self, t:int) -> List :
        costs = [0.0]*(self.lifetime+self.leadtime-1)
        costs += [self.purchase_cost.numpy()[t]]
        costs += [0.0]*(self.lifetime-1)
        costs += [0.0]*(self.lifetime-1)
        costs += [0.0]*(self.lifetime-1)
        costs += [1.0]*4
        return costs

    def stage_linear_constraints(self, t:int) -> Tuple :
        A = []
        b_l = []
        b_u = []

        names = self.linear_decision_variable_structure(t)[0]
        names = {name : idx for (idx,name) in enumerate(names)}
        dim = len(names)

        if(t==0) :
            # Initial state
            for i in range(self.lifetime+self.leadtime-1) :
                A.append([0.0]*dim)
                A[-1][names["state_{}".format(i)]] = 1.0
                b_l.append(self.initial_state.numpy()[i])
                b_u.append(self.initial_state.numpy()[i])

        # Holding
        A.append([0.0]*dim)
        A[-1][names["holding"]] = 1.0
        for i in range(self.lifetime) :
            A[-1][i] = -self.holding_cost.numpy()[t]
        b_l.append(-self.holding_cost.numpy()[t]*self.demands.numpy()[t])
        b_u.append(np.inf)

        # Negative revenue
        A.append([0.0]*dim)
        A[-1][names["neg_revenue"]] = 1.0
        for i in range(self.lifetime) :
            A[-1][i] = self.selling_price.numpy()[t]
        b_l.append(0.0)
        b_u.append(np.inf)

        # Outdating
        A.append([0.0]*dim)
        A[-1][names["outdating"]] = 1.0
        A[-1][0] = -self.outdating_cost.numpy()[t]
        b_l.append(-self.outdating_cost.numpy()[t]*self.demands.numpy()[t])
        b_u.append(np.inf)

        # Penalty
        A.append([0.0]*dim)
        A[-1][names["penalty"]] = 1.0
        for i in range(self.lifetime) :
            A[-1][i] = self.penalty_cost.numpy()[t]
        b_l.append(self.penalty_cost.numpy()[t]*self.demands.numpy()[t])
        b_u.append(np.inf)

        return A, b_l, b_u

    def transition_linear_constraints(self, t:int) -> Tuple :
        A = []
        b_l = []
        b_u = []

        names = self.linear_decision_variable_structure(t)[0]
        names = {name : idx for (idx,name) in enumerate(names)}
        dim = len(names)

        for i in range(self.lifetime-1) :
            A.append([0.0]*(2*dim))
            A[-1][dim+i] = 1.0
            A[-1][i+1] = -1.0
            A[-1][names["internal_positive_part_{}".format(i)]] = 1.0
            b_l.append(0.0)
            b_u.append(np.inf)

            A.append([0.0]*(2*dim))
            A[-1][dim+i] = 1.0
            A[-1][names["external_positive_part_indicator_{}".format(i)]] = -self.state_bound
            b_l.append(-np.inf)
            b_u.append(0.0)

            A.append([0.0]*(2*dim))
            A[-1][dim+i] = 1.0
            A[-1][i+1] = -1.0
            A[-1][names["internal_positive_part_{}".format(i)]] = 1.0
            A[-1][names["external_positive_part_indicator_{}".format(i)]] = self.state_bound + self.demands.numpy()[t]
            b_l.append(-np.inf)
            b_u.append(self.state_bound + self.demands.numpy()[t])

            A.append([0.0]*(2*dim))
            A[-1][names["internal_positive_part_{}".format(i)]] = 1.0
            for j in range(i+1) :
                A[-1][j] = 1.0
            b_l.append(self.demands.numpy()[t])
            b_u.append(np.inf)

            A.append([0.0]*(2*dim))
            A[-1][names["internal_positive_part_{}".format(i)]] = 1.0
            A[-1][names["internal_positive_part_indicator_{}".format(i)]] = -self.demands.numpy()[t]
            b_l.append(-np.inf)
            b_u.append(0.0)

            A.append([0.0]*(2*dim))
            A[-1][names["internal_positive_part_{}".format(i)]] = 1.0
            for j in range(i+1) :
                A[-1][j] = 1.0
            A[-1][names["internal_positive_part_indicator_{}".format(i)]] = (i+1)*self.state_bound
            b_l.append(-np.inf)
            b_u.append((i+1)*self.state_bound + self.demands.numpy()[t])

        for i in range(self.lifetime-1, self.lifetime+self.leadtime-1) :
            A.append([0.0]*(2*dim))
            A[-1][dim+i] = 1.0
            A[-1][i+1] = -1.0
            b_l.append(0.0)
            b_u.append(0.0)

        return A, b_l, b_u