
from math import ceil
import torch
from typing import List
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from policy_selection_for_inventories.algorithms.algorithm import Algorithm
from policy_selection_for_inventories.algorithms.base_stock import Base_Stock
from policy_selection_for_inventories.environments.environment import Environment
from policy_selection_for_inventories.simulator import Simulator


def build_MILP(environment:Environment, horizon:int) :
    global_costs = []
    global_names = []
    global_integralities = []
    global_lower_bounds = []
    global_upper_bounds = []
    global_A = []
    global_b_l = []
    global_b_u = []

    for t in range(horizon) :
        costs = environment.linear_costs(t)
        names, integralities, lower_bounds, upper_bounds = environment.linear_decision_variable_structure(t)
        dim = len(names)
        A, b_l, b_u = environment.stage_linear_constraints(t)
        next_A, next_b_l, next_b_u = environment.transition_linear_constraints(t)

        global_costs += costs
        global_names += ["t{}_".format(t)+name for name in names]
        global_integralities += integralities
        global_lower_bounds += lower_bounds
        global_upper_bounds += upper_bounds

        global_b_l += b_l
        global_b_u += b_u

        if(t < horizon-1) :
            extended_A = np.zeros((len(A)+len(next_A), horizon*dim))
            extended_A[:len(A),t*dim:(t+1)*dim] = A
            extended_A[len(A):, t*dim:(t+2)*dim] = next_A
            global_b_l += next_b_l
            global_b_u += next_b_u
        else :
            extended_A = np.zeros((len(A), horizon*dim))
            extended_A[:len(A),t*dim:(t+1)*dim] = A 

        global_A = np.concatenate([global_A, extended_A],axis=0) if t>0 else extended_A

    return (global_costs, global_names, global_integralities,
        global_lower_bounds, global_upper_bounds, global_A,
        global_b_l, global_b_u)


def find_best_S_policy(environment:Environment, T:int, seasonality:List[int], initial_levels: List[torch.Tensor], nb_iterations:int, learning_rate:float, show_progress_bar:bool=False) :
    K = len(environment.products_list)

    def get_avg_cost(base_stock_levels : List[torch.Tensor]) :
        states, controls, costs, parameters, computation_times = Simulator(
            environment,
            [Base_Stock(environment, [base_stock_levels[p_idx].tile(ceil(T/seasonality[p_idx]))[:T] for p_idx in range(K)])], 
            False
        ).run(T)
        return costs[0].mean()

    levels = [initial_levels[p_idx].clone().requires_grad_(True) for p_idx in range(K)]
    optimizer = torch.optim.Adam(levels, lr=learning_rate)
    costs = []
    enumeration = tqdm(range(nb_iterations)) if show_progress_bar else range(nb_iterations)
    for _ in  enumeration :
        optimizer.zero_grad()
        cost = get_avg_cost(levels)
        costs.append(cost.item())
        cost.backward()
        optimizer.step()

    plt.plot(costs)
    plt.show()
    print("Final cost", costs[-1])
    return [levels[p_idx].clone().detach().clip(min=0.0) for p_idx in range(K)]

def find_best_algorithm_idx(environment: Environment, algorithms: List[Algorithm], T:int) :
    states, controls, costs, parameters, computation_times = Simulator(environment, algorithms, True).run(T)
    avg_costs = [costs[alg_idx].mean() for alg_idx in range(len(algorithms))]
    res = np.argmin(avg_costs)
    plt.bar(np.arange(len(algorithms)), avg_costs)
    plt.axvline(res)
    plt.show()
    print("Final cost", avg_costs[res].item())
    return res