"""
This document contains the implentation of the EDDP algorithm to solve the 1/sqrt{N}-scale SP from the Gaussian stochastic program, as well as many other utility functions
used in the numerical experiments.

Part (A): Utility functions
Part (B): EDDP Functions
Part (C): Naive way to solve the SAA, which is a very large LP, and is feasible only for small size problems
Part (D): Rounding functions
Part (E): Simulation functions

"""

import numpy as np
import pulp as pulp
from pulp import *
import gurobipy as gp
from gurobipy import GRB
import random
import pickle
import time
from tqdm import tqdm
from joblib import Parallel, delayed


###############################################################################
# (A) UTILITY FUNCTIONS
###############################################################################

def solve_lp_standard(P0,P1,R0,R1,alpha_list,T,init,print_res=False):
    n = len(R0)
    P = [P0,P1]
    R = [R0,R1]
    action = range(0,2)
    state = range(0,n)
    horizon = range(0,T)
    prob = LpProblem("LP1", LpMaximize)
    variables = LpVariable.dicts("Y",(horizon,action,state),lowBound=0., upBound=1.)
    # resource constraints
    for t in horizon:
        prob += lpSum([variables[t][1][s] for s in state]) == alpha_list[t]
    # Markov state evolution
    for t in range(0,T-1):
        for s in state:
            prob += lpSum(variables[t+1][a][s] for a in action) == lpSum([variables[t][a][ss]*P[a][ss][s] for a in action for ss in state])
    # initial condition        
    for s in state:
        prob += lpSum(variables[0][a][s] for a in action) == init[s]
    # objective    
    prob += lpSum([variables[t][a][s]*R[a][s] for t in horizon for a in action for s in state])
    
    prob.solve()
    
    y = np.zeros((T,2,n))
    for t in horizon:
        for a in action:
            for s in state:
                V = variables[t][a][s]
                y[t,a,s] = V.varValue

    upper = value(prob.objective)
    if print_res:
        print("Fluid LP upper bound is " + str(upper))
        return y, upper
    else:
        return y

def randomization_numbers(P0,P1,R0,R1,alpha_list,T,init):
    n = len(R0)
    P = [P0,P1]
    R = [R0,R1]
    action = range(0,2)
    state = range(0,n)
    horizon = range(0,T)
    prob = LpProblem("LP1", LpMaximize)
    variables = LpVariable.dicts("Y",(horizon,action,state),lowBound=0., upBound=1.)
    # resource constraints
    for t in horizon:
        prob += lpSum([variables[t][1][s] for s in state]) == alpha_list[t]
    # Markov state evolution
    for t in range(0,T-1):
        for s in state:
            prob += lpSum(variables[t+1][a][s] for a in action) == lpSum([variables[t][a][ss]*P[a][ss][s] for a in action for ss 
                                                                          in state])
    # initial condition        
    for s in state:
        prob += lpSum(variables[0][a][s] for a in action) == init[s]
    # objective    
    prob += lpSum([variables[t][a][s]*R[a][s] for t in horizon for a in action for s in state])
    
    prob.solve()
    
    y = np.zeros((T,2,n))
    randomization_nb = np.zeros(T, dtype=int)
    for t in horizon:
        for a in action:
            for s in state:
                V = variables[t][a][s]
                y[t,a,s] = V.varValue
    for t in horizon:
        randomization_count = 0
        for s in state:
            if y[t,0,s] > 1e-8 and y[t,1,s] > 1e-8:
                randomization_count += 1
        randomization_nb[t] = randomization_count
    return randomization_nb


def compute_Sigma(P0, P1, t, y):
    n = len(P0[0])
    P = np.array([P0,P1])
    A = 2
    action = range(0,A)
    state = range(0,n)
    Sigma_all = np.zeros((n,A,n,n))
    Sigma = np.zeros((n,n))            
    for s in state:
        for a in action:
            for i in state:
                for j in state:
                    if i == j:
                        Sigma_all[s,a,i,j] = P[a,s,i]*(1 - P[a,s,i])
                    else:
                        Sigma_all[s,a,i,j] = -P[a,s,i]*P[a,s,j]
    for i in state:
        for j in state:
            const = 0.
            for s in state:
                for a in action:
                    const += y[t,a,s]*Sigma_all[s,a,i,j]
            Sigma[i,j] = const
    
    return Sigma


def generate_randomness_efficient(P0, P1, T, h, y, nb_samples_per_step):
    def extend_one_step(current_randomness, P0, P1, t, y, nb_samples_per_step):
        n = len(P0[0])
        mean = np.zeros(n)
        next_randomness = []
        Sigma = compute_Sigma(P0, P1, t, y)
        for item in current_randomness:
            for _ in range(nb_samples_per_step):
                next_item_copy = item.copy()
                one_randomness = np.random.multivariate_normal(mean, Sigma)
                next_item_copy.append(one_randomness)
                next_randomness.append(next_item_copy)
        return next_randomness
    
    current_randomness = [[]]
    for t in range(h-1):
        current_randomness_copy = current_randomness.copy()
        current_randomness = extend_one_step(current_randomness_copy, P0, P1, t, y, nb_samples_per_step)

    return current_randomness


