from collections import defaultdict

def s_type_policy(env, s_store, s_factory):
    t = env.time
    # compute the desired shipping quantities
    ship = dict()
    desiredStoreOrder = defaultdict()
    desiredStoreOrder_sum = 0
    # get desired shipping quantities for all stores 
    for (i,j) in env.G.edges:
        if j in env.scenario.warehouse:
            desiredStoreOrder[(i,j)] = s_store - (env.acc[t-1][j] + env.arrival_flow[t][j])
            desiredStoreOrder_sum += desiredStoreOrder[(i,j)]
    # if all store orders are feasible under the current factory availability: execute it
    if (env.acc[t-1][0] + env.arrival_prod[t-1][0]) >= desiredStoreOrder_sum:
        ship = desiredStoreOrder
    # otherwise, select store orders to maximize the minimum inventoy among all stores
    else:
        ratios = [desiredStoreOrder[0,j] / desiredStoreOrder_sum if j in env.scenario.warehouse else None for i,j in env.G.edges]
        for i, key in enumerate(desiredStoreOrder.keys()):
            ship[key] = (env.acc[t-1][key[0]] + env.arrival_prod[t][key[0]])*ratios[i]
    # compute the desired production quantities
    prod = dict()
    # compute available products at factory nodes
    av_factory = dict()
    for factory in env.scenario.factory:
        av_factory[factory] = env.acc[t-1][factory] + env.arrival_prod[t][factory] - sum([ship[key] for key in ship])
        diff = (s_factory - av_factory[factory])
        if max(0,diff) < env.scenario.storage_capacities[factory]:
            prod[factory] = max(0,diff)
        else:
            prod[factory] = env.scenario.storage_capacities[factory]
            
    return prod, ship
    

from collections import defaultdict

def s_type_policy_multi_factory(env, s_store, s_factory, eps=1e-9):
    """
    Multi-factory extension of s-type policy.
    - s_store: scalar or dict {warehouse: target_s}
    - s_factory: scalar or dict {factory: target_s}
    Return:
      prod: dict {factory: production_qty}
      ship: dict {(factory, warehouse): ship_qty}
    """
    t = env.time
    factories = list(env.scenario.factory)
    warehouses = set(env.scenario.warehouse)

    # ---------- helpers ----------
    def get_target_s_store(w):
        return s_store[w] if isinstance(s_store, dict) else s_store

    def get_target_s_factory(f):
        return s_factory[f] if isinstance(s_factory, dict) else s_factory

    # available inventory at factories for shipping (today)
    av_factory = {
        f: max(0.0, float(env.acc[t][f] + env.arrival_prod[t][f]))
        for f in factories
    }
    total_avail = sum(av_factory.values())

    # ---------- desired replenishment per warehouse (compute ONCE per warehouse) ----------
    desired_w = {}
    desired_sum = 0.0
    for w in warehouses:
        inv_w = float(env.acc[t][w] + env.arrival_flow[t][w])
        need = max(0.0, float(get_target_s_store(w) - inv_w))
        desired_w[w] = need
        desired_sum += need

    # if nobody needs anything or no supply, early exit with only production logic
    ship = defaultdict(float)

    # ---------- compute target shipment to each warehouse ----------
    # full fill if enough supply, else proportional rationing
    target_to_w = {w: 0.0 for w in warehouses}
    if desired_sum > eps and total_avail > eps:
        if total_avail >= desired_sum:
            # fulfill all
            for w in warehouses:
                target_to_w[w] = desired_w[w]
        else:
            # proportional to desired
            for w in warehouses:
                target_to_w[w] = total_avail * (desired_w[w] / desired_sum)

    # ---------- allocate target_to_w across factories respecting edges ----------
    # Build feasible factory list for each warehouse (based on graph edges)
    # Assumption: edges are (factory -> warehouse) for shipping arcs.
    # We'll allocate per warehouse, proportionally by remaining factory availability.
    for w in warehouses:
        remaining = target_to_w[w]
        if remaining <= eps:
            continue

        # connected factories that can ship to w
        connected = []
        for (i, j) in env.G.edges:
            if j == w and i in av_factory:
                connected.append(i)

        if not connected:
            # no inbound edge: can't ship to this warehouse
            continue

        # allocate iteratively in case some factories run out
        # (proportional by remaining av_factory among connected)
        while remaining > eps:
            active = [f for f in connected if av_factory[f] > eps]
            if not active:
                break
            denom = sum(av_factory[f] for f in active)
            if denom <= eps:
                break

            shipped_this_round = 0.0
            for f in active:
                qty = remaining * (av_factory[f] / denom)
                qty = min(qty, av_factory[f])
                if qty > eps:
                    ship[(f, w)] += qty
                    av_factory[f] -= qty
                    shipped_this_round += qty

            remaining -= shipped_this_round
            if shipped_this_round <= eps:
                break

    # ---------- production decision per factory ----------
    # After shipping, av_factory[f] is remaining inventory at factory (today).
    prod = {}
    for f in factories:
        target = float(get_target_s_factory(f))
        diff = max(0.0, target - av_factory[f])   # want to raise to target
        cap = float(env.scenario.storage_capacities[f])
        prod[f] = min(diff, cap)

    return dict(prod), dict(ship)
