{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3d42700d-9eeb-46d2-8b9a-5d5fd02fc4e1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pulp as pulp\n",
    "from pulp import *\n",
    "import gurobipy as gp\n",
    "from gurobipy import GRB\n",
    "import random\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "from joblib import Parallel, delayed\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "\n",
    "from EDDP_functions import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92ac9046-5cf5-4b8e-8c8d-bb71f500c573",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6ae460a2-641c-4f25-b4ff-8fbd450b501b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def EDDP_one_run(\n",
    "    T: int,\n",
    "    N_list: dict,\n",
    "    S: int,\n",
    "    scenario_data_func,\n",
    "    delta: dict,    \n",
    "    problem_data: dict,\n",
    "    max_iterations: int = 100\n",
    "):\n",
    "    \"\"\"\n",
    "    We run a forward/backward pass approach up to some iteration limit or gap condition.\n",
    "    \"\"\"\n",
    "    state_dim = 2*S\n",
    "    V_cuts = {t: [] for t in range(2, T+2)}\n",
    "    # V_{T+1} => 0\n",
    "    V_cuts[T+1].append((0.0, None))\n",
    "    \n",
    "    visited = {t: [] for t in range(1, T+1)}\n",
    "\n",
    "    iteration = 1\n",
    "    final_first_stage_decision = np.zeros(state_dim)\n",
    "\n",
    "    while iteration <= max_iterations:\n",
    "        #print(f\"Current iteration: {iteration}\")\n",
    "\n",
    "        # ---------------------------\n",
    "        # Forward Phase\n",
    "        # ---------------------------\n",
    "        c_sequence = []\n",
    "        for t in range(1, T+1):\n",
    "            if t == 1 and iteration < 2:\n",
    "                c_opt = np.zeros(2*S)\n",
    "                c_sequence.append(c_opt)\n",
    "            else:\n",
    "                x_prev = np.zeros(2*S) if (t == 1) else c_sequence[-1]  # In fact x_prev does not matter if t == 1\n",
    "                if t == 1:\n",
    "                    data_ti = scenario_data_func(problem_data, t, 0)\n",
    "                    c_opt, val_opt, duals = solve_stage_problem(\n",
    "                        t=t,\n",
    "                        x_prev=x_prev,\n",
    "                        scenario_data=data_ti,\n",
    "                        V_cuts=V_cuts,\n",
    "                        state_dim=state_dim,\n",
    "                        first_stage=True\n",
    "                    )\n",
    "                    c_sequence.append(c_opt)\n",
    "                else:\n",
    "                    cand_c = []\n",
    "                    cand_gap = []\n",
    "                    for i in range(N_list[t]):\n",
    "                        data_ti = scenario_data_func(problem_data, t, i)\n",
    "                        c_opt, val_opt, duals = solve_stage_problem(\n",
    "                                t=t,\n",
    "                                x_prev=x_prev,\n",
    "                                scenario_data=data_ti,\n",
    "                                V_cuts=V_cuts,\n",
    "                                state_dim=state_dim,\n",
    "                                first_stage=False\n",
    "                            )\n",
    "                        \n",
    "                        if c_opt is None:\n",
    "                            cand_c.append(None)\n",
    "                            cand_gap.append(-1e9)\n",
    "                            continue\n",
    "                            \n",
    "                        # gap measure\n",
    "                        if t < T:\n",
    "                            if visited[t]:\n",
    "                                dist_list = [np.linalg.norm(s - c_opt) for s in visited[t]]\n",
    "                                gap_val = min(dist_list)\n",
    "                            else:\n",
    "                                gap_val = np.linalg.norm(c_opt)\n",
    "                        else:\n",
    "                            gap_val = 0.0\n",
    "\n",
    "                        cand_c.append(c_opt)\n",
    "                        cand_gap.append(gap_val)\n",
    "                    \n",
    "                    # pick c_t^k with largest gap\n",
    "                    if all(x is None for x in cand_c):\n",
    "                        #print(f\"All scenarios infeasible at stage {t}. Exiting.\")\n",
    "                        return None\n",
    "\n",
    "                    best_idx = max(range(len(cand_gap)), key=lambda ii: cand_gap[ii])\n",
    "                    c_tk = cand_c[best_idx]\n",
    "                    c_sequence.append(c_tk)\n",
    "                        \n",
    "        #print(f\"iteration={iteration}, c_sequence={c_sequence}\")\n",
    "        final_first_stage_decision = c_sequence[0]                 # The indexing of c_sequence is 0-based !!!\n",
    "        \n",
    "        # Compute gap_1 for stage 1\n",
    "        if visited[1]:\n",
    "            dist_list = [np.linalg.norm(s - c_sequence[0]) for s in visited[1]]\n",
    "            gap_1 = min(dist_list)\n",
    "        else:\n",
    "            gap_1 = np.linalg.norm(c_sequence[0])\n",
    "\n",
    "        if iteration > 1 and gap_1 <= delta[1]:\n",
    "            break\n",
    "\n",
    "        # ---------------------------\n",
    "        # Backward Phase\n",
    "        # ---------------------------\n",
    "        for t in range(T, 1, -1):  # Here t = T...2\n",
    "            if t < T:\n",
    "                if visited[t]:\n",
    "                    dist_list = [np.linalg.norm(s - c_sequence[t-1]) for s in visited[t]]\n",
    "                    gap_t = min(dist_list)\n",
    "                else:\n",
    "                    gap_t = np.linalg.norm(c_sequence[t-1])\n",
    "            else:\n",
    "                gap_t = 0.0\n",
    "\n",
    "            if gap_t <= delta[t]:\n",
    "                visited[t-1].append(c_sequence[t-2])\n",
    "            \n",
    "            # solve subproblems for each scenario\n",
    "            nu_vals = []\n",
    "            subgrads = []\n",
    "            x_ref = c_sequence[t-2]   # Careful here we need to shift t by 2 !!!\n",
    "            for i in range(N_list[t]):\n",
    "                data_ti = scenario_data_func(problem_data, t, i)\n",
    "                c_opt, val_opt, duals = solve_stage_problem(\n",
    "                    t=t,\n",
    "                    x_prev=x_ref,\n",
    "                    scenario_data=data_ti,\n",
    "                    V_cuts=V_cuts,\n",
    "                    state_dim=state_dim,\n",
    "                    first_stage=False\n",
    "                )\n",
    "                if c_opt is None:\n",
    "                    continue\n",
    "                nu_vals.append(val_opt)\n",
    "                eq_duals_array = duals[\"eq_duals\"]\n",
    "                #print(f\"eq_duals_array = {eq_duals_array}\")\n",
    "                linking_data = duals[\"linking_data\"]\n",
    "                subgrad_i = build_subgradient_max(eq_duals_array, None, linking_data)\n",
    "                #print(f\"subgrad_i = {subgrad_i}\")\n",
    "                subgrads.append(subgrad_i)\n",
    "\n",
    "            V_avg = sum(nu_vals)/len(nu_vals)\n",
    "            avg_subgrad = sum(subgrads)/len(subgrads) if subgrads else np.zeros(state_dim)\n",
    "\n",
    "            alpha_cut = V_avg - avg_subgrad.dot(x_ref)\n",
    "            V_cuts[t].append((alpha_cut, avg_subgrad))\n",
    "\n",
    "        iteration += 1\n",
    "\n",
    "    return final_first_stage_decision\n",
    "\n",
    "\n",
    "def generate_first_NOR_greater_than_one_data(T, S):\n",
    "    while True:\n",
    "        P_list, r_list, alpha_budget, init = generate_problem_data(T=T, S=S, time_homogenous=True, sparse=True, d=int(S/2))\n",
    "        \n",
    "        P_array = np.zeros((T, S, S, 2))\n",
    "        for t in range(T):\n",
    "            P_0 = P_list[t][0]\n",
    "            P_1 = P_list[t][1]\n",
    "            P_array[t,:,:,0] = P_0\n",
    "            P_array[t,:,:,1] = P_1\n",
    "\n",
    "        # Convert r_list to a 2D array r_array[t, :]\n",
    "        r_array = np.zeros((T, 2*S))\n",
    "        for t_ in range(T):\n",
    "            r_array[t_,:] = r_list[t_]\n",
    "\n",
    "        # Solve an LP for y\n",
    "        y_star, upper_val = solve_lp_for_y(P_array, r_array, alpha_budget, init, T, S)\n",
    "        \n",
    "        NOR = 0\n",
    "        for s in range(S):\n",
    "            if y_star[0,s,0] > 1e-6 and y_star[0,s,1] > 1e-6:\n",
    "                NOR += 1\n",
    "        if NOR > 1:\n",
    "            return P_list, r_list, alpha_budget, init, y_star\n",
    "\n",
    "        \n",
    "def estimate_EDDP_runtime(T, S, nb_sample_per_stage, outer_num_repeats=100, inner_num_repeats=10):\n",
    "    \n",
    "    delta_value = 1e-3\n",
    "    \n",
    "    runtimes = []\n",
    "    effective_count = 0\n",
    "    \n",
    "    with tqdm(total=outer_num_repeats) as pbar:\n",
    "        while effective_count < outer_num_repeats:\n",
    "            # Generate data where the first stage NOR by solving the LP is > 1 !\n",
    "            P_list, r_list, alpha_budget, init, y_star = generate_first_NOR_greater_than_one_data(T, S)\n",
    "            # Form data to prepare the EDDP algorithm !\n",
    "            N_list = {}\n",
    "            delta = {}\n",
    "            for t in range(2,T+1):\n",
    "                N_list[t] = nb_sample_per_stage\n",
    "            for t in range(1,T+1):\n",
    "                delta[t] = delta_value\n",
    "            P_dict = {}\n",
    "            r_dict = {}\n",
    "            y_dict = {}\n",
    "            for t_ in range(1, T+1):\n",
    "                P_dict[t_] = P_list[t_-1]\n",
    "                r_dict[t_] = r_list[t_-1]\n",
    "                y_dict[t_] = y_star[t_-1]\n",
    "            P_array = np.zeros((T, S, S, 2))\n",
    "            for t in range(T):\n",
    "                P_0 = P_list[t][0]\n",
    "                P_1 = P_list[t][1]\n",
    "                P_array[t,:,:,0] = P_0\n",
    "                P_array[t,:,:,1] = P_1\n",
    "            # Start inner computation\n",
    "            start_time = time.time()\n",
    "            inner_count = 0\n",
    "            for _ in range(inner_num_repeats):\n",
    "                W_all = generate_randomness_W(P_array, y_star, T, S, N_list)\n",
    "                problem_data = {\n",
    "                    \"W_all\": W_all,   # dictionary with keys = 2..T\n",
    "                    \"P\": P_dict,      # stage t => [P_t_0, P_t_1]\n",
    "                    \"r\": r_dict,      # stage t => shape(2*S,)\n",
    "                    \"y_star\": y_dict, # stage t => shape(S,2)\n",
    "                    \"S\": S\n",
    "                } \n",
    "                first_stage_decision = EDDP_one_run(T, N_list, S, scenario_data_func, delta, problem_data)\n",
    "                if first_stage_decision is not None:\n",
    "                    #print(f\"inner_count={inner_count}\")\n",
    "                    inner_count += 1\n",
    "            if inner_count > int(inner_num_repeats/2):\n",
    "                end_time = time.time()\n",
    "                runtimes.append((end_time - start_time) / inner_count)\n",
    "                effective_count += 1\n",
    "                # Increment the progress bar\n",
    "                pbar.update(1)\n",
    "                \n",
    "    average_runtime = np.mean(runtimes)\n",
    "    std_runtime = 2*np.std(runtimes)/np.sqrt(len(runtimes))\n",
    "    print(f\"Average runtime over {outer_num_repeats} runs with T={T} and S={S}: {average_runtime:.4f} seconds\")\n",
    "    print(f\"Standard deviation of runtime: {std_runtime:.4f} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7cfc9bd-9b80-4af7-9fa9-787a28d8c199",
   "metadata": {},
   "source": [
    "Fix S = 5 and vary H"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2be8c1d-6c8f-4b52-b9f6-a6b44667c61c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The computation time varies across machines\n",
    "\n",
    "S = 5\n",
    "nb_sample_per_stage = 100\n",
    "myT = [5, 10, 15, 20, 25, 30]\n",
    "\n",
    "for T in myT:\n",
    "    estimate_EDDP_runtime(T, S, nb_sample_per_stage)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a82cc939-28b1-422a-ace7-6389e3ff9552",
   "metadata": {},
   "source": [
    "Fix H = 5 and vary S"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fcbb482-8f6c-4fc5-89e4-5b84d8bde978",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The computation time varies across machines\n",
    "\n",
    "T = 5\n",
    "nb_sample_per_stage = 100\n",
    "myS = [5, 10, 15, 20, 25, 30]\n",
    "\n",
    "for S in myS:\n",
    "    estimate_EDDP_runtime(T, S, nb_sample_per_stage)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc5a0b1a-36a9-4608-9b4e-db0bda8df51d",
   "metadata": {},
   "source": [
    "Plot the figure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7082e139-a2e3-4d3b-8b80-92e850479ef8",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.preprocessing import PolynomialFeatures\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import r2_score\n",
    "\n",
    "# Given simulated data\n",
    "runtime_data = [\n",
    "    {\"T\": 5, \"S\": 5, \"L\": 100, \"time in seconds\": 17.8202, \"std\": 3.4666},\n",
    "    {\"T\": 10, \"S\": 5, \"L\": 100, \"time in seconds\": 49.2953, \"std\": 8.4137},\n",
    "    {\"T\": 5, \"S\": 10, \"L\": 100, \"time in seconds\": 25.9583, \"std\": 5.2219},\n",
    "    {\"T\": 5, \"S\": 15, \"L\": 100, \"time in seconds\": 32.8601, \"std\": 5.78068},\n",
    "    {\"T\": 5, \"S\": 20, \"L\": 100, \"time in seconds\": 36.7304, \"std\": 7.8967},\n",
    "    {\"T\": 5, \"S\": 25, \"L\": 100, \"time in seconds\": 48.9043, \"std\": 8.969},\n",
    "    {\"T\": 15, \"S\": 5, \"L\": 100, \"time in seconds\": 88.8301, \"std\": 22.6099},\n",
    "    {\"T\": 5, \"S\": 5, \"L\": 50, \"time in seconds\": 11.2787, \"std\": 2.1752},\n",
    "    {\"T\": 20, \"S\": 5, \"L\": 100, \"time in seconds\": 121.5906, \"std\": 31.6199},\n",
    "    {\"T\": 5, \"S\": 30, \"L\": 100, \"time in seconds\": 52.4803, \"std\": 10.2662},\n",
    "    {\"T\": 5, \"S\": 5, \"L\": 200, \"time in seconds\": 49.2600, \"std\": 8.184825},\n",
    "    {\"T\": 5, \"S\": 5, \"L\": 150, \"time in seconds\": 29.4818, \"std\": 6.0101},\n",
    "    {\"T\": 5, \"S\": 5, \"L\": 300, \"time in seconds\": 64.2505, \"std\": 12.4117},\n",
    "    {\"T\": 5, \"S\": 5, \"L\": 400, \"time in seconds\": 126.8820, \"std\": 26.4638},  \n",
    "    {\"T\": 25, \"S\": 5, \"L\": 100, \"time in seconds\": 145.1373, \"std\": 35.0579},\n",
    "    {\"T\": 30, \"S\": 5, \"L\": 100, \"time in seconds\": 184.7827, \"std\": 44.3793},\n",
    "]\n",
    "\n",
    "# Separate data for fixed S and increasing T\n",
    "fixed_S_data = [d for d in runtime_data if d[\"S\"] == 5 and d[\"L\"] == 100]\n",
    "\n",
    "# Separate data for fixed T and increasing S\n",
    "fixed_T_data = [d for d in runtime_data if d[\"T\"] == 5 and d[\"L\"] == 100]\n",
    "\n",
    "# Separate data for fixed T and S and increasing L\n",
    "fixed_T_S_data = [d for d in runtime_data if d[\"T\"] == 5 and d[\"S\"] == 5]\n",
    "\n",
    "# Extract data for plotting (fixed S, increasing T)\n",
    "T_values = [d[\"T\"] for d in fixed_S_data]\n",
    "time_values_T = [d[\"time in seconds\"] for d in fixed_S_data]\n",
    "std_T = [d[\"std\"] for d in fixed_S_data]\n",
    "\n",
    "# Extract data for plotting (fixed T, increasing S)\n",
    "S_values = [d[\"S\"] for d in fixed_T_data]\n",
    "time_values_S = [d[\"time in seconds\"] for d in fixed_T_data]\n",
    "std_S = [d[\"std\"] for d in fixed_T_data]\n",
    "\n",
    "# Extract data for plotting (fixed T and S, increasing L)\n",
    "L_values = [d[\"L\"] for d in fixed_T_S_data]\n",
    "time_values_L = [d[\"time in seconds\"] for d in fixed_T_S_data]\n",
    "std_L = [d[\"std\"] for d in fixed_T_S_data]\n",
    "\n",
    "def fit_and_plot_degrees(x, y, yerr, xlabel, ylabel, title, color, degrees, filename):\n",
    "    \"\"\"\n",
    "    Fit polynomial regressions of degree 1, 2, and plot results with error bars.\n",
    "    Save each plot to a separate PDF file.\n",
    "    - x: Independent variable (e.g., H, S, or L)\n",
    "    - y: Dependent variable (runtime)\n",
    "    - yerr: Error bars (std deviations)\n",
    "    - xlabel, ylabel, title: Plot labels\n",
    "    - color: Color for scatter points\n",
    "    - degrees: List of polynomial degrees to fit\n",
    "    - filename: File name for saving the figure\n",
    "    \"\"\"\n",
    "    plt.figure(figsize=(8, 5))\n",
    "    \n",
    "    # Plot with thicker error bars and points\n",
    "    plt.errorbar(x, y, yerr=yerr, fmt='o', color=color, capsize=8, label='Data with 2-sigma error bars', \n",
    "                 alpha=0.7, elinewidth=4, lw=4)  # Adjust elinewidth and lw for thickness\n",
    "\n",
    "    x_range = np.linspace(min(x), max(x), 500).reshape(-1, 1)\n",
    "    \n",
    "    for degree in degrees:\n",
    "        poly = PolynomialFeatures(degree)\n",
    "        x_poly = poly.fit_transform(np.array(x).reshape(-1, 1))\n",
    "        model = LinearRegression()\n",
    "        model.fit(x_poly, y)\n",
    "        y_pred = model.predict(x_poly)\n",
    "        r2 = r2_score(y, y_pred)\n",
    "        \n",
    "        # Predict for the x range for smooth curve\n",
    "        x_range_poly = poly.transform(x_range)\n",
    "        y_range_pred = model.predict(x_range_poly)\n",
    "        \n",
    "        #Plot the regression curve\n",
    "        plt.plot(x_range, y_range_pred, \"--\", label=f'Degree {degree} regression (R² = {r2:.4f})')\n",
    "\n",
    "    plt.title(title, size=30)\n",
    "    plt.xlabel(xlabel, size=30)\n",
    "    plt.ylabel(ylabel, size=30)\n",
    "    # Increase tick label size\n",
    "    plt.tick_params(axis='x', labelsize=20)\n",
    "    plt.tick_params(axis='y', labelsize=20)\n",
    "    plt.grid(True)\n",
    "    plt.legend(fontsize=18)\n",
    "    plt.savefig(filename, bbox_inches='tight')  # Save the plot to a separate file\n",
    "    plt.close()  # Close the plot\n",
    "\n",
    "# Save separate plots\n",
    "fit_and_plot_degrees(\n",
    "    x=T_values,\n",
    "    y=time_values_T,\n",
    "    yerr=std_T,\n",
    "    xlabel=\"$H$\",\n",
    "    ylabel=\"Time (seconds)\",\n",
    "    #title=\"Vary $H$, fix $S$ = 5, $L$ = 100\",\n",
    "    title=\"\",\n",
    "    color=\"blue\",\n",
    "    degrees=[1],\n",
    "    filename=\"EDDP_runtime_vs_H.pdf\"\n",
    ")\n",
    "\n",
    "fit_and_plot_degrees(\n",
    "    x=S_values,\n",
    "    y=time_values_S,\n",
    "    yerr=std_S,\n",
    "    xlabel=\"$S$\",\n",
    "    ylabel=\"Time (seconds)\",\n",
    "    #title=\"Vary $S$, fix $H$ = 5, $L$ = 100\",\n",
    "    title=\"\",\n",
    "    color=\"orange\",\n",
    "    degrees=[1],\n",
    "    filename=\"EDDP_runtime_vs_S.pdf\"\n",
    ")\n",
    "\n",
    "fit_and_plot_degrees(\n",
    "    x=L_values,\n",
    "    y=time_values_L,\n",
    "    yerr=std_L,\n",
    "    xlabel=\"Samples per stage $L$\",\n",
    "    ylabel=\"Time (seconds)\",\n",
    "    title=\"Vary $L$, fix $H$ = 5, $S$ = 5\",\n",
    "    color=\"purple\",\n",
    "    degrees=[2],\n",
    "    filename=\"EDDP_runtime_vs_L.pdf\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3b323f9-e39e-43ca-af00-006880b9efa9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