def compute_const(P0, P1, t, y):
    n = len(P0[0])
    P = [P0,P1]
    action = range(0,2)
    state = range(0,n)
    const = 0.
    for s in state:
        for a in action:
            const += y[t][a][s]*P[a][s][0]*(1 - P[a][s][0])
    return np.sqrt(const)

def generate_randomness(T, nb_samples):
    randomness = np.zeros((nb_samples,T-1))
    for i in range(nb_samples):
        for j in range(T-1):
            randomness[i,j] = np.random.normal(loc=0, scale=1)
    return randomness



def generate_problem_data(
    T: int,
    S: int,
    time_homogenous: bool,
    sparse: bool,
    d: int
):
    """
    Generates:
      P: list of length T; each P[t] = [P_t_0, P_t_1],
         where P_t_a is an (S x S) transition matrix (or zeros for last stage)
      r: list of length T; each r[t] has shape (2*S,) for the immediate reward
      alpha_budget: shape (T,), some resource budget
      init: shape (S,) for initial condition
    """
    def dense_matrix(S_):
        P_ = np.zeros((S_, S_))
        for i in range(S_):
            for j in range(S_):
                P_[i,j] = np.random.exponential()
        for i in range(S_):
            s_ = np.sum(P_[i])
            if s_ > 0:
                P_[i] /= s_
        return P_

    def sparse_matrix(S_, d_):
        P_ = np.zeros((S_, S_))
        for i in range(S_):
            cols = np.random.choice(S_, d_, replace=False)
            for j in cols:
                P_[i, j] = np.random.exponential()
            s_ = np.sum(P_[i])
            if s_ > 0:
                P_[i] /= s_
        return P_

    # Generate transition kernels
    P = []
    if time_homogenous:
        if sparse:
            P_all_0 = sparse_matrix(S, d)
            P_all_1 = sparse_matrix(S, d)
            for t in range(T):
                P.append([P_all_0, P_all_1])
        else:
            P_all_0 = dense_matrix(S)
            P_all_1 = dense_matrix(S)
            for t in range(T):
                P.append([P_all_0, P_all_1])
    else:
        if sparse:
            for t in range(T):
                P_t_0 = sparse_matrix(S, d)
                P_t_1 = sparse_matrix(S, d)
                P.append([P_t_0, P_t_1])
        else:
            for t in range(T):
                P_t_0 = dense_matrix(S)
                P_t_1 = dense_matrix(S)
                P.append([P_t_0, P_t_1])
    
    # Generate reward vectors
    r = []
    if time_homogenous:
        r_all = np.random.uniform(0,1, size=(2*S,))
        for _ in range(T):
            r.append(r_all)
    else:
        for _ in range(T):
            r_t = np.random.uniform(0,1, size=(2*S,))
            r.append(r_t)

    # Generate action-1 budget constraints
    alpha_budget = np.ones(T)*0.4

    # Generate initial condition (dimension S)
    init = np.random.exponential(size=(S,))
    total = init.sum()
    if total > 0:
        init /= total

    return P, r, alpha_budget, init


def solve_lp_for_y(
    P: np.ndarray,         # shape = (T, S, S, 2)
    r: np.ndarray,         # shape = (T, 2*S)
    alpha_budget: np.ndarray,  # shape = (T,)
    init: np.ndarray,      # shape = (S,)
    T: int,
    S: int
):
    """
    Solves a simple MDP-like linear program with variables y[t,s,a],
    returning y and the objective value.
    """
    from pulp import LpProblem, LpMaximize, LpVariable, lpSum, value

    action = range(2)
    state = range(S)
    horizon = range(T)

    prob = LpProblem("LP1", LpMaximize)
    variables = {
        t: {
            s: {
                a: LpVariable(f"Y_{t}_{s}_{a}", lowBound=0.0, upBound=1.0)
                for a in action
            }
            for s in state
        }
        for t in horizon
    }

    # resource constraints
    for t in horizon:
        prob += lpSum(variables[t][s][1] for s in state) == alpha_budget[t]

    # Markov state evolution
    for t_ in range(T-1):
        for s_ in state:
            lhs = lpSum(variables[t_+1][s_][a] for a in action)
            rhs = lpSum(
                variables[t_][ss][a] * P[t_, ss, s_, a]
                for ss in state
                for a in action
            )
            prob += lhs == rhs

    # initial condition
    for s_ in state:
        lhs = lpSum(variables[0][s_][a] for a in action)
        prob += lhs == init[s_]

    # objective
    obj_expr = lpSum(
        variables[t_][s_][a] * r[t_, s_*2 + a]
        for t_ in horizon
        for s_ in state
        for a in action
    )
    prob += obj_expr

    prob.solve()

    y = np.zeros((T, S, 2))
    for t_ in horizon:
        for s_ in state:
            for a_ in action:
                V = variables[t_][s_][a_]
                y[t_,s_,a_] = V.varValue

    upper = value(prob.objective)
    return y, upper


