import numpy as np
from typing import List


class InventorySimulator():
    def __init__(self, demands_prob: np.ndarray,
                 purchase_cost: float, delivery_cost: float, holding_cost: float, backlog_cost: float,
                 sale_price: float,
                 max_inventory: int, max_backlog: int, max_order: int) -> None:
        self.demands_prob = demands_prob
        self.purchase_cost = purchase_cost
        self.delivery_cost = delivery_cost
        self.holding_cost = holding_cost
        self.backlog_cost = backlog_cost
        self.sale_price = sale_price
        self.max_inventory = max_inventory
        self.max_backlog = max_backlog
        self.max_order = max_order
        assert np.isclose(np.sum(self.demands_prob), 1), f"demands_prob must sum to 1, now: {np.sum(self.demands_prob)}"
        pass

    def transaction_demand(self, current_state, action_order, demand):
        assert current_state >= 0 and current_state < self.state_count()
        assert action_order >= 0 and action_order < self.action_count()
        assert demand >= 0
        # 0-max_backlog represents backlog states，states larger than max_backlog represent inventory states
        current_inventory = current_state - self.max_backlog
        # Replenishment can only reach up to max_inventory
        adjusted_order = min(action_order, self.max_inventory - current_inventory)
        # next_inventory can not be lower than -max_backlog
        next_inventory = max(adjusted_order + current_inventory - demand, -self.max_backlog)
        assert next_inventory <= self.max_inventory
        # next calculate variables related to reward
        sold_amount = current_inventory - next_inventory + adjusted_order
        revenue = sold_amount * self.sale_price
        # purchase cost + delivery cost + holding cost + backlog cost
        expense = (adjusted_order * self.purchase_cost +
                   (self.delivery_cost if adjusted_order > 0 else 0) +
                   self.holding_cost * max(next_inventory, 0) +
                   self.backlog_cost * -min(next_inventory, 0))
        reward = revenue - expense
        return reward, next_inventory + self.max_backlog

    def state_count(self):
        return self.max_inventory + self.max_backlog + 1

    def action_count(self):
        return self.max_order + 1

    def build_mdp(self, n=None):
        P = np.zeros((self.state_count(), self.action_count(), self.state_count()))
        R = np.zeros((self.state_count(), self.action_count(), self.state_count()))
        for statefrom in range(self.state_count()):
            for action in range(self.action_count()):
                if n is None:
                    demands_prob = self.demands_prob
                else:
                    samples = np.random.choice(len(self.demands_prob), size=n, p=self.demands_prob)
                    freq = np.bincount(samples, minlength=len(self.demands_prob))
                    demands_prob = freq / n
                for demand in range(demands_prob.shape[0]):
                    reward, stateto = self.transaction_demand(statefrom, action, demand)
                    probability = demands_prob[demand]
                    P[statefrom, action, stateto] += probability
                    R[statefrom, action, stateto] += reward
        return P, R
