import torch
from policy_selection_for_inventories.algorithms.algorithm import Algorithm
from policy_selection_for_inventories.environments.environment import Environment
import policy_selection_for_inventories.utils
import scipy
import numpy as np

class MPC(Algorithm) :
    def __init__(self, environment: Environment, horizon:int) :
        self.environment = environment
        self.horizon = horizon
        self.parameter = None

    def get_control(self, t:int, state:torch.Tensor) -> torch.Tensor :
        #print("Solving MPC for t={}".format(t))
        
        copied_env = self.environment.get_copy()

        state_idx = 0
        for product in copied_env.products_list :
            product.initial_state = state[state_idx:state_idx+product.state_dim]

            product.demands = product.demands.roll(-t)
            product.purchase_cost = product.purchase_cost.roll(-t)
            product.holding_cost = product.holding_cost.roll(-t)
            product.selling_price = product.selling_price.roll(-t)
            product.outdating_cost = product.outdating_cost.roll(-t)
            product.penalty_cost = product.penalty_cost.roll(-t)

            state_idx += product.state_dim

        (global_costs, global_names, global_integralities,
        global_lower_bounds, global_upper_bounds, global_A,
        global_b_l, global_b_u) = policy_selection_for_inventories.utils.build_MILP(copied_env,self.horizon)

        opt = scipy.optimize.milp(
            global_costs,
            integrality = global_integralities,
            bounds = scipy.optimize.Bounds(global_lower_bounds, global_upper_bounds), 
            constraints = scipy.optimize.LinearConstraint(global_A, global_b_l, global_b_u)
        )

        if(opt.x is None) :
            # print("Names", global_names)
            # print("A")
            # for line in global_A :
            #     print("\t", line)
            # print("B_L", global_b_l)
            # print("B_U", global_b_u)

            opt = scipy.optimize.milp(
                global_costs,
                integrality = global_integralities,
                bounds = scipy.optimize.Bounds(global_lower_bounds, global_upper_bounds), 
                constraints = scipy.optimize.LinearConstraint(global_A, global_b_l, global_b_u),
                options = {"disp":True}
            )
        assert opt.x is not None, "Solution not found."
        global_names = {name : idx for (idx,name) in enumerate(global_names)}
        control = torch.FloatTensor([opt.x[global_names["t0_p{}_control".format(p_idx)]] for p_idx in range(len(copied_env.products_list))])

        self.parameter = control
        return control