def compute_Gamma(
    P: np.ndarray,         # shape = (T, S, S, 2)
    y: np.ndarray,         # shape = (T, S, 2)
    t: int,
    S: int
):
    action = range(2)
    state = range(S)
    Gamma_all = np.zeros((S,S,S,2))
    for s in state:
        for a in action:
            for i in state:
                for j in state:
                    if i == j:
                        Gamma_all[s,i,j,a] = P[t,s,i,a]*(1. - P[t,s,i,a])
                    else:
                        Gamma_all[s,i,j,a] = - P[t,s,i,a]*P[t,s,j,a]

    Gamma = np.zeros((S,S))
    for i in state:
        for j in state:
            val = 0.
            for s_ in state:
                for a_ in action:
                    val += y[t,s_,a_]*Gamma_all[s_,i,j,a_]
            Gamma[i,j] = val
    return Gamma


def generate_randomness_W(
    P: np.ndarray,         # shape = (T, S, S, 2)
    y: np.ndarray,         # shape = (T, S, 2)
    T: int,
    S: int,
    N: dict
):
    """
    We will store random vectors for stages 2..T in W_all[2..T].
    So W_all[2] has N[2] vectors (for stage 2),
    W_all[3] has N[3] vectors (for stage 3),
    etc.
    If T=3, we define W_all[2], W_all[3].
    """
    W_all = {}
    mean = np.zeros(S)
    for t in range(2, T+1):     # i.e. t=2..T
        # stage t => used in the t-th stage => the # of scenarios is N[t]
        W_t_list = []
        # We'll build Gamma from t-2 in the code, but be mindful
        # If your model says: stage 2 uses W_0 => we pass t-2:
        Gamma_t = compute_Gamma(P, y, t-2, S)  # t-2 is used to shift two index !!!
        for _ in range(N[t]):
            W_ = np.random.multivariate_normal(mean, Gamma_t)
            W_t_list.append(W_)
        W_all[t] = W_t_list

    return W_all

###############################################################################
# (B) EDDP Functions
###############################################################################

def idx(s: int, a: int, S: int) -> int:
    return s*2 + a

def build_subgradient_max(eq_duals, ineq_duals, linking_constr_data):
    S = linking_constr_data["S"]
    P_h = linking_constr_data["P_h"]
    subgrad = np.zeros(2*S, dtype=float)

    for s_new in range(S):
        dualval = eq_duals[s_new]
        for s_old in range(S):
            for a_old in [0,1]:
                # The sign can differ. We assume here c(s_new)= sum(...) => partial wrt x_prev(s_old,a_old)
                # This might be + or - depending on how you wrote constraints.
                coeff = + P_h[s_old, s_new, a_old]
                subgrad[idx(s_old, a_old, S)] += coeff * dualval

    return subgrad

def scenario_data_func(problem_data: dict, t: int, i: int):
    """
    r[t], P[t], y_star[t] are stored with 1-based indexing.
    """
    
    S_ = problem_data["S"]

    # use W_all[t] if t > 1
    if t > 1:
        W_ti = problem_data["W_all"][t][i]
    else:
        W_ti = np.zeros(S_)

    r_t  = problem_data["r"][t]
    P_0 = problem_data["P"][t][0]
    P_1 = problem_data["P"][t][1]

    P_h = np.zeros((S_, S_, 2))
    P_h[:,:,0] = P_0
    P_h[:,:,1] = P_1

    y_star = problem_data["y_star"][t]

    return {
        "W": W_ti,
        "r_t": r_t,
        "P_h": P_h,
        "y_star": y_star,
        "S": S_
    }


def solve_stage_problem(
    t: int,
    x_prev: np.ndarray,
    scenario_data: dict,
    V_cuts: dict,
    state_dim: int,
    first_stage: bool
):
    S = scenario_data["S"]
    r_t = scenario_data["r_t"]
    y_star = scenario_data["y_star"]
    W      = scenario_data["W"]
    P_h    = scenario_data["P_h"]

    model = gp.Model(f"Stage_{t}")
    model.setParam('OutputFlag', 0)
    
    c_vars = model.addVars(state_dim, lb=-10., ub=10., vtype=GRB.CONTINUOUS, name="c")
    theta  = model.addVar(lb=-1e5, ub=1e5, vtype=GRB.CONTINUOUS, name="theta")

    # objective: max r_t^T c + theta
    obj_expr = gp.quicksum(r_t[j]*c_vars[j] for j in range(state_dim)) + theta
    model.setObjective(obj_expr, GRB.MAXIMIZE)

    # sum_{s} c(s,1) = 0
    sum_expr = gp.quicksum(c_vars[idx(s,1,S)] for s in range(S))
    model.addConstr(sum_expr == 0, "Sum_s_c(s,1)=0")

    # c(s,a) >= 0 if y_star[s,a] ~ 0
    for s_ in range(S):
        for a_ in [0,1]:
            if y_star[s_,a_] <= 1e-6:
                model.addConstr(c_vars[idx(s_,a_,S)] >= 0, name=f"NonNeg_{s_}_{a_}")

    # linking constraints
    eq_constrs = []
    for s_ in range(S):
        lhs_expr = c_vars[idx(s_,0,S)] + c_vars[idx(s_,1,S)]
        if first_stage:
            # stage 1 => c(s,0)+ c(s,1)=0
            c_ = model.addConstr(lhs_expr == 0, name=f"FirstStage_{s_}")
            eq_constrs.append(c_)
        else:
            # c(s,0)+c(s,1) = W[s_] + sum_{(s_old,a_old)} x_prev[s_old,a_old]* P_h[s_old, s_, a_old]
            rhs_val = W[s_]
            for s_old in range(S):
                for a_old in [0,1]:
                    rhs_val += x_prev[idx(s_old,a_old,S)] * P_h[s_old, s_, a_old]
            c_ = model.addConstr(lhs_expr == rhs_val, name=f"LaterStage_{s_}")
            eq_constrs.append(c_)

    # Value function cuts:  theta <= alpha + beta^T c   (for a max problem)
    if (t+1) in V_cuts:
        for cut_idx, (alpha, beta) in enumerate(V_cuts[t+1]):
            if beta is None:
                model.addConstr(theta <= alpha, name=f"Cut_{t+1}_{cut_idx}")
            else:
                expr = alpha + gp.quicksum(beta[j]*c_vars[j] for j in range(state_dim))
                model.addConstr(theta <= expr, name=f"Cut_{t+1}_{cut_idx}")

    model.optimize()

    # Check model status
    if model.status == GRB.INFEASIBLE:
        #print("Stage problem infeasible. Writing infeasible.lp")
        model.write("infeasible.lp")
        return None, None, None
    elif model.status == GRB.UNBOUNDED:
        #print("Stage problem unbounded. Writing unbounded.lp")
        model.write("unbounded.lp")
        return None, None, None
    elif model.status not in [GRB.OPTIMAL]:
        #print(f"Stage problem not optimal. status={model.status}")
        return None, None, None

    c_opt = np.array([c_vars[j].X for j in range(state_dim)], dtype=float)
    obj_val = model.ObjVal

    eq_duals_array = np.array([eq_constrs[s_].Pi for s_ in range(S)], dtype=float)
    dual_info = {
        "eq_duals": eq_duals_array,
        "linking_data": {
            "S": S,
            "P_h": P_h
        }
    }
    return c_opt, obj_val, dual_info

