{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c044b1af",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import pickle\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17f81912",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_states = 1000\n",
    "n_actions = 500\n",
    "d = 7\n",
    "gamma = 0.99\n",
    "\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)\n",
    "torch.cuda.manual_seed(0)\n",
    "random.seed(0)\n",
    "\n",
    "def softmax(v,axis):\n",
    "    dim = v.shape[axis]\n",
    "    max_v = np.max(v)\n",
    "    v -= max_v\n",
    "    exp_v = np.exp(v)\n",
    "    denom = np.expand_dims(np.sum( exp_v, axis=axis), axis=axis)\n",
    "    denom = denom.repeat(dim, axis=-1)\n",
    "    return exp_v / denom\n",
    "\n",
    "def logsumexp(v,axis):\n",
    "    max_v = np.max(v)\n",
    "    v -= max_v\n",
    "    return np.log(np.sum( np.exp(v), axis=axis)) + max_v\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "872d8a76",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Environment:\n",
    "    def __init__(self, n_states, n_actions, d, gamma):\n",
    "        self.features = softmax(np.random.randn(n_states,n_actions, d),axis=2)\n",
    "        self.reward_weights = np.random.randn(d)\n",
    "        self.transition_weights = softmax(np.random.randn(d, n_states),axis=1)\n",
    "        self.reward = self.features@self.reward_weights\n",
    "        self.T = self.features@self.transition_weights\n",
    "        self.gamma = gamma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1eac31f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "env = Environment(n_states, n_actions, d, gamma)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "687e1245",
   "metadata": {},
   "outputs": [],
   "source": [
    "def soft_value_iteration(env):\n",
    "    Q = np.zeros((n_states,n_actions))\n",
    "    V = np.zeros(n_states)\n",
    "    for _ in range(1000):\n",
    "        theta = env.reward_weights + env.gamma*env.transition_weights@V\n",
    "        Q = env.features@theta\n",
    "        V_new = logsumexp(Q,-1)\n",
    "        tol = np.linalg.norm(V - V_new)\n",
    "        print(tol)\n",
    "        if  tol < 1e-1:\n",
    "            break\n",
    "        V = V_new\n",
    "    return Q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae8fb2a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "Q = soft_value_iteration(env)\n",
    "expert = softmax(Q,axis=1)\n",
    "\n",
    "to_save = { \"Q\": Q, \"policy\":expert}\n",
    "\n",
    "# Save to a pickle file\n",
    "with open('simple_expert.pkl', 'wb') as file:\n",
    "    pickle.dump(to_save, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86375a43",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('simple_expert.pkl', 'rb') as file:\n",
    "    to_save = pickle.load(file)\n",
    "Q = to_save[\"Q\"]\n",
    "expert = to_save[\"policy\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "411e33db",
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "def evaluate_policy(policy,env):\n",
    "    features_pi = np.stack([ policy[s]@env.features[s] for s in range(n_states)])\n",
    "    V = np.zeros(n_states)\n",
    "    for _ in range(1000):\n",
    "        theta = env.reward_weights + env.gamma*env.transition_weights@V\n",
    "        V_new = features_pi@theta\n",
    "        tol = np.linalg.norm(V - V_new)\n",
    "        if  tol < 1e-1:\n",
    "            break\n",
    "        V = deepcopy(V_new)\n",
    "    return V"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "545014a8",
   "metadata": {},
   "source": [
    "# Instantiate the complex expert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b6ba4b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ThreeLayerMLP(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim1=1000, hidden_dim2=1000, output_dim=10):\n",
    "        super(ThreeLayerMLP, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim1)  # First hidden layer\n",
    "        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)  # Second hidden layer\n",
    "        self.fc3 = nn.Linear(hidden_dim2, output_dim)  # Output layer\n",
    "        self.softmax = nn.Softmax(dim=1) \n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.fc1(x))  # ReLU activation after first layer\n",
    "        x = torch.relu(self.fc2(x))  # ReLU activation after second layer\n",
    "        x = self.fc3(x)  # Output layer (logits)\n",
    "        return self.softmax(x) # Apply softmax to get probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "258f360d",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32\n",
    "epochs = 1000\n",
    "learning_rate = 0.001\n",
    "\n",
    "# Instantiate the model\n",
    "expert_NN = ThreeLayerMLP(input_dim=n_states, output_dim=n_actions)\n",
    "\n",
    "# Loss function (KL-Divergence Loss)\n",
    "loss_function = nn.KLDivLoss(reduction='batchmean')  # Expecting log-probs in input\n",
    "\n",
    "# Optimizer (Adam)\n",
    "optimizer = optim.Adam(expert_NN.parameters(), lr=learning_rate)\n",
    "\n",
    "# Generate Synthetic Training Data\n",
    "num_samples = n_states\n",
    "X_train = torch.eye(n_states).to(torch.float32) # Random input data\n",
    "target_distribution = torch.from_numpy(expert).to(torch.float32)\n",
    "\n",
    "# Training Loop\n",
    "for epoch in range(epochs):\n",
    "    perm = torch.randperm(num_samples)\n",
    "    X_train_batch = X_train[perm]\n",
    "    Y_train_batch = target_distribution[perm]\n",
    "    for i in range(0, num_samples, batch_size):\n",
    "        X_batch = X_train_batch[i:i+batch_size]\n",
    "        y_batch = Y_train_batch[i:i+batch_size]\n",
    "\n",
    "        optimizer.zero_grad()  # Zero gradients\n",
    "\n",
    "        output = expert_NN(X_batch)  # Forward pass\n",
    "        loss = loss_function(output.log(), y_batch)  # Compute loss\n",
    "\n",
    "        loss.backward()  # Backpropagation\n",
    "        optimizer.step()  # Update weights\n",
    "\n",
    "    if (epoch + 1) % 10 == 0:\n",
    "        print(f\"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.10f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43d40c76",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.mean(torch.Tensor(\n",
    "    [torch.sum(torch.abs(expert_NN(X_train[i].reshape(1,-1)) - \n",
    "                         torch.from_numpy(expert[i])))/n_actions for i in range(n_states)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a456d0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(expert_NN.state_dict(), \"expert.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28d60d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "expert_NN = ThreeLayerMLP(input_dim=n_states, output_dim=n_actions)\n",
    "expert_NN.load_state_dict(torch.load(\"expert.pth\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "838b4fa9",
   "metadata": {},
   "source": [
    "# Generate expert dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "577ad1f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "initial = np.ones(n_states)/n_states\n",
    "\n",
    "torch_softmax = nn.Softmax(dim=1)\n",
    "def to_policy(policy_torch,apply_softmax=True):\n",
    "    X_train = torch.eye(n_states).to(torch.float32) # Random input data\n",
    "    out = policy_torch(X_train)\n",
    "    if apply_softmax:\n",
    "        out = torch_softmax(out)\n",
    "    out = out.detach().numpy()\n",
    "    return out\n",
    "\n",
    "\n",
    "def sample_from_occupancy_measure(env, policy): #, torch_bool=False):\n",
    "    state_id = np.random.choice(np.arange(n_states), p=initial)\n",
    "    #if torch_bool:\n",
    "    #    action_id = np.random.choice(np.arange(n_actions), \n",
    "    #                                 p=policy(torch.eye(n_states).to(torch.float32)[state_id].reshape(1,-1)).flatten().detach().numpy())\n",
    "    #else:\n",
    "    action_id = np.random.choice(np.arange(n_actions), p=policy[state_id])\n",
    "    while True:\n",
    "        stop = np.random.binomial(n=1, p=1 - env.gamma) \n",
    "        if stop:\n",
    "            break\n",
    "        state_id = np.random.choice(np.arange(n_states), p=env.T[state_id,action_id])\n",
    "        action_id = np.random.choice(np.arange(n_actions), p=policy[state_id])\n",
    "    \n",
    "    return state_id\n",
    "            \n",
    "        \n",
    "\n",
    "def sample_dataset(n_expert_samples, env, policy):\n",
    "    states = []\n",
    "    actions = []\n",
    "    for n in range(n_expert_samples):\n",
    "        x = sample_from_occupancy_measure(env, policy)\n",
    "        a = np.random.choice(np.arange(n_actions), p=policy[x])\n",
    "    \n",
    "        states.append(x)\n",
    "        actions.append(a)\n",
    "    return states, actions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c49254a",
   "metadata": {},
   "outputs": [],
   "source": [
    "expert_states, expert_actions = sample_dataset(100, env, expert)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83aeb6a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "expert_states_NN, expert_actions_NN = sample_dataset(1000, env, to_policy(expert_NN,apply_softmax=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "13ebb76c",
   "metadata": {},
   "source": [
    "# Behavioural Cloning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6141bfc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class BC_net(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim1=1000, hidden_dim2=1000, output_dim=10):\n",
    "        super(BC_net, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_dim, hidden_dim1)  # First hidden layer\n",
    "        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)  # Second hidden layer\n",
    "        self.fc3 = nn.Linear(hidden_dim2, output_dim)  # Output layer\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(self.fc1(x))  # ReLU activation after first layer\n",
    "        x = torch.relu(self.fc2(x))  # ReLU activation after second layer\n",
    "        x = self.fc3(x)  # Output layer (logits)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98020421",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def run_bc(states_dataset,actions_dataset):   \n",
    "    batch_size = 32\n",
    "    epochs = 1000\n",
    "    learning_rate = 0.001\n",
    "    values = []\n",
    "    # Instantiate the model\n",
    "    bc_net = BC_net(input_dim=n_states, output_dim=n_actions)\n",
    "\n",
    "    loss_function = nn.CrossEntropyLoss()\n",
    "\n",
    "    # Optimizer (Adam)\n",
    "    optimizer_bc_net = optim.Adam(bc_net.parameters(), lr=learning_rate)\n",
    "\n",
    "\n",
    "    X_train = torch.stack([torch.eye(n_states)[s] for s in states_dataset]) # Random input data\n",
    "    Y_train = torch.stack([torch.eye(n_actions)[a] for a in actions_dataset]) # Random input data\n",
    "    num_samples = len(states_dataset)\n",
    "    # Training Loop\n",
    "    for epoch in range(epochs):\n",
    "        perm = torch.randperm(num_samples)\n",
    "        X_train_batch = X_train[perm]\n",
    "        Y_train_batch = Y_train[perm]\n",
    "        for i in range(0, num_samples, batch_size):\n",
    "            X_batch = X_train_batch[i:i+batch_size]\n",
    "            y_batch = Y_train_batch[i:i+batch_size]\n",
    "\n",
    "            optimizer_bc_net.zero_grad()  # Zero gradients\n",
    "\n",
    "            output = bc_net(X_batch)  # Forward pass\n",
    "            loss = loss_function(output, y_batch)  # Compute loss\n",
    "\n",
    "            loss.backward()  # Backpropagation\n",
    "            optimizer_bc_net.step()  # Update weights\n",
    "        values.append(np.mean(evaluate_policy(to_policy(bc_net),env)))\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "            print(f\"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}\")\n",
    "    return { \"net\": bc_net, \"values\": values}\n",
    "\n",
    "def run_lin_bc(expert_states, expert_actions, env, K=1000, eta=0.8):\n",
    "    values = []\n",
    "    policy_weights = np.zeros(d)\n",
    "    for _ in range(K):\n",
    "        policy = softmax(eta*env.features@policy_weights,1)\n",
    "        policy_weights_grad = np.mean( [env.features[x,a] for x,a in zip(expert_states, expert_actions)],axis=0)\n",
    "        policy_weights_grad -= np.mean([policy[x]@env.features[x] for x in zip(expert_states)],axis=0)\n",
    "        policy_weights += eta*policy_weights_grad\n",
    "        \n",
    "        values.append(np.mean(evaluate_policy(policy,env)))\n",
    "    return { \"net\": policy, \"values\": values}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63f7da23",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run_PrimalDualOffIL(expert_states, expert_actions, env, K=1000, eta=0.08, ThetaMAX = 1):\n",
    "    q_weights = np.zeros(d)\n",
    "    values = []\n",
    "    for k in range(K):\n",
    "        policy = softmax(eta*env.features@q_weights,1)\n",
    "        gain = np.mean([env.features[x,a] for x,a in zip(expert_states, expert_actions)],axis=0)\n",
    "        gain -= np.mean([policy[x]@env.features[x] for x in zip(expert_states)],axis=0)\n",
    "        q_weights += gain / np.linalg.norm(gain)*ThetaMAX\n",
    "        values.append(np.mean(evaluate_policy(policy,env)))\n",
    "    return { \"net\": policy, \"values\": values}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9bada89c",
   "metadata": {},
   "outputs": [],
   "source": [
    "for n_data in [100]:\n",
    "    for seed in range(10):\n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "        torch.cuda.manual_seed(seed)\n",
    "        random.seed(seed)\n",
    "\n",
    "        expert_states, expert_actions = sample_dataset(n_data, env, expert)\n",
    "\n",
    "        pdoIL_dict = run_PrimalDualOffIL(expert_states, expert_actions, env)\n",
    "        lin_bc_dict = run_lin_bc(expert_states, expert_actions, env)\n",
    "\n",
    "        results = { \"BC\": lin_bc_dict, \"PDOIL\": pdoIL_dict}\n",
    "\n",
    "        # Save to a pickle file\n",
    "        with open(f'Pickles/easy_expert_results{n_data}_{seed}.pkl', 'wb') as file:\n",
    "            pickle.dump(results, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94e748c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "for n_data in [100]: \n",
    "    for seed in range(10): \n",
    "        np.random.seed(seed)\n",
    "        torch.manual_seed(seed)\n",
    "        torch.cuda.manual_seed(seed)\n",
    "        random.seed(seed)\n",
    "\n",
    "        expert_states_NN, expert_actions_NN = sample_dataset(n_data, env, to_policy(expert_NN,apply_softmax=False))\n",
    "\n",
    "        pdoIL_dict_NN = run_PrimalDualOffIL(expert_states_NN, expert_actions_NN, env)\n",
    "        bc_dict_NN = run_bc(expert_states_NN, expert_actions_NN)\n",
    "        results = { \"BC\": bc_dict_NN, \"PDOIL\": pdoIL_dict_NN}\n",
    "\n",
    "        # Save to a pickle file\n",
    "        with open(f'Pickles/complex_expert_results{n_data}_{seed}.pkl', 'wb') as file:\n",
    "            pickle.dump(results, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a657314f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.ticker as ticker\n",
    "import seaborn as sns\n",
    "#sns.set_style(\"darkgrid\")\n",
    "\n",
    "def plot(file_load, file_out, is_complex=True):\n",
    "    plt.rcParams.update({\n",
    "    \"text.usetex\": True,\n",
    "    \"font.family\": \"serif\",\n",
    "    })\n",
    "    to_plot_pdoil = []\n",
    "    to_plot_bc = []\n",
    "    to_plot_misbc = []\n",
    "    for seed in range(10):\n",
    "        with open(f'{file_load}{seed}.pkl', 'rb') as file:\n",
    "            results = pickle.load(file)\n",
    "        to_plot_pdoil.append(np.cumsum(results[\"PDOIL\"][\"values\"])/np.cumsum(np.ones_like(results[\"PDOIL\"][\"values\"])))\n",
    "        to_plot_bc.append(np.cumsum(results[\"BC\"][\"values\"])/np.cumsum(np.ones_like(results[\"BC\"][\"values\"])))\n",
    "        \n",
    "    mean_pdoil = np.mean(to_plot_pdoil,axis=0)    \n",
    "    mean_bc = np.mean(to_plot_bc,axis=0)\n",
    "    \n",
    "    std_pdoil = np.std(to_plot_pdoil,axis=0)    \n",
    "    std_bc = np.std(to_plot_bc,axis=0)\n",
    "    \n",
    "    if is_complex:\n",
    "        expert_value = np.mean(evaluate_policy(to_policy(expert_NN,apply_softmax=False), env))\n",
    "    else:\n",
    "        expert_value = np.mean(evaluate_policy(expert, env))\n",
    "    \n",
    "    fig, ax = plt.subplots()\n",
    "    ax.plot(mean_pdoil, label=r\"\\texttt{SPOIL}\", color=\"red\")\n",
    "    ax.plot(mean_bc,  label=r\"\\texttt{BC}\", color=\"blue\")\n",
    "    ax.plot(expert_value*np.ones_like(mean_pdoil),\"--\", color=\"gray\", label=\"Expert\")\n",
    "    \n",
    "    ax.fill_between(np.arange(mean_pdoil.shape[0]), mean_pdoil - std_pdoil, mean_pdoil + std_pdoil, color=\"red\", alpha=0.05)\n",
    "    ax.fill_between(np.arange(mean_bc.shape[0]),mean_bc - std_bc, mean_bc + std_bc, color=\"blue\", alpha=0.05)\n",
    "    plt.legend(fontsize=15)\n",
    "    ax.xaxis.set_major_locator(ticker.MaxNLocator(nbins=5))\n",
    "    ax.yaxis.set_major_locator(ticker.MaxNLocator(nbins=5))\n",
    "    plt.ylabel(\"Return\", fontsize=30)\n",
    "    plt.xlabel(\"Epochs\", fontsize=30)\n",
    "    plt.xticks(fontsize=30)\n",
    "    plt.yticks(fontsize=30)\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(file_out)\n",
    "    print(n_data, np.max(mean_pdoil))\n",
    "    plt.show()\n",
    "print(\"Easy data\")\n",
    "for n_data in [1,2,5,10,20,50]:\n",
    "    plot(f\"Pickles/easy_expert_results{n_data}_\",f\"Easy_expert{n_data}.pdf\", is_complex=False)\n",
    "print(\"Complex data\")\n",
    "for n_data in [1,2,5,10,20,50]:\n",
    "    plot(f\"Pickles/complex_expert_results{n_data}_\",f\"Complex_expert{n_data}.pdf\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
