{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "8f2efff1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import random\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "\n",
    "GRID_HEIGHT = 8\n",
    "GRID_WIDTH = 10\n",
    "OBSTACLES = [(5, 2), (5, 5), (2, 4), (6, 7)]\n",
    "GOAL = (7, 9)\n",
    "START = (7, 0)\n",
    "\n",
    "ACTIONS = [(-1, 0), (1, 0), (0, -1), (0, 1)] \n",
    "NUM_ACTIONS = len(ACTIONS)\n",
    "H = np.linspace(-150, 100, 251)\n",
    "gamma = 0.99\n",
    "threshold = 1e-4\n",
    "reward_default = -1\n",
    "reward_goal = 50\n",
    "reward_obstacle = -50\n",
    "random_probability = 0.3 \n",
    "q = 0.1\n",
    "\n",
    "SEED = 0\n",
    "\n",
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    \n",
    "def is_valid(pos):\n",
    "    x, y = pos\n",
    "    return 0 <= x < GRID_HEIGHT and 0 <= y < GRID_WIDTH\n",
    "\n",
    "def step(state, action_index):\n",
    "    if random.random() < random_probability:\n",
    "        other_actions = list(range(NUM_ACTIONS))\n",
    "        other_actions.remove(action_index)\n",
    "        action_index = random.choice(other_actions)\n",
    "\n",
    "    action = ACTIONS[action_index]\n",
    "    next_state = (state[0] + action[0], state[1] + action[1])\n",
    "    done = False\n",
    "    \n",
    "    if not is_valid(next_state):\n",
    "        next_state = state\n",
    "        reward = reward_default\n",
    "    elif next_state in OBSTACLES:\n",
    "        done = True\n",
    "        reward = np.random.normal(reward_obstacle, 1)\n",
    "    elif next_state == GOAL:\n",
    "        done = True\n",
    "        reward = np.random.normal(reward_goal, 1)\n",
    "    else:\n",
    "        reward = reward_default\n",
    "\n",
    "    return next_state, reward, done\n",
    "\n",
    "def choose_action(Q, state, epsilon):\n",
    "    if np.random.rand() < epsilon:\n",
    "        return random.randint(0, NUM_ACTIONS - 1)  \n",
    "    else:\n",
    "        return np.argmax(Q[state[0], state[1]])    \n",
    "def choose_action_PCVaR(Q_cvar, M, state, idx, epsilon):\n",
    "    if np.random.rand() < epsilon:\n",
    "        return random.randint(0, NUM_ACTIONS - 1)  \n",
    "    else:\n",
    "        q_values = Q_cvar[state[0], state[1], idx, :] - H[idx]*M[state[0], state[1], idx, :]\n",
    "        max_q = np.max(q_values)\n",
    "        max_actions = np.where(q_values == max_q)[0]\n",
    "        return np.random.choice(max_actions)\n",
    "    \n",
    "def choose_action_CVaR(Q_cvar, state, idx, epsilon):\n",
    "    if np.random.rand() < epsilon:\n",
    "        return random.randint(0, NUM_ACTIONS - 1)  \n",
    "    else:\n",
    "        q_values = Q_cvar[state[0], state[1], idx, :]\n",
    "        min_q = np.min(q_values)\n",
    "        min_actions = np.where(q_values == min_q)[0]\n",
    "        return np.random.choice(min_actions)\n",
    "    \n",
    "\n",
    "def step_deterministic(state, action):\n",
    "    dx, dy = ACTIONS[action]\n",
    "    next_state = (state[0] + dx, state[1] + dy)\n",
    "    if not is_valid(next_state):\n",
    "        next_state = state\n",
    "    if next_state in OBSTACLES:\n",
    "        return next_state, reward_obstacle, True\n",
    "    elif next_state == GOAL:\n",
    "        return next_state, reward_goal, True\n",
    "    else:\n",
    "        return next_state, reward_default, False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d0ce7ccf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sarsa(episodes=100000, alpha=0.01, epsilon_min=0.0, max_steps=500, decay_episodes = 25000):\n",
    "    Q = np.zeros((GRID_HEIGHT, GRID_WIDTH, NUM_ACTIONS)) + 10\n",
    "    rewards = []\n",
    "    epsilon = epsilon_min\n",
    "    for episode in range(episodes):\n",
    "\n",
    "        if episode < decay_episodes:\n",
    "            epsilon = 1.0 - (episode / decay_episodes)\n",
    "        else:\n",
    "            epsilon = 0.0\n",
    "\n",
    "        state = START\n",
    "        action = choose_action(Q, state, epsilon)\n",
    "        done = False\n",
    "        total_reward = 0\n",
    "        \n",
    "        for t in range(max_steps):\n",
    "            next_state, reward, done = step(state, action)\n",
    "            next_action = choose_action(Q, next_state, epsilon)\n",
    "\n",
    "            x, y = state\n",
    "            nx, ny = next_state\n",
    "\n",
    "            Q[x, y, action] += alpha * (\n",
    "                reward + Q[nx, ny, next_action] - Q[x, y, action]\n",
    "            )\n",
    "            total_reward += reward\n",
    "            state = next_state\n",
    "            action = next_action\n",
    "            \n",
    "            rewards.append(total_reward)\n",
    "            \n",
    "            if done:\n",
    "                break\n",
    "        rewards.append(total_reward)\n",
    "\n",
    "    return Q, rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7c9c82a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def PCVaR_Q_Pre_train(Q, num_simul=10000):\n",
    "    Q_cvar_sum = np.zeros((GRID_HEIGHT, GRID_WIDTH, len(H), NUM_ACTIONS)) \n",
    "    M_sum = np.zeros((GRID_HEIGHT, GRID_WIDTH, len(H), NUM_ACTIONS))      \n",
    "    Count = np.zeros((GRID_HEIGHT, GRID_WIDTH, len(H), NUM_ACTIONS))       \n",
    "    rewards = []\n",
    "    for episode in range(num_simul):\n",
    "        if episode % 10000 == 0:\n",
    "            print(f\"Episode {episode}\")\n",
    "        while True:\n",
    "            state = (random.randint(0, 7), random.randint(0, 9))\n",
    "            if state not in OBSTACLES and state != GOAL:\n",
    "                break\n",
    "        total_reward = 0\n",
    "        done = False\n",
    "        trajectory = []\n",
    "\n",
    "        time_step = 0\n",
    "        action = choose_action(Q, state, 0.6)\n",
    "        while not done:\n",
    "            next_state, reward, done = step(state, action)\n",
    "            trajectory.append((total_reward, state, action, reward))\n",
    "            state = next_state\n",
    "            total_reward += reward\n",
    "            time_step +=1\n",
    "            if time_step > 1000:\n",
    "                break\n",
    "            action = choose_action(Q, state, 0.0)\n",
    "        rewards.append(total_reward)\n",
    "        \n",
    "        G = 0\n",
    "        T = len(trajectory) - 1\n",
    "        Remain_sum = []\n",
    "        for _, _, _, r in reversed(trajectory):\n",
    "            G = r + G\n",
    "            Remain_sum.insert(0, G)\n",
    "            T -= 1\n",
    "\n",
    "        for t, ((sum_r, s, a, r), Gt) in enumerate(zip(trajectory, Remain_sum)):\n",
    "            s_x, s_y = s\n",
    "            for i, h in enumerate(H):\n",
    "                idx = np.clip(i - int(round(sum_r)),0,len(H)-1)\n",
    "                indicator = 1.0 if Gt <= H[idx] else 0.0\n",
    "                Count[s_x, s_y, idx, a] += 1\n",
    "                M_sum[s_x, s_y, idx, a] += indicator\n",
    "                Q_cvar_sum[s_x, s_y, idx, a] += Gt * indicator\n",
    "                if ((i - int(round(sum_r))) < 0) or ((i - int(round(sum_r))) > (len(H) - 1)):\n",
    "                    break\n",
    "\n",
    "    M = np.zeros_like(M_sum) \n",
    "    Q_cvar = np.zeros_like(Q_cvar_sum) \n",
    "\n",
    "    valid = Count > 0\n",
    "    M[valid] = M_sum[valid] / Count[valid]\n",
    "    Q_cvar[valid] = Q_cvar_sum[valid] / Count[valid]\n",
    "    Q_cvar[7,9, :, :] = 0\n",
    "    Q_cvar[5,2, :, :] = 0\n",
    "    Q_cvar[5,5, :, :] = 0\n",
    "    Q_cvar[2,4, :, :] = 0\n",
    "    Q_cvar[6,7, :, :] = 0\n",
    "    M[7,9, 151:, :] = 1\n",
    "    M[5,2, 151:, :] = 1\n",
    "    M[5,5, 151:, :] = 1\n",
    "    M[2,4, 151:, :] = 1\n",
    "    M[6,7, 151:, :] = 1\n",
    "    \n",
    "    return Q_cvar, M, rewards\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ff58a7dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def CVaR_Q_Pre_train(Q, num_simul=10000):\n",
    "    Q_cvar_sum = np.zeros((GRID_HEIGHT, GRID_WIDTH, len(H), NUM_ACTIONS))\n",
    "    Count = np.zeros((GRID_HEIGHT, GRID_WIDTH, len(H), NUM_ACTIONS))       \n",
    "    rewards = []\n",
    "    for episode in range(num_simul):\n",
    "        if (episode+1) % 10000 == 0:\n",
    "            print(f\"Episode {(episode+1)}\")\n",
    "        while True:\n",
    "            state = (random.randint(0, 7), random.randint(0, 9))\n",
    "            if state not in OBSTACLES and state != GOAL:\n",
    "                break\n",
    "        total_reward = 0\n",
    "        done = False\n",
    "        trajectory = []\n",
    "\n",
    "        time_step = 0\n",
    "        action = choose_action(Q, state, 0.6)\n",
    "        while not done:\n",
    "            next_state, reward, done = step(state, action)\n",
    "            trajectory.append((total_reward, state, action, reward))\n",
    "            state = next_state\n",
    "            total_reward += reward\n",
    "            time_step +=1\n",
    "            if time_step > 1000:\n",
    "                break\n",
    "            action = choose_action(Q, state, 0.0)\n",
    "        rewards.append(total_reward)\n",
    "\n",
    "        G = 0\n",
    "        T = len(trajectory) - 1\n",
    "        Remain_sum = []\n",
    "        for _, _, _, r in reversed(trajectory):\n",
    "            G = r + G\n",
    "            Remain_sum.insert(0, G)\n",
    "            T -= 1\n",
    "\n",
    "        for t, ((sum_r, s, a, r), Gt) in enumerate(zip(trajectory, Remain_sum)):\n",
    "            s_x, s_y = s\n",
    "            for i, h in enumerate(H):\n",
    "                idx = np.clip(i - int(round(sum_r)),0,len(H)-1)\n",
    "                indicator = 1.0 if Gt <= H[idx] else 0.0\n",
    "                Count[s_x, s_y, idx, a] += 1\n",
    "                Q_cvar_sum[s_x, s_y, idx, a] += (H[idx] - Gt) * indicator\n",
    "                if ((i - int(round(sum_r))) < 0) or ((i - int(round(sum_r))) > (len(H) - 1)):\n",
    "                    break\n",
    "    valid = Count > 0\n",
    "    Q_cvar = np.zeros_like(Q_cvar_sum)\n",
    "    Q_cvar[valid] = Q_cvar_sum[valid] / Count[valid]\n",
    "    for i,h in enumerate(H[150:]):\n",
    "        Q_cvar[7,9, i+150, :] = h\n",
    "        Q_cvar[5,2, i+150, :] = h\n",
    "        Q_cvar[5,5, i+150, :] = h\n",
    "        Q_cvar[2,4, i+150, :] = h\n",
    "        Q_cvar[6,7, i+150, :] = h\n",
    "    return Q_cvar, rewards"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ddfa036a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_CVaR(Q_cvar, H, lr1, trajectory):\n",
    "    Q_est = np.zeros_like(Q_cvar)\n",
    "    count = np.zeros_like(Q_cvar)\n",
    "    G = 0\n",
    "    Remain_sum = []\n",
    "    for _, _, _, r, _ in reversed(trajectory):\n",
    "        G = r + G\n",
    "        Remain_sum.insert(0, G)\n",
    "    for t, ((sum_r, s, a, r, s_next), Gt) in enumerate(zip(trajectory, Remain_sum)):\n",
    "        for i in range(len(H)):\n",
    "            idx = i - int(round(sum_r))\n",
    "            if ((idx < 0) or (idx > (len(H) - 1))):\n",
    "                break\n",
    "            count[s[0],s[1],idx, a] += 1\n",
    "            next_idx = np.clip(idx - int(round(r)), 0, len(H) -1)\n",
    "            q_values = Q_cvar[s_next[0], s_next[1], next_idx, :]    \n",
    "            min_q = np.min(q_values)\n",
    "            min_actions = np.where(q_values == min_q)[0]\n",
    "            a_next =  np.random.choice(min_actions)\n",
    "            Q_est[s[0],s[1],idx,a] += (Q_cvar[s_next[0], s_next[1],next_idx, a_next])\n",
    "    valid = count > 0\n",
    "    Q_cvar[valid] += lr1*(Q_est[valid]/count[valid] - Q_cvar[valid])  \n",
    "    return Q_cvar\n",
    "\n",
    "def CVaR_Q_learning(Q_cvar, decay_episodes=2000):\n",
    "    alpha_theta = 0.01\n",
    "    eta_set = H\n",
    "    rewards = []\n",
    "    eta_index= int(eta_RN) + 150\n",
    "    eta = eta_RN\n",
    "    cvar_hist = []\n",
    "    sigma = 45\n",
    "    for episode in range(num_episodes):\n",
    "        epsilon_t = max(1.0 - (episode / decay_episodes), 0.0)\n",
    "        Trajectory = []\n",
    "        state = START\n",
    "        done = False\n",
    "        total_reward = 0\n",
    "        eta_t_idx = eta_index\n",
    "        t = 0\n",
    "        while not done:\n",
    "            action = choose_action_CVaR(Q_cvar, state, eta_t_idx, epsilon_t)\n",
    "            next_state, reward, done = step(state, action)\n",
    "            Trajectory.append([total_reward, state, action, reward, next_state])\n",
    "            total_reward += reward\n",
    "            eta_t_idx =  np.clip(eta_t_idx - int(round(reward)), 0, len(H) -1) \n",
    "            state = next_state\n",
    "            t += 1\n",
    "            if t > 1000:\n",
    "                done = True\n",
    "        rewards.append(total_reward)\n",
    "        Q_cvar = update_CVaR(Q_cvar, H, alpha_theta, Trajectory)\n",
    "        if((episode + 1)%1000 == 0):\n",
    "            var_est = -1000000000\n",
    "            eta = 0\n",
    "            Q_start = Q_cvar[START[0], START[1]]\n",
    "            for i, h in enumerate(H):\n",
    "                val_all = h - Q_start[i]/q\n",
    "                max_val = np.max(val_all)\n",
    "                if max_val > var_est:\n",
    "                    var_est = max_val\n",
    "                    eta = h\n",
    "                    eta_index = i\n",
    "            sigma = max(45*(3 - ((episode + 1) // 2000)),0)\n",
    "        sample_eta = np.random.normal(loc=eta, scale=sigma)\n",
    "        sample_eta = np.clip(sample_eta, eta - 2*sigma, eta + 2*sigma)\n",
    "        eta_index = np.clip(int(round(sample_eta) + 150), 0,250)\n",
    "        if((episode + 1)%1000 == 0):\n",
    "            print(eta_index - 150, eta)\n",
    "            rewards_test = []\n",
    "            for iter in range(10000):\n",
    "                state = START\n",
    "                done = False\n",
    "                eta_t_idx = eta_index\n",
    "                total_reward = 0 \n",
    "                t = 0\n",
    "                while not done:\n",
    "                    action = choose_action_CVaR(Q_cvar, state, eta_t_idx, 0.0)\n",
    "                    next_state, reward, done = step(state, action)\n",
    "                    total_reward += reward\n",
    "                    eta_t_idx =  np.clip(eta_t_idx - int(round(reward)), 0, len(H) -1) \n",
    "                    state = next_state\n",
    "                    t += 1\n",
    "                    if t > 1000:\n",
    "                        done = True\n",
    "                rewards_test.append(total_reward)\n",
    "            rewards_test = np.array(rewards_test)\n",
    "            var_test = np.percentile(rewards_test, q * 100 )\n",
    "            cvar_test = np.mean(rewards_test[rewards_test <= var_test])\n",
    "            cvar_hist. append(cvar_test)\n",
    "    return Q_cvar, cvar_hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c50cc5bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def update_PCVaR(Q_cvar, M, H, lr1, lr2, trajectory):\n",
    "    Q_est = np.zeros_like(Q_cvar)\n",
    "    M_est = np.zeros_like(M)\n",
    "    count = np.zeros_like(Q_cvar)\n",
    "\n",
    "    for t, ((sum_r, s, a, r, s_next)) in enumerate(trajectory):\n",
    "        for i in range(len(H)):\n",
    "            idx = i - int(round(sum_r))\n",
    "            if ((idx < 0) or (idx > (len(H) - 1))):\n",
    "                break\n",
    "            count[s[0],s[1],idx, a] += 1\n",
    "            next_idx = np.clip(idx - int(round(r)), 0, len(H) -1)\n",
    "            q_values = Q_cvar[s_next[0], s_next[1], next_idx, :] - H[next_idx]*M[s_next[0], s_next[1], next_idx, :]     \n",
    "            max_q = np.max(q_values)\n",
    "            max_actions = np.where(q_values == max_q)[0]\n",
    "            a_next =  np.random.choice(max_actions)\n",
    "            Q_est[s[0],s[1],idx,a] += (Q_cvar[s_next[0], s_next[1],next_idx, a_next] + M[s_next[0], s_next[1],next_idx, a_next]*r)\n",
    "            M_est[s[0],s[1],idx, a] += M[s_next[0], s_next[1],next_idx, a_next]\n",
    "    valid = count > 0\n",
    "    Q_cvar[valid] += lr1*(Q_est[valid]/count[valid] - Q_cvar[valid])  \n",
    "    M[valid] += lr2*(M_est[valid]/count[valid] - M[valid]) \n",
    "    return Q_cvar, M    \n",
    "\n",
    "def PCVaR_Q_learning(Q_cvar, M, decay_episodes=2000):\n",
    "    start_time = time.time() \n",
    "    alpha_theta = 0.01\n",
    "    alpha_phi = 0.01\n",
    "    eta_set = H\n",
    "    rewards = []\n",
    "    eta_index= int(eta_RN) + 150\n",
    "    eta = eta_RN\n",
    "    sigma = 45\n",
    "    cvar_hist = []\n",
    "    for episode in range(num_episodes):\n",
    "        epsilon_t = max(1.0 - (episode / decay_episodes), 0.0)\n",
    "        Trajectory = []\n",
    "        state = START\n",
    "        done = False\n",
    "        total_reward = 0\n",
    "        eta_t_idx = eta_index\n",
    "        t = 0\n",
    "        while not done:\n",
    "            action = choose_action_PCVaR(Q_cvar, M, state, eta_t_idx, epsilon_t)\n",
    "            next_state, reward, done = step(state, action)\n",
    "            Trajectory.append([total_reward, state, action, reward, next_state])\n",
    "            total_reward += reward\n",
    "            eta_t_idx =  np.clip(eta_t_idx - int(round(reward)), 0, len(H) -1) \n",
    "            state = next_state\n",
    "            t += 1\n",
    "            if t > 1000:\n",
    "                done = True\n",
    "        rewards.append(total_reward)\n",
    "        Q_cvar, M = update_PCVaR(Q_cvar, M, H, alpha_theta, alpha_phi, Trajectory)\n",
    "        if((episode + 1)%1000 == 0):\n",
    "            var_est = -10000000000\n",
    "            eta = 0\n",
    "            Q_start = Q_cvar[START[0], START[1]]\n",
    "            M_start = M[START[0], START[1]]\n",
    "            for i, h in enumerate(H):\n",
    "                val_all = h * (q - M_start[i]) + Q_start[i]\n",
    "                max_val = np.max(val_all)\n",
    "                if max_val > var_est:\n",
    "                    var_est = max_val\n",
    "                    eta = h\n",
    "            sigma = max(45*(3 - ((episode + 1) // 2000)),0)\n",
    "        sample_eta = np.random.normal(loc=eta, scale=sigma)\n",
    "        sample_eta = np.clip(sample_eta, eta - 2*sigma, eta + 2*sigma)\n",
    "        eta_index = np.clip(int(round(sample_eta) + 150), 0,250)\n",
    "        \n",
    "        if((episode + 1)%1000 == 0):\n",
    "            print(eta_index - 150, eta)\n",
    "            rewards_test = []\n",
    "            for iter in range(10000):\n",
    "                state = START\n",
    "                done = False\n",
    "                eta_t_idx = int(eta) + 150\n",
    "                total_reward = 0 \n",
    "                t = 0\n",
    "                while not done:\n",
    "                    action = choose_action_PCVaR(Q_cvar, M, state, eta_t_idx, 0.0)\n",
    "                    next_state, reward, done = step(state, action)\n",
    "                    total_reward += reward\n",
    "                    eta_t_idx =  np.clip(eta_t_idx - int(round(reward)), 0, len(H) -1) \n",
    "                    state = next_state\n",
    "                    t += 1\n",
    "                    if t > 1000:\n",
    "                        done = True\n",
    "                rewards_test.append(total_reward)\n",
    "            rewards_test = np.array(rewards_test)\n",
    "            var_test = np.percentile(rewards_test, q * 100 )\n",
    "            cvar_test = np.mean(rewards_test[rewards_test <= var_test])\n",
    "            cvar_hist. append(cvar_test)\n",
    "    return Q_cvar, M, cvar_hist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "8577212c",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(SEED)\n",
    "Q_sarsa, rewards_sarsa = sarsa()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd85a4ba",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "71f0f884",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode 0\n",
      "Episode 10000\n",
      "Episode 20000\n",
      "Episode 30000\n",
      "Episode 40000\n",
      "-50.80753459062239\n",
      "Episode 10000\n",
      "Episode 20000\n",
      "Episode 30000\n",
      "Episode 40000\n",
      "Episode 50000\n",
      "-50.80753459062239\n"
     ]
    }
   ],
   "source": [
    "set_seed(SEED)\n",
    "Pre_PCVaR_Q_cvar, Pre_M, Pre_rewards = PCVaR_Q_Pre_train(Q_sarsa, 50000)\n",
    "eta_RN = np.quantile(Pre_rewards, q)\n",
    "set_seed(SEED)\n",
    "Pre_CVaR_Q_cvar, Pre_rewards = CVaR_Q_Pre_train(Q_sarsa, 50000)\n",
    "eta_RN = np.quantile(Pre_rewards, q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "fd175b4c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-108 -56.0\n",
      "-32 -56.0\n",
      "-48 -56.0\n",
      "-48 -4.0\n",
      "23 -4.0\n",
      "-2 -2.0\n",
      "1 1.0\n",
      "1 1.0\n",
      "0 0.0\n",
      "2 2.0\n",
      "4 4.0\n",
      "7 7.0\n",
      "9 9.0\n",
      "10 10.0\n",
      "9 9.0\n",
      "95 -56.0\n",
      "-150 -56.0\n",
      "-150 -56.0\n",
      "27 -5.0\n",
      "85 -4.0\n",
      "-3 -3.0\n",
      "2 2.0\n",
      "1 1.0\n",
      "2 2.0\n",
      "5 5.0\n",
      "7 7.0\n",
      "8 8.0\n",
      "9 9.0\n",
      "10 10.0\n",
      "10 10.0\n",
      "100 -56.0\n",
      "-150 -56.0\n",
      "17 -56.0\n",
      "46 -4.0\n",
      "-36 -3.0\n",
      "-2 -2.0\n",
      "2 2.0\n",
      "2 2.0\n",
      "2 2.0\n",
      "1 1.0\n",
      "2 2.0\n",
      "2 2.0\n",
      "1 1.0\n",
      "6 6.0\n",
      "8 8.0\n",
      "-67 -56.0\n",
      "-15 -56.0\n",
      "66 -56.0\n",
      "2 -4.0\n",
      "61 -3.0\n",
      "0 0.0\n",
      "2 2.0\n",
      "2 2.0\n",
      "1 1.0\n",
      "4 4.0\n",
      "6 6.0\n",
      "8 8.0\n",
      "8 8.0\n",
      "9 9.0\n",
      "9 9.0\n",
      "60 -56.0\n",
      "7 -56.0\n",
      "-110 -56.0\n",
      "-4 -6.0\n",
      "27 -4.0\n",
      "-2 -2.0\n",
      "1 1.0\n",
      "1 1.0\n",
      "1 1.0\n",
      "0 0.0\n",
      "3 3.0\n",
      "7 7.0\n",
      "8 8.0\n",
      "8 8.0\n",
      "8 8.0\n",
      "12 -56.0\n",
      "-22 -56.0\n",
      "75 -56.0\n",
      "-46 -5.0\n",
      "-30 -3.0\n",
      "-1 -1.0\n",
      "2 2.0\n",
      "2 2.0\n",
      "3 3.0\n",
      "5 5.0\n",
      "7 7.0\n",
      "8 8.0\n",
      "8 8.0\n",
      "8 8.0\n",
      "9 9.0\n",
      "-150 -56.0\n",
      "-46 -56.0\n",
      "34 -56.0\n",
      "24 -6.0\n",
      "-20 -4.0\n",
      "-3 -3.0\n",
      "1 1.0\n",
      "2 2.0\n",
      "3 3.0\n",
      "5 5.0\n",
      "7 7.0\n",
      "8 8.0\n",
      "8 8.0\n",
      "8 8.0\n",
      "9 9.0\n",
      "-150 -56.0\n",
      "-60 -56.0\n",
      "-17 -56.0\n",
      "-63 -4.0\n",
      "-67 -4.0\n",
      "-1 -1.0\n",
      "2 2.0\n",
      "2 2.0\n",
      "3 3.0\n",
      "3 3.0\n",
      "7 7.0\n",
      "8 8.0\n",
      "8 8.0\n",
      "10 10.0\n",
      "10 10.0\n",
      "-149 -56.0\n",
      "-108 -56.0\n",
      "-133 -56.0\n",
      "-30 -5.0\n",
      "48 -4.0\n",
      "-1 -1.0\n",
      "2 2.0\n",
      "2 2.0\n",
      "2 2.0\n",
      "5 5.0\n",
      "7 7.0\n",
      "8 8.0\n",
      "9 9.0\n",
      "9 9.0\n",
      "10 10.0\n",
      "-132 -56.0\n",
      "100 -56.0\n",
      "-44 -5.0\n",
      "-70 -3.0\n",
      "-32 -3.0\n",
      "0 0.0\n",
      "3 3.0\n",
      "2 2.0\n",
      "4 4.0\n",
      "7 7.0\n",
      "8 8.0\n",
      "9 9.0\n",
      "9 9.0\n",
      "9 9.0\n",
      "9 9.0\n"
     ]
    }
   ],
   "source": [
    "num_episodes = 15000\n",
    "\n",
    "PCVaR_Q_list = []\n",
    "M_list = []\n",
    "PCVaR_Q_cvar_hist_list = []\n",
    "\n",
    "CVaR_Q_list = []\n",
    "CVaR_Q_cvar_hist_list = []\n",
    "\n",
    "for SEED in range(5):\n",
    "    set_seed(SEED)\n",
    "    \n",
    "    Q_cvar, M, cvar_hist = PCVaR_Q_learning(\n",
    "        np.copy(Pre_PCVaR_Q_cvar), \n",
    "        np.copy(Pre_M)\n",
    "    )\n",
    "\n",
    "    PCVaR_Q_list.append(Q_cvar)\n",
    "    M_list.append(M)\n",
    "    PCVaR_Q_cvar_hist_list.append(cvar_hist)\n",
    "    \n",
    "    set_seed(SEED)\n",
    "    \n",
    "    Q_cvar, cvar_hist = CVaR_Q_learning(\n",
    "        np.copy(Pre_CVaR_Q_cvar)\n",
    "    )\n",
    "\n",
    "    CVaR_Q_list.append(Q_cvar)\n",
    "    CVaR_Q_cvar_hist_list.append(cvar_hist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "2e80756c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('PCVaR_Q_list.pickle', 'wb') as f:\n",
    "    pickle.dump(PCVaR_Q_list, f)\n",
    "with open('M_list.pickle', 'wb') as f:\n",
    "    pickle.dump(M_list, f)\n",
    "with open('PCVaR_Q_cvar_hist_list.pickle', 'wb') as f:\n",
    "    pickle.dump(PCVaR_Q_cvar_hist_list, f)\n",
    "\n",
    "with open('CVaR_Q_list.pickle', 'wb') as f:\n",
    "    pickle.dump(CVaR_Q_list, f)\n",
    "with open('CVaR_Q_cvar_hist_list.pickle', 'wb') as f:\n",
    "    pickle.dump(CVaR_Q_cvar_hist_list, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c6c04e31",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(0)\n",
    "rewards_cvar = []\n",
    "eta = 9\n",
    "for iter in range(50000):\n",
    "    state = START\n",
    "    done = False\n",
    "    eta_t_idx = int(eta) + 150\n",
    "    total_reward = 0 \n",
    "    t = 0\n",
    "    while not done:\n",
    "        action = choose_action_PCVaR(PCVaR_Q_list[0], M_list[0], state, eta_t_idx, 0.0)\n",
    "        next_state, reward, done = step(state, action)\n",
    "        total_reward += reward\n",
    "        eta_t_idx =  np.clip(eta_t_idx - int(round(reward)), 0, len(H) -1) \n",
    "        state = next_state\n",
    "        t += 1\n",
    "        if t > 1000:\n",
    "            done = True\n",
    "    rewards_cvar.append(total_reward)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "66f5f3da",
   "metadata": {},
   "outputs": [],
   "source": [
    "rewards_cvar = np.array(rewards_cvar)\n",
    "rewards_RN = np.array(Pre_rewards)\n",
    "var_RN = np.percentile(rewards_RN, q * 100)\n",
    "cvar_RN = np.mean(rewards_RN[rewards_RN<= var_RN])\n",
    "var_cvar = np.percentile(rewards_cvar, q * 100)\n",
    "cvar_cvar = np.mean(rewards_cvar[rewards_cvar<= var_cvar])\n",
    "\n",
    "plt.figure(figsize=(12, 6))\n",
    "\n",
    "plt.hist(rewards_RN, bins=30, alpha=0.4, label='RN', color='royalblue', edgecolor='black', density=True)\n",
    "\n",
    "plt.hist(rewards_cvar, bins=30, alpha=0.6, label='PCVaR-Q', color='darkorange', edgecolor='black', density=True)\n",
    "plt.axvline(cvar_RN, color='blue', linestyle='--', linewidth=2, label='CVaR of RN ')\n",
    "plt.axvline(cvar_cvar, color='red', linestyle='--', linewidth=2, label='CVaR of PCVaR-Q')\n",
    "plt.xlabel('Reward')\n",
    "plt.ylabel('Density')\n",
    "plt.legend(fontsize = 16)\n",
    "plt.grid(True, linestyle='--', alpha=0.5)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"Figure 3 (a).png\", dpi=300) \n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "8125d980",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "def get_mean_and_std_from_list(data_list):\n",
    "    min_len = min(len(d) for d in data_list)  \n",
    "    data_array = np.array([d[:min_len] for d in data_list])\n",
    "    mean = data_array.mean(axis=0)\n",
    "    std = data_array.std(axis=0)\n",
    "    return data_array, mean, std\n",
    "\n",
    "cvar_runs, cvar_mean, cvar_std = get_mean_and_std_from_list(CVaR_Q_cvar_hist_list)\n",
    "pcvar_runs, pcvar_mean, pcvar_std = get_mean_and_std_from_list(PCVaR_Q_cvar_hist_list)\n",
    "\n",
    "x = np.arange(1000, 1000 * (len(cvar_mean) + 1), 1000)\n",
    "\n",
    "plt.figure(figsize=(12, 6))\n",
    "\n",
    "for i in range(len(seeds)):\n",
    "    plt.plot(x, pcvar_runs[i][:len(x)], linestyle='--', color='darkorange', alpha=0.3)\n",
    "    plt.plot(x, cvar_runs[i][:len(x)], linestyle='--', color='royalblue', alpha=0.3)\n",
    "\n",
    "plt.plot(x, pcvar_mean, color='darkorange', marker='o', linewidth=2.5, label='PCVaR-Q')\n",
    "plt.fill_between(x, pcvar_mean - pcvar_std, pcvar_mean + pcvar_std,\n",
    "                 color='darkorange', alpha=0.2)\n",
    "\n",
    "plt.plot(x, cvar_mean, color='royalblue', marker='s', linewidth=2.5, label='CVaR-Q')\n",
    "plt.fill_between(x, cvar_mean - cvar_std, cvar_mean + cvar_std,\n",
    "                 color='royalblue', alpha=0.2)\n",
    "\n",
    "plt.axhline(y=-52.34, color='red', linestyle='--', linewidth=2, label='Opt')\n",
    "plt.xlabel(\"Training Iteration\", fontsize=12)\n",
    "plt.ylabel(\"Estimated CVaR\", fontsize=12)\n",
    "plt.grid(True, linestyle='--', alpha=0.5)\n",
    "plt.legend(fontsize=14, loc='lower right')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"Figure 3 (b).png\", dpi=300) \n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04dcd952",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac12fd44",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "217d0b78",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