def EDDP_maximization_linear_gurobi(
    T: int,
    N_list: dict,
    S: int,
    scenario_data_func,
    delta: dict,    
    problem_data: dict,
    max_iterations: int = 100,
    printing: bool = False
):
    """
    We run a forward/backward pass approach up to some iteration limit or gap condition.
    """
    state_dim = 2*S
    V_cuts = {t: [] for t in range(2, T+2)}
    # V_{T+1} => 0
    V_cuts[T+1].append((0.0, None))
    
    visited = {t: [] for t in range(1, T+1)}

    iteration = 1
    final_first_stage_decision = np.zeros(state_dim)

    # A quick check or some initial condition:
    data_11 = scenario_data_func(problem_data, 1, 0)  # stage 1, scenario i=0
    y_first = data_11["y_star"]

    NOR = 0
    for s in range(S):
        if y_first[s,0] > 1e-6 and y_first[s,1] > 1e-6:
            NOR += 1
    if NOR < 2:
        print("Conjecture triggers early return. No SDDP/EDDP.")
        return final_first_stage_decision

    while iteration <= max_iterations:
        #print(f"Current iteration: {iteration}")

        # ---------------------------
        # Forward Phase
        # ---------------------------
        c_sequence = []
        for t in range(1, T+1):
            if t == 1 and iteration < 2:
                c_opt = np.zeros(2*S)
                c_sequence.append(c_opt)
            else:
                x_prev = np.zeros(2*S) if (t == 1) else c_sequence[-1]  # In fact x_prev does not matter if t == 1
                if t == 1:
                    data_ti = scenario_data_func(problem_data, t, 0)
                    c_opt, val_opt, duals = solve_stage_problem(
                        t=t,
                        x_prev=x_prev,
                        scenario_data=data_ti,
                        V_cuts=V_cuts,
                        state_dim=state_dim,
                        first_stage=True
                    )
                    c_sequence.append(c_opt)
                else:
                    cand_c = []
                    cand_gap = []
                    for i in range(N_list[t]):
                        data_ti = scenario_data_func(problem_data, t, i)
                        c_opt, val_opt, duals = solve_stage_problem(
                                t=t,
                                x_prev=x_prev,
                                scenario_data=data_ti,
                                V_cuts=V_cuts,
                                state_dim=state_dim,
                                first_stage=False
                            )
                        
                        if c_opt is None:
                            cand_c.append(None)
                            cand_gap.append(-1e9)
                            continue
                            
                        # gap measure
                        if t < T:
                            if visited[t]:
                                dist_list = [np.linalg.norm(s - c_opt) for s in visited[t]]
                                gap_val = min(dist_list)
                            else:
                                gap_val = np.linalg.norm(c_opt)
                        else:
                            gap_val = 0.0

                        cand_c.append(c_opt)
                        cand_gap.append(gap_val)
                    
                    # pick c_t^k with largest gap
                    if all(x is None for x in cand_c):
                        #print(f"All scenarios infeasible at stage {t}. Exiting.")
                        return None

                    best_idx = max(range(len(cand_gap)), key=lambda ii: cand_gap[ii])
                    c_tk = cand_c[best_idx]
                    c_sequence.append(c_tk)
                        
        #print(f"iteration={iteration}, c_sequence={c_sequence}")
        final_first_stage_decision = c_sequence[0]                 # The indexing of c_sequence is 0-based !!!
        
        # Compute gap_1 for stage 1
        if visited[1]:
            dist_list = [np.linalg.norm(s - c_sequence[0]) for s in visited[1]]
            gap_1 = min(dist_list)
        else:
            gap_1 = np.linalg.norm(c_sequence[0])
        
        if printing and iteration > 1:
            print(f"iteration={iteration}, gap_1 = {gap_1}")

        if iteration > 1 and gap_1 <= delta[1]:
            break

        # ---------------------------
        # Backward Phase
        # ---------------------------
        for t in range(T, 1, -1):  # Here t = T...2
            if t < T:
                if visited[t]:
                    dist_list = [np.linalg.norm(s - c_sequence[t-1]) for s in visited[t]]
                    gap_t = min(dist_list)
                else:
                    gap_t = np.linalg.norm(c_sequence[t-1])
            else:
                gap_t = 0.0

            if gap_t <= delta[t]:
                visited[t-1].append(c_sequence[t-2])
            
            # solve subproblems for each scenario
            nu_vals = []
            subgrads = []
            x_ref = c_sequence[t-2]   # Careful here we need to shift t by 2 !!!
            for i in range(N_list[t]):
                data_ti = scenario_data_func(problem_data, t, i)
                c_opt, val_opt, duals = solve_stage_problem(
                    t=t,
                    x_prev=x_ref,
                    scenario_data=data_ti,
                    V_cuts=V_cuts,
                    state_dim=state_dim,
                    first_stage=False
                )
                if c_opt is None:
                    continue
                nu_vals.append(val_opt)
                eq_duals_array = duals["eq_duals"]
                #print(f"eq_duals_array = {eq_duals_array}")
                linking_data = duals["linking_data"]
                subgrad_i = build_subgradient_max(eq_duals_array, None, linking_data)
                #print(f"subgrad_i = {subgrad_i}")
                subgrads.append(subgrad_i)

            V_avg = sum(nu_vals)/len(nu_vals)
            avg_subgrad = sum(subgrads)/len(subgrads) if subgrads else np.zeros(state_dim)

            alpha_cut = V_avg - avg_subgrad.dot(x_ref)
            V_cuts[t].append((alpha_cut, avg_subgrad))

        #print(f"iteration={iteration}, V_cuts={V_cuts}")
        iteration += 1

    return final_first_stage_decision


def compute_first_c(
    P_list: np.ndarray,
    r_list: np.ndarray,
    alpha_budget: np.ndarray,
    init: np.ndarray,
    nb_sample_per_stage: int = 100,
    delta_value: float = 1e-3,
    printing: bool = False,
    repetition: int = 100
):
    S = len(init)
    T = len(alpha_budget)
    P_list[-1] = P_list[0]
    
    # Convert P_list to a 4D array for solve_lp_for_y
    # P_list[t] = [P_t_0, P_t_1], each shape (S,S).
    # We'll build P_array[t, s, s_, a].
    P_array = np.zeros((T, S, S, 2))
    for t in range(T):
        P_0 = P_list[t][0]
        P_1 = P_list[t][1]
        P_array[t,:,:,0] = P_0
        P_array[t,:,:,1] = P_1

    # Convert r_list to a 2D array r_array[t, :]
    r_array = np.zeros((T, 2*S))
    for t_ in range(T):
        r_array[t_,:] = r_list[t_]

    # Solve an LP for y
    y_star, upper_val = solve_lp_for_y(P_array, r_array, alpha_budget, init, T, S)

    # Generate N_list and delta, and set other problem paras to _dict
    N_list = {}
    delta = {}
    for t in range(2,T+1):
        N_list[t] = nb_sample_per_stage
    for t in range(1,T+1):
        delta[t] = delta_value
    
    P_dict = {}
    r_dict = {}
    y_dict = {}
    for t_ in range(1, T+1):
        P_dict[t_] = P_list[t_-1]
        r_dict[t_] = r_list[t_-1]
        y_dict[t_] = y_star[t_-1]

    # Function to handle a single iteration
    def compute_first_stage_decision(j, P_array, y_star, T, S, N_list, P_dict, r_dict, y_dict, delta):
        # Generate randomness W
        W_all = generate_randomness_W(P_array, y_star, T, S, N_list)
        problem_data = {
            "W_all": W_all,   # dictionary with keys = 2..T
            "P": P_dict,      # stage t => [P_t_0, P_t_1]
            "r": r_dict,      # stage t => shape(2*S,)
            "y_star": y_dict, # stage t => shape(S,2)
            "S": S
        }
        # Perform the optimization
        first_stage_decision = EDDP_maximization_linear_gurobi(
            T=T,
            N_list=N_list,
            S=S,
            scenario_data_func=scenario_data_func,
            delta=delta,
            problem_data=problem_data,
            max_iterations=100,
            printing=False
        )
        return first_stage_decision

    # Parallelize the loop
    all_first_c = Parallel(n_jobs=20)(
        delayed(compute_first_stage_decision)(
            j, P_array, y_star, T, S, N_list, P_dict, r_dict, y_dict, delta
        ) for j in range(repetition)
    )

    # Filter out None results
    all_first_c = [decision for decision in all_first_c if decision is not None]
            
    if printing: 
        print(f"effective number of first_c is {len(all_first_c)}")
    
    if len(all_first_c) < 10:
        return np.zeros(S)
    
    all_first_c_array = np.array(all_first_c)
    average_first_c = np.mean(all_first_c_array, axis=0)
    std_first_c = np.std(all_first_c_array, axis=0)
    
    c_first_pull = np.zeros(S)
    std_c_first_pull = np.zeros(S)
    for s in range(S):
        c_first_pull[s] = average_first_c[2*s+1]
        std_c_first_pull[s] = std_first_c[2*s+1]
        
    if printing:
        print(f"c_first_pull using EDDP = {c_first_pull}")
    #print(f"std={std_c_first_pull}")

    return c_first_pull


###############################################################################
# (C) Brute-de-force method to solve SSA
###############################################################################

def solve_c_ssa(P0, P1, R0, R1, T, y, randomness, nb_samples):
    n = len(R0)
    P = [P0,P1]
    R = [R0,R1]
    action = range(0,2)
    state = range(0,n)
    horizon = range(0,T)
    prob = LpProblem("LP2", LpMaximize)
    c_first = LpVariable.dicts("c_first",(action, state))
    c = LpVariable.dicts("c",(range(0,nb_samples), range(1,T), action, state))
    #coef = np.sqrt(N)
        
    for i in range(nb_samples):
        for t in range(1,T):
            prob += c[i][t][1][1] + c[i][t][1][0] == 0
            prob += c[i][t][0][1] + c[i][t][0][0] == 0
            for s in state:
                for a in action:
                    #prob += c[i][t][a][s] >= -y[t,a,s]*coef
                    if y[t,a,s] <= 1e-6:
                        prob += c[i][t][a][s] >= 0
    
    prob += c_first[1][1] + c_first[1][0] == 0
    prob += c_first[0][1] + c_first[0][0] == 0
    for s in state:
        for a in action:
            #prob += c_first[a][s] >= -y[0,a,s]*coef
            if y[0,a,s] <= 1e-6:
                prob += c_first[a][s] >= 0
    prob += c_first[1][0] + c_first[0][0] == 0
    
    for t in range(1,T-1):
        const = compute_const(P0, P1, t, y)
        for i in range(nb_samples):
            prob += lpSum(c[i][t+1][a][0] for a in action) - lpSum([c[i][t][a][s]*P[a][s][0] for a in action for s in state]) == const*randomness[i,t]
    
    const = compute_const(P0, P1, 0, y)
    for i in range(nb_samples):
        prob += lpSum(c[i][1][a][0] for a in action) - lpSum([c_first[a][s]*P[a][s][0] for a in action for s in state]) == const*randomness[i,0]

    # objective    
    prob += lpSum([c[i][t][a][s]*R[a][s] for i in range(nb_samples) for t in range(1,T) for a in action for s in state]) + lpSum([c_first[a][s]*R[a][s]*nb_samples for a in action for s in state])

    prob.solve()

    c_value = np.zeros((2,n))
    for a in action:
        for s in state:
            V = c_first[a][s]
            c_value[a,s] = V.varValue
    
    return c_value[0,0]


def solve_c_ssa_efficient(N, P0, P1, R0, R1, T, h, y, randomness, nb_samples_per_step):
    nb_samples = len(randomness)
    n = len(R0)
    P = [P0,P1]
    R = [R0,R1]
    action = range(0,2)
    state = range(0,n)
    horizon = range(0,T)
    prob = LpProblem("LP2", LpMaximize)
    c_first = LpVariable.dicts("c_first",(action, state))
    c = LpVariable.dicts("c",(range(0,nb_samples), range(1,T), action, state))
    coef = np.sqrt(N)
        
    for i in range(nb_samples):
        for t in range(1,T):
            prob += lpSum(c[i][t][1][s] for s in state) == 0
            prob += lpSum(c[i][t][0][s] for s in state) == 0
            for s in state:
                for a in action:
                    #prob += c[i][t][a][s] + y[t,a,s]*coef >= 0
                    if y[t,a,s] <= 1e-8:
                        prob += c[i][t][a][s] >= 0
    
    prob += lpSum(c_first[1][s] for s in state) == 0
    prob += lpSum(c_first[0][s] for s in state) == 0
    for s in state:
        prob += lpSum(c_first[a][s] for a in action) == 0
        for a in action:
            prob += c_first[a][s] + y[0,a,s]*coef >= 0
            #if y[0,a,s] <= 1e-8:
                #prob += c_first[a][s] >= 0
        
    for t in range(1,h-1):
        for i in range(nb_samples):
            for ss in state:
                prob += lpSum(c[i][t+1][a][ss] for a in action) - lpSum([c[i][t][a][s]*P[a][s][ss] for a in action for s in state]) == randomness[i][t][ss]
    
    for t in range(h-1,T-1):
        for i in range(nb_samples):
            for ss in state:
                prob += lpSum(c[i][t+1][a][ss] for a in action) - lpSum([c[i][t][a][s]*P[a][s][ss] for a in action for s in state]) == 0

    for i in range(nb_samples):
        for ss in state:
            prob += lpSum(c[i][1][a][ss] for a in action) - lpSum([c_first[a][s]*P[a][s][ss] for a in action for s in state]) == randomness[i][0][ss]
            
    # add tree structure constraints
    for t in range(1,h-1):
        node_size = nb_samples_per_step**(h-1-t)
        for i in range(nb_samples):
            if i % node_size != 0:
                for s in state:
                    for a in action:
                        prob += c[i][t][a][s] == c[i-1][t][a][s]
            

    # objective, note that "nb_samples = (nb_samples_per_step)**(T-1)"  
    prob += lpSum([c[i][t][a][s]*R[a][s] for i in range(nb_samples) for t in range(1,T) for a in action for s in state]) + lpSum([c_first[a][s]*R[a][s]*nb_samples for a in action for s in state])

    prob.solve()
    
    #print("Status:", LpStatus[prob.status])
    if LpStatus[prob.status] != 'Optimal':
        return None

    c_value = np.zeros((2,n))
    for a in action:
        for s in state:
            V = c_first[a][s]
            if abs(V.varValue) < 1e-8:
                c_value[a,s] = 0.
            else:
                c_value[a,s] = V.varValue
            
    return c_value



###############################################################################
# (D) Rounding functions
###############################################################################


"""
We implement a randomized rounding algorithm that can be found in Section 5.2.3 of the following paper:

https://dl.acm.org/doi/pdf/10.1145/2964791.2901467

The main function is the "randomized_rounding" function.

"""


def sampling(proba_line):
    n = len(proba_line)
    seed = np.random.uniform(0,1)
    position = 0
    while sum(proba_line[0:position+1]) < seed:
        position += 1
    return position

def randomized_rounding(probas, states):
    """
    This function will return a list of possible activation vectors encoded in "possible_acts",
    and a probability vector "prob" encoding the probability of sampling each possible activation
    vector in "possible_acts".

    Args:
    - probas: a list of numbers between 0 and 1
    - states: a list of integer numbers encoding the number of bandits being in each state in set B
    """
    base = np.zeros(len(probas), dtype=int)
    reduced = []
    c = 0
    for i in range(len(probas)):
        s = probas[i][0]
        nb_s = states[s]
        p_s = probas[i][1]
        prod = nb_s*p_s
        base[i] = int(prod)
        reduced.append([prod-base[i],1])
        c += reduced[i][0]
    cc = c - reduced[-1][0]    
    c = int(round(c))
    reduced[-1][0] = abs(c - cc)
    acts,prob = rounding(c,reduced)
    possible_acts = []
    for item in acts:
        possible_acts.append(list(base+item))
    return possible_acts,prob


def transform(ans1, y):
    card = len(y)
    action = np.zeros(card,dtype=int)
    left = 0
    right = 0
    for i in range(card):
        left = right
        right += y[i][1]
        action[i] = sum(ans1[left:right])
    return list(action)

def find_ind(thres, s):
    N = len(s)
    if thres > s[N-1] - 1e-8:
        return N-1
    left = 0
    right = N-1
    current = int((left+right)/2)
    while right-left > 1:
        if s[current] > thres:
            right = current
            current = int((left+right)/2)
        else:
            left = current
            current = int((left+right)/2)
    return current

def placement(c, y):
    budget = sum(y)
    if c != round(budget) or abs(budget-round(budget)) > 1e-7:
        c = int(round(budget))
        #print("Error!")
        #return False
    N = len(y)
    s = np.zeros(N)
    t = np.zeros(N)
    tau = np.zeros(N)
    Tau = np.ones(N+1)
    sums = 0.
    for i in range(N):
        s[i] = sums
        t[i] = sums + y[i]
        if abs(t[i]-round(t[i])) < 1e-8:
            tau[i] = 0
        else:
            tau[i] = t[i] - int(t[i])
        sums = t[i]
    tau_sort = sorted(set(tau))
    K = len(tau_sort)
    Tau[:K] = tau_sort
    nu = []
    for k in range(N):
        x = np.zeros(N,dtype=int)
        for l in range(c):
            thres = l + Tau[k]
            ind = find_ind(thres,s)
            x[ind] = 1
        nu.append([x,Tau[k+1]-Tau[k]])
    return nu

def rounding(c,y):
    card = len(y)
    z = []
    for i in range(card):
        z.extend([y[i][0]]*y[i][1])
    ans = placement(c,z)
    possible_acts = []
    probas = []
    for item in ans:
        act = transform(item[0],y)
        if act in possible_acts:
            ind = possible_acts.index(act)
            probas[ind] += item[1]
        else:
            possible_acts.append(act)
            probas.append(item[1])
    return possible_acts,probas


###############################################################################
# (E) Simulation functions
###############################################################################


def resolve_one_simulation(
    N: int,
    P_list: np.ndarray,
    r_list: np.ndarray,
    alpha_budget: np.ndarray,
    init: np.ndarray,
    T: int,
    S: int,
    use_diffusion: bool
):
    def custom_multinomial(n, p_vector):
        if n >= 0:
            return np.random.multinomial(n, p_vector)
        else:
            return -np.random.multinomial(-n, p_vector)
        
    state = range(0, S)
    total_reward = 0.
    current_horizon = T
    current_init = init
    
    np.set_printoptions(linewidth=np.inf, suppress=True)

    while current_horizon > 0: 
        t = T - current_horizon
        alpha = alpha_budget[t]
        P0 = P_list[t][0]
        P1 = P_list[t][1]
        R0 = np.zeros(S)
        R1 = np.zeros(S)
        for s in range(S):
            R0[s] = r_list[t][2*s]
            R1[s] = r_list[t][2*s+1]
    
        current_P_array = np.zeros((current_horizon, S, S, 2))
        for t_ in range(current_horizon):
            P_0 = P_list[t_ + t][0]
            P_1 = P_list[t_ + t][1]
            current_P_array[t_,:,:,0] = P_0
            current_P_array[t_,:,:,1] = P_1
        # Convert r_list to a 2D array r_array[t, :]
        current_r_array = np.zeros((current_horizon, 2*S))
        for t_ in range(current_horizon):
            current_r_array[t_,:] = r_list[t_ + t]  
        
        current_alpha_budget = alpha_budget[t:]     
        y_star, upper_val = solve_lp_for_y(current_P_array, current_r_array, current_alpha_budget, current_init, current_horizon, S)
        y_first_pull = y_star[0, :, 1]
        # Compute the diffusion correction term if use_diffusion is True
        if use_diffusion:
            current_P_list = P_list[t:]
            current_r_list = r_list[t:]
            c_first_pull = compute_first_c(current_P_list, current_r_list, current_alpha_budget, current_init, printing=False, nb_sample_per_stage=100, repetition=100)
            y_first_pull = y_first_pull + c_first_pull / np.sqrt(N)
        
        if np.any(y_first_pull < 0):
            y_first_pull = y_star[0, :, 1]
        
        def first_pull_number(N, y_first, alpha):
            S = len(y_first)
            state = range(0,S)
            Y = np.zeros(S, dtype = int)
            M = int(round(N*alpha))
            reduced = []
            for s in state:
                Ns = N*y_first[s]
                Y[s] = int(Ns)
                reduced.append([Ns - Y[s], 1])
            pull_frac_total = M - sum(Y)
            if pull_frac_total == 0:
                return Y
            else:        
                # randomized rounding
                possible_acts, proba_list = rounding(pull_frac_total, reduced)
                proba_list /= sum(proba_list)
                position = sampling(proba_list)
                return Y + possible_acts[position]
        
        pull_number = first_pull_number(N, y_first_pull, alpha)
        #print(f"step={t}, pull number={pull_number}")
        X_next = np.zeros(S, dtype=int)
        gain = 0.
        for s in state:
            Xs = int(round(N * current_init[s]))
            Ys = pull_number[s]
            X_next += custom_multinomial(Xs - Ys, P0[s]) + custom_multinomial(Ys, P1[s])
            gain += (Xs - Ys) * R0[s] + Ys * R1[s]
        total_reward += gain

        # update the system information
        current_horizon -= 1
        current_init = X_next / N

    return total_reward / N


def repeat_simulation(
    N: int,
    P_list: np.ndarray,
    r_list: np.ndarray,
    alpha_budget: np.ndarray,
    init: np.ndarray,
    T: int,
    S: int,
    use_diffusion: bool,
    num_repeats: int = 100,
    printing: bool = False
):
    """
    Repeats the resolve_one_simulation function and computes the mean and std.

    :param num_repeats: Number of repetitions
    :param N: System size (int)
    :param P_list: List of transition matrices
    :param r_list: List of rewards
    :param alpha_budget: Budget array
    :param init: Initial state distribution
    :param T: Time horizon
    :param S: Number of states
    :param use_diffusion: Whether to use diffusion correction
    :return: None (prints the results)
    """
    results = []
    if printing:
        for _ in tqdm(range(num_repeats)):
            performance = resolve_one_simulation(
                N=N,
                P_list=P_list,
                r_list=r_list,
                alpha_budget=alpha_budget,
                init=init,
                T=T,
                S=S,
                use_diffusion=use_diffusion
            )
            results.append(performance)
            #if use_diffusion is True:
                #print(performance)
    else:
        for _ in range(num_repeats):
            performance = resolve_one_simulation(
                N=N,
                P_list=P_list,
                r_list=r_list,
                alpha_budget=alpha_budget,
                init=init,
                T=T,
                S=S,
                use_diffusion=use_diffusion
            )
            results.append(performance)

    # Compute mean and std
    mean_performance = np.mean(results)
    std_performance = 2*np.std(results, ddof=1)/np.sqrt(len(results))

    # Print results
    if printing:
        print(f"Mean performance: {mean_performance}")
        print(f"2 sigma: {std_performance}")
        
    return mean_performance

def repeat_simulation_parallel(
    N: int,
    P_list: np.ndarray,
    r_list: np.ndarray,
    alpha_budget: np.ndarray,
    init: np.ndarray,
    T: int,
    S: int,
    use_diffusion: bool = False,
    num_repeats: int = 100,
    printing: bool = False
):
    results = Parallel(n_jobs=-1)(
        delayed(resolve_one_simulation)(
            N=N,
            P_list=P_list,
            r_list=r_list,
            alpha_budget=alpha_budget,
            init=init,
            T=T,
            S=S,
            use_diffusion=use_diffusion
        ) for _ in range(num_repeats)
    )
   
    # Compute mean and std
    mean_performance = np.mean(results)
    std_performance = 2*np.std(results, ddof=1)/np.sqrt(len(results))

    # Print results
    if printing:
        print(f"Mean performance: {mean_performance}")
        print(f"2 sigma: {std_performance}")
       
    return mean_performance
