{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VbD-MdRHTR1T"
      },
      "outputs": [],
      "source": [
        "!pip install functorch\n",
        "!pip install munch\n",
        "!pip install wandb\n",
        "!wandb login"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "bfdlnT8rL2kJ"
      },
      "outputs": [],
      "source": [
        "wandb_agent_mode = False\n",
        "run_tests = False"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "JGh5XZ_zsNLE"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import time\n",
        "import copy\n",
        "torch.set_default_dtype(torch.float64)\n",
        "from munch import Munch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uav6Pc8Jv95L"
      },
      "source": [
        "# Wandb and parameterization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "-bds1iRYu7hA"
      },
      "outputs": [],
      "source": [
        "import wandb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yh1J2oB9wRG5"
      },
      "outputs": [],
      "source": [
        "# Currently this is as in the sweep for the main experiment in the paper\n",
        "args = {\n",
        "  \"RANDOM_SEED\" : 3,\n",
        "  \"INP_DIM\" : 10,\n",
        "  \"OUTP_DIM\" : 3,\n",
        "  \"N_OF_INPS\" : 50, #number of different inputs used both for computing diff and for\n",
        "                #calculating expected utilities\n",
        "  \"G\" : 5,\n",
        "  \"MANUAL_fCD\" : True,\n",
        "  \"fCD_HIDDEN_DIM_LIST\" : [4],\n",
        "  \"fCD_HIDDEN_TYPE_LIST\" : [\"LeakyReLU\", \"LeakyReLU\"],\n",
        "  \"MODEL_HIDDEN_DIM_LIST\" : [100, 50, 50],\n",
        "  \"MODEL_HIDDEN_TYPE_LIST\" : [\"LeakyReLU\",\"LeakyReLU\",\"LeakyReLU\",\"LeakyReLU\"],\n",
        "  \"NOISE_TYPE\" : \"uniform\",\n",
        "  \"NOISE_SIZE\" : 0.1,\n",
        "  \"TEST_DIFFS_RANGE\" : 0.2,\n",
        "  \"STEP_2_SELF_PLAY_PROB\" : 0.5,\n",
        "  \"STEP_2_LR\" : 0.02,\n",
        "  \"STEP_2_OPTIMIZER\" : \"Adam\",\n",
        "  \"N_OF_STEPS_STEP_2\" : 200,\n",
        "  \"NO_NOISE_STEP_2\" : True,\n",
        "  \"LOLA_LR\" : 0.0001,\n",
        "  \"LOLA_LA\" : 0.001,\n",
        "  \"N_OF_LOLA_STEPS\" : 0,\n",
        "  \"LOLA_EARLY_TERMINATION\" : True,\n",
        "  \"LOLA_EARLY_TERMINATION_SHIELD\" : 20000,\n",
        "  \"LOLA_LINEAR_LA_DECAY\" : True,\n",
        "  \"LOG_BEHAVIOR_GRAPH_EVERY_N_STEPS\" : 1000,\n",
        "  \"TAYLOR_LOLA\" : False,\n",
        "  \"N_OF_STEPS_PER_TURN_STEP_3\" : 3000,\n",
        "  \"STEP_3_LR\" : 0.00005,\n",
        "  \"N_OF_TURNS_STEP_3\" : 1000,\n",
        "  \"STEP_3_OPTIMIZER\": \"SGD\",\n",
        "  \"STEP_3_IMPROVEMENTS_ONLY\": True,\n",
        "  \"STEP_3_LR_EXPONENT\": 0\n",
        "}\n",
        "args = Munch.fromDict(args)\n",
        "\n",
        "#Track parameters in wandb:\n",
        "wandb_run = wandb.init(project=\"SR-unified\", entity=\"casparoesterheld\",\\\n",
        "                        config = args, group= \"v0\")\n",
        "\n",
        "#For sweeps:\n",
        "args = wandb.config\n",
        "\n",
        "torch.manual_seed(args.RANDOM_SEED)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ORGfZR_jyWBm"
      },
      "source": [
        "# Neural Nets\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "vtijwLksyVSv"
      },
      "outputs": [],
      "source": [
        "class NeuralNet(nn.Module):\n",
        "  def __init__(self, inp_dim, hidden_dim_list, hidden_type_list, outp_dim):\n",
        "    super(NeuralNet,self).__init__()\n",
        "    self.inp_dim = inp_dim\n",
        "    self.outp_dim = outp_dim\n",
        "    full_dims_list = [inp_dim] + hidden_dim_list\n",
        "    self.layers = nn.ModuleList()\n",
        "    for i in range(len(full_dims_list)-1):\n",
        "      self.layers.append(nn.Linear(full_dims_list[i], full_dims_list[i+1]))\n",
        "      if hidden_type_list[i] == \"LeakyReLU\":\n",
        "        self.layers.append(nn.LeakyReLU())\n",
        "      elif hidden_type_list[i] == \"LogSigmoid\":\n",
        "        self.layers.append(nn.LogSigmoid())\n",
        "      else:\n",
        "        assert False\n",
        "    self.layers.append(nn.Linear(full_dims_list[-1], outp_dim))\n",
        "\n",
        "  def forward(self, x):\n",
        "    result = x\n",
        "    for layer in self.layers:\n",
        "      result = layer(result)\n",
        "    return result"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rxDARvB2vHWc"
      },
      "source": [
        "# Reimplementation of LOLA"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "9eZ2HYiFvKeb"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import copy\n",
        "from functorch import make_functional, grad_and_value\n",
        "from functools import partial\n",
        "\n",
        "def __get_gradient(value, params):\n",
        "  grads = torch.autograd.grad(value, params, create_graph=True)\n",
        "  return grads\n",
        "\n",
        "def lola_update(loss_fn, models, alpha, beta, algo='exact_lola'):\n",
        "  # loss_fn takes in list of 2 NN model forward functions and outputs loss values for each player\n",
        "  # (i.e. you CANNOT CALL model.forward() in the loss function. Instead, just use model(x))\n",
        "  # models is a list of two models for each player\n",
        "  # alpha is learning rate\n",
        "  # beta is opponent shaping / lookahead rate\n",
        "  # algo in ['taylor_lola', 'exact_lola'] specifies whether to use taylor or exact lola\n",
        "  # returns a list of current losses for both players\n",
        "\n",
        "  assert len(models) == 2, 'more than 2 players not implemented yet'\n",
        "\n",
        "  for model in models:\n",
        "    model.zero_grad()\n",
        "\n",
        "  n = 2\n",
        "\n",
        "  #start by making models functional\n",
        "  funcs_params = [make_functional(model) for model in models]\n",
        "  funcs = [func_model[0] for func_model in funcs_params]\n",
        "  params = [func_model[1] for func_model in funcs_params]\n",
        "\n",
        "  assert len(params[0]) == len(params[1]), 'different number of params for different players not implemented yet'\n",
        "\n",
        "  n_params = len(params[0])\n",
        "\n",
        "  def _Ls(params):\n",
        "    _models = [partial(funcs[i],params[i]) for i in range(n)]\n",
        "    return loss_fn(_models)\n",
        "\n",
        "  losses = _Ls(params)\n",
        "\n",
        "  if algo == 'exact_lola':\n",
        "    opponent_steps = [__get_gradient(losses[i], params[i]) for i in range(n)]\n",
        "\n",
        "    # update opponent's parameters with naive learning\n",
        "    updated_params = [tuple(params[i][j] - beta*opponent_steps[i][j] for j in range(n_params)) for i in range(n)]\n",
        "\n",
        "\n",
        "    # calculate loss for players again, using updated params\n",
        "    lookahead_losses = [_Ls([params[0],updated_params[1]])[0],\n",
        "                        _Ls([updated_params[0],params[1]])[1]]\n",
        "\n",
        "    grads = [__get_gradient(lookahead_losses[i], params[i]) for i in range(n)]\n",
        "\n",
        "  elif algo == 'taylor_lola':\n",
        "    grad_L = [[__get_gradient(losses[j], params[i]) for j in range(n)] for i in range(n)]\n",
        "\n",
        "    terms = [sum([torch.dot(grad_L[j][i][v].flatten(), grad_L[j][j][v].flatten())\n",
        "                  for j in range(n) if j != i for v in range(n_params)]) for i in range(n)]\n",
        "\n",
        "    second_order_grads = [__get_gradient(terms[i], params[i]) for i in range(n)]\n",
        "    grads = [tuple(grad_L[i][i][j]-beta*second_order_grads[i][j] for j in range(n_params)) for i in range(n)]\n",
        "\n",
        "  else:\n",
        "    raise Exception(f'invalid algo specification: {algo}')\n",
        "\n",
        "  # update all players' params for real\n",
        "  with torch.no_grad():\n",
        "    for i in range(n):\n",
        "        for j, param in enumerate(models[i].parameters()):\n",
        "          param.sub_(alpha*grads[i][j])\n",
        "\n",
        "  return losses"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S8RruMx2v6Wq"
      },
      "source": [
        "# Definition of the diff game"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "Vt1YvCTOVq_t"
      },
      "outputs": [],
      "source": [
        "from random import randint\n",
        "\n",
        "#This is used to randomly generate functions f_C and f_D.\n",
        "def random_sum_sin(inps):\n",
        "  assert inps.shape == (args.N_OF_INPS, args.INP_DIM)\n",
        "  selector = torch.Tensor([randint(0, 1) for i in range(args.INP_DIM)])\n",
        "  sums = inps @ selector\n",
        "  outpi = torch.sin(sums)\n",
        "  outpi = torch.reshape(outpi, (args.N_OF_INPS, 1))\n",
        "  return outpi"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FaoPt-lAx8_6"
      },
      "outputs": [],
      "source": [
        "def generate_noise():\n",
        "  if args.NOISE_TYPE == \"uniform\":\n",
        "    return args.NOISE_SIZE * torch.rand(args.N_OF_INPS, 1)\n",
        "  elif args.NOISE_TYPE == \"normal\":\n",
        "    return torch.normal(mean=0.0,std=args.NOISE_SIZE, size=(args.N_OF_INPS, 1))\n",
        "  assert False\n",
        "\n",
        "def distance(t1, t2):\n",
        "  assert t1.shape[1] == args.OUTP_DIM\n",
        "  assert t2.shape[1] == args.OUTP_DIM\n",
        "  return torch.mean(torch.linalg.norm(t1-t2, dim=1, ord=2))\n",
        "\n",
        "inps = torch.rand(args.N_OF_INPS, args.INP_DIM)\n",
        "test_diffs = args.TEST_DIFFS_RANGE * torch.rand(args.N_OF_INPS, 1)\n",
        "\n",
        "if args.NO_NOISE_STEP_2:\n",
        "  test_diff_noises = torch.zeros(args.N_OF_INPS, 1)\n",
        "  diff_noise_vals_pl1 = torch.zeros(args.N_OF_INPS, 1)\n",
        "  diff_noise_vals_pl2 = torch.zeros(args.N_OF_INPS, 1)\n",
        "else:\n",
        "  test_diff_noises, diff_noise_vals_pl1, diff_noise_vals_pl2 = generate_noise(), generate_noise(), generate_noise()\n",
        "\n",
        "test_diffs += test_diff_noises\n",
        "test_inps = torch.cat((test_diffs, inps), dim=1)\n",
        "\n",
        "\n",
        "if args.MANUAL_fCD:\n",
        "  fD_vals = torch.cat([random_sum_sin(inps) for _ in range(args.OUTP_DIM)], dim=1)\n",
        "  fC_vals = torch.cat([random_sum_sin(inps) for _ in range(args.OUTP_DIM)], dim=1)\n",
        "  print(\"fD_vals\",fD_vals)\n",
        "else:\n",
        "  fD_model = NeuralNet(args.INP_DIM, args.fCD_HIDDEN_DIM_LIST, args.fCD_HIDDEN_TYPE_LIST, args.OUTP_DIM)\n",
        "  fD_vals = fD_model.forward(inps)\n",
        "  print(fD_vals)\n",
        "  del fD_model\n",
        "  fC_model = NeuralNet(args.INP_DIM, args.fCD_HIDDEN_DIM_LIST, args.fCD_HIDDEN_TYPE_LIST, args.OUTP_DIM)\n",
        "  fC_vals = fC_model.forward(inps)\n",
        "  print(fC_vals)\n",
        "  del fC_model\n",
        "\n",
        "CD_diff = distance(fC_vals,fD_vals).item()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "IWF7cIn7yRJp"
      },
      "outputs": [],
      "source": [
        "def diff(model1, model2):\n",
        "  assert model1.inp_dim ==  model2.inp_dim == args.INP_DIM+1\n",
        "  y1 = model1.forward(test_inps)\n",
        "  y2 = model2.forward(test_inps)\n",
        "  return distance(y1, y2)/CD_diff\n",
        "\n",
        "#BEGIN TEST\n",
        "#Tests whether the distance of a model to itself is small.\n",
        "if run_tests:\n",
        "  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)\n",
        "  assert -1e-10< diff(model,model).item() < 1e-10\n",
        "  del model\n",
        "#END Test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "yMxYEdESzyQQ"
      },
      "outputs": [],
      "source": [
        "def loss(models, false_diff = None):\n",
        "  model1 = models[0]\n",
        "  model2 = models[1]\n",
        "  d = diff(model1,model2)\n",
        "  if false_diff is not None:\n",
        "    d = torch.Tensor([false_diff]).repeat(args.N_OF_INPS,1)\n",
        "  d1 = d + diff_noise_vals_pl1\n",
        "  d2 = d + diff_noise_vals_pl2\n",
        "  inps1 = torch.cat((d1,inps), dim=1)\n",
        "  inps2 = torch.cat((d2,inps), dim=1)\n",
        "  y1 = model1.forward(inps1)\n",
        "  y2 = model2.forward(inps2)\n",
        "  loss1 = torch.reshape((distance(y1,fD_vals) + args.G * distance(y2,fC_vals))/CD_diff,(1,))\n",
        "  loss2 = torch.reshape((distance(y2,fD_vals) + args.G * distance(y1,fC_vals))/CD_diff,(1,))\n",
        "  return [loss1,loss2]\n",
        "\n",
        "#BEGIN TEST\n",
        "if run_tests:\n",
        "  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)\n",
        "  assert -0.005< loss([model,model])[0].item() - loss([model,model])[1].item() < 0.005\n",
        "  del model\n",
        "#END Test"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "3PCjo8azT1kW"
      },
      "outputs": [],
      "source": [
        "# Now we implement a copy of the original loss function, except that it doesn't\n",
        "# use model.forward(inps). Instead, it just uses model(inps).\n",
        "# This is necessary for our implementations of exact and Taylor LOLA.\n",
        "# We still need the original function for everything else.\n",
        "def diff_without_forward(model1, model2):\n",
        "  y1 = model1(test_inps)\n",
        "  y2 = model2(test_inps)\n",
        "  return distance(y1, y2)/CD_diff\n",
        "\n",
        "#BEGIN TEST\n",
        "if run_tests:\n",
        "  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)\n",
        "  assert -1e-10< diff(model,model).item() < 1e-10\n",
        "  del model\n",
        "#END Test\n",
        "\n",
        "def loss_without_forward(models, false_diff = None):\n",
        "  model1 = models[0]\n",
        "  model2 = models[1]\n",
        "  d = diff_without_forward(model1,model2)\n",
        "  if false_diff is not None:\n",
        "    d = torch.Tensor([false_diff]).repeat(args.N_OF_INPS,1)\n",
        "  d1 = d + diff_noise_vals_pl1\n",
        "  d2 = d + diff_noise_vals_pl2\n",
        "  inps1 = torch.cat((d1,inps), dim=1)\n",
        "  inps2 = torch.cat((d2,inps), dim=1)\n",
        "  y1 = model1(inps1)\n",
        "  y2 = model2(inps2)\n",
        "  loss1 = torch.reshape((distance(y1,fD_vals) + args.G * distance(y2,fC_vals))/CD_diff,(1,))\n",
        "  loss2 = torch.reshape((distance(y2,fD_vals) + args.G * distance(y1,fC_vals))/CD_diff,(1,))\n",
        "  return [loss1,loss2]\n",
        "\n",
        "#BEGIN TEST\n",
        "if not wandb_agent_mode:\n",
        "  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)\n",
        "  assert -0.005< loss_without_forward([model,model])[0].item() - loss_without_forward([model,model])[1].item() < 0.005\n",
        "  del model\n",
        "#END Test"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ccLDRUi-Axrd"
      },
      "source": [
        "\n",
        "# Analysis and logging functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "bdPOs7gyIsSs"
      },
      "outputs": [],
      "source": [
        "import copy as copy\n",
        "\n",
        "#This method tests whether a given model is a best response to another given model.\n",
        "#It does so by perturbing the given model and seeing whether this yields lower loss.\n",
        "def best_response_test(model, opponent_model, n_of_perturbations=10000, epsilon=0.0001):\n",
        "  n_of_improvements = 0\n",
        "  old_loss = loss([model, opponent_model])[0].item()\n",
        "  for _ in range(n_of_perturbations):\n",
        "    alt_model = copy.deepcopy(model)\n",
        "    state_dict = alt_model.state_dict()\n",
        "    for name, param in state_dict.items():\n",
        "      # Transform the parameter as required.\n",
        "      transformed_param = param + epsilon*torch.rand(param.shape) - epsilon/2\n",
        "      # Update the parameter.\n",
        "      param.copy_(transformed_param)\n",
        "    if loss([alt_model, opponent_model])[0].item() < old_loss:\n",
        "      n_of_improvements += 1\n",
        "  return 1-(n_of_improvements / n_of_perturbations)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "bGtMoBAcAwWa"
      },
      "outputs": [],
      "source": [
        "#This asks: if artificially both agents received a slightly higher diff input,\n",
        "# how would this affect their utility. If the models successfully set incentives\n",
        "# on each other, this should be positive.\n",
        "def loss_diff_rate(model, opponent, epsilon = 0.001):\n",
        "  true_losses = loss([model, opponent])[0].item()\n",
        "  incr_loss = loss([model, opponent], diff(model,opponent).item()+epsilon)[1].item()\n",
        "  return (incr_loss-true_losses)/epsilon\n",
        "\n",
        "def model_behavior_stats(model, n_of_points = 300):\n",
        "  test_diffs = [torch.tensor([i/n_of_points]) for i in range(n_of_points)]\n",
        "  diffs_to_fC = []\n",
        "  diffs_to_fD = []\n",
        "  for test_diff in test_diffs:\n",
        "    full_inps = torch.cat((torch.reshape(test_diff.repeat(args.N_OF_INPS),(args.N_OF_INPS,1)),inps), dim=1)\n",
        "    outps = model.forward(full_inps)\n",
        "    diffs_to_fC.append(distance(outps,fC_vals).item())\n",
        "    diffs_to_fD.append(distance(outps,fD_vals).item())\n",
        "  return test_diffs, diffs_to_fC, diffs_to_fD\n",
        "\n",
        "# Deprecated -- use log_behavior_graph_to_wandb instead\n",
        "def behavior_graph(model, n_of_points = 300):\n",
        "  test_diffs, diffs_to_fC, diffs_to_fD = model_behavior_stats(model, n_of_points)\n",
        "  plt.plot(test_diffs, diffs_to_fD, label=\"dist to D\")\n",
        "  plt.plot(test_diffs, diffs_to_fC, label=\"dist to C\")\n",
        "  plt.plot(test_diffs, [diffs_to_fD[i] + diffs_to_fC[i] for i in range(n_of_points)],\\\n",
        "            label=\"dist to C + dist to D\")\n",
        "  plt.plot(test_diffs, [CD_diff]*n_of_points, label=\"dist C to D\")\n",
        "  plt.legend()\n",
        "  plt.xlabel(\"ag_diff\")\n",
        "  plt.ylabel(\"dist to C/D\")\n",
        "  plt.ylim(bottom=0)\n",
        "  plt.show()\n",
        "\n",
        "def log_behavior_graph_to_wandb(model, id, n_of_points=300):\n",
        "  test_diffs, diffs_to_fC, diffs_to_fD = model_behavior_stats(model, n_of_points)\n",
        "\n",
        "  wandb.log({id : wandb.plot.line_series(\n",
        "                       xs=[d.item() for d in test_diffs],\n",
        "                       ys=[diffs_to_fD, diffs_to_fC,\\\n",
        "                           [y+z for (y, z) in zip(diffs_to_fD, diffs_to_fC)],\\\n",
        "                           [CD_diff]*len(test_diffs)],\n",
        "                       keys=[\"dist to D\", \"dist to C\", \"dist to C + dist to D\", \"dist C to D\"],\n",
        "                       title=id,\n",
        "                       xname=\"ag_diff\")})\n",
        "\n",
        "#shows how the behavior graph has changed in an update step\n",
        "def behavior_change_graph(old_model, new_model, diff_val=None, n_of_points = 300):\n",
        "  test_diffs, diffs_to_fC_old, diffs_to_fD_old = model_behavior_stats(old_model, n_of_points)\n",
        "  test_diffs, diffs_to_fC_new, diffs_to_fD_new = model_behavior_stats(new_model, n_of_points)\n",
        "  plt.plot(test_diffs, [diffs_to_fD_new[i] - diffs_to_fD_old[i] for i in range(n_of_points)], label=\"increase in dist to D\")\n",
        "  plt.plot(test_diffs, [diffs_to_fC_new[i] - diffs_to_fC_old[i] for i in range(n_of_points)], label=\"increase in dist to C\")\n",
        "  plt.axhline(y=0)\n",
        "  if diff_val is not None:\n",
        "    plt.axvline(x=diff_val, label=\"current diff between models[0] and models[1]\")\n",
        "  plt.legend(bbox_to_anchor=(1.1, 1.05))\n",
        "  plt.xlabel(\"ag_diff\")\n",
        "  plt.show()\n",
        "\n",
        "\n",
        "def print_diffs_to_random(model):\n",
        "    opponents_lst = [NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST,args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)\\\n",
        "                        for i in range(30)]\n",
        "    trunc_diffs = [truncate(diff(model, opponent).item(),4) for opponent in opponents_lst]\n",
        "    trunc_diffs.sort()\n",
        "    print(\"diffs to random:\")\n",
        "    print(trunc_diffs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KCqxSOFK2DIS"
      },
      "source": [
        "# Step 2 pre-training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "jUyakNIQ2H2Z"
      },
      "outputs": [],
      "source": [
        "def step2_loss(model, opponents_lst, self_play_prob=1/2):\n",
        "  loss_val = torch.zeros(1)\n",
        "  loss_val += self_play_prob * loss([model,model])[0]\n",
        "  for opponent in opponents_lst:\n",
        "    loss_val += (1-self_play_prob) * (1/len(opponents_lst)) * loss([model,opponent])[0]\n",
        "  return loss_val\n",
        "\n",
        "def step2(model, n_of_steps=100, self_play_prob=1/2, n_of_opponents_per_epoch=100,\n",
        "            print_progress=True, wandb_track = True):\n",
        "  if args.STEP_2_OPTIMIZER == \"Adam\":\n",
        "    optimizer = optim.Adam(model.parameters(), lr=args.STEP_2_LR)\n",
        "  elif args.STEP_2_OPTIMIZER == \"SGD\":\n",
        "    optimizer = optim.SGD(model.parameters(), lr=args.STEP_2_LR)\n",
        "  else:\n",
        "    assert False\n",
        "  for step_no in range(n_of_steps):\n",
        "    optimizer.zero_grad()\n",
        "    opponents_lst = [NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)\\\n",
        "                        for i in range(n_of_opponents_per_epoch)]\n",
        "    loss_val = step2_loss(model, opponents_lst, self_play_prob)\n",
        "    loss_val.backward(retain_graph = True)\n",
        "\n",
        "    #for logging, calculate the avg of the entries of the gradient\n",
        "    #Note that this actually the average of averages of parameters in each layer.\n",
        "    gradient_avg = 0\n",
        "    count = 0\n",
        "    for param in model.parameters():\n",
        "      count += 1\n",
        "      gradient_avg += torch.mean(torch.abs(param.grad)).item()\n",
        "    gradient_avg = gradient_avg/count\n",
        "\n",
        "    optimizer.step()\n",
        "\n",
        "    if wandb_track:\n",
        "      metrics = {\n",
        "        'step_2_loss': loss_val[0].item(),\n",
        "        'phase': \"Step 2\",\n",
        "        'loss of (D,D)': args.G,\n",
        "        'loss of (C,C)': 1,\n",
        "        'gradient avg step 2': gradient_avg\n",
        "      }\n",
        "      wandb.log(metrics)\n",
        "\n",
        "    if print_progress:\n",
        "      print('Step 2 Epoch {}, Loss {}'.format(step_no, loss_val.item()))\n",
        "\n",
        "\n",
        "### BEGIN TEST (of Step 2)\n",
        "if run_tests:\n",
        "  #These tests depend on args and don't work for all args.\n",
        "\n",
        "  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)\n",
        "  step2(model, n_of_steps=5000, self_play_prob=1, print_progress=False, wandb_track=False)\n",
        "\n",
        "  assert 1<=loss([model,model])[0].item() < 1.1\n",
        "  del model\n",
        "\n",
        "  model = NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST,args.OUTP_DIM)\n",
        "  step2(model, n_of_steps=5000, n_of_opponents_per_epoch=10, self_play_prob=0, print_progress=False, wandb_track=False)\n",
        "\n",
        "  print(\"loss:\", loss([model,model]))\n",
        "  print(\"CD_diff:\", CD_diff)\n",
        "  assert args.G-0.3<=loss([model,model],false_diff=0.7)[0].item() < args.G+0.3\n",
        "  del model\n",
        "### END TEST"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9j3s3BWyun4M"
      },
      "source": [
        "# Step 3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "id": "vJ-6HAlbup9z"
      },
      "outputs": [],
      "source": [
        "import copy as copy\n",
        "import math as math\n",
        "\n",
        "def step3(model1, model2=None, n_of_turns=100, print_progress = False,\\\n",
        "          unilateral = False, self_play=False, n_of_steps_per_turn=1,\n",
        "          print_turn_numbers=True, optimizer_type=args.STEP_3_OPTIMIZER,\n",
        "          early_termination_at_mutual_D=True):\n",
        "  models = [model1, model2]\n",
        "  agents_to_update = [0]\n",
        "  if self_play:\n",
        "    unilateral = True\n",
        "  if not unilateral:\n",
        "    agents_to_update.append(1)\n",
        "  optimizers = [None] * len(agents_to_update)\n",
        "\n",
        "  #variables used for analysis\n",
        "  n_of_successful_steps = [0] * len(agents_to_update)\n",
        "\n",
        "  pre_turn_losses = loss(models)\n",
        "  pre_turn_losses = (pre_turn_losses[0].item(), pre_turn_losses[1].item())\n",
        "\n",
        "  for turn_no in range(n_of_turns):\n",
        "    print (\"Turn \",turn_no+1,\" of \",n_of_turns, \"Losses: \", pre_turn_losses[0],\", \",pre_turn_losses[1])\n",
        "    if self_play:\n",
        "      models[1] = copy.deepcopy(models[0])\n",
        "    for i in agents_to_update:\n",
        "      if optimizer_type == \"Adam\":\n",
        "        optimizers[i] = optim.Adam(models[i].parameters(), lr=args.STEP_3_LR)\n",
        "      elif optimizer_type == \"SGD\":\n",
        "        optimizers[i] = optim.SGD(models[i].parameters(), lr=args.STEP_3_LR)\n",
        "      else:\n",
        "        assert False\n",
        "\n",
        "      for step_no in range(n_of_steps_per_turn):\n",
        "        #reduce learning rate\n",
        "        step_size = ((1+turn_no*n_of_steps_per_turn+step_no) ** args.STEP_3_LR_EXPONENT) * args.STEP_3_LR * torch.rand(1).item()\n",
        "        for g in optimizers[i].param_groups:\n",
        "          g['lr'] = step_size\n",
        "        old_model = copy.deepcopy(models[i])\n",
        "\n",
        "        models[i].zero_grad()\n",
        "\n",
        "        loss_vals = loss(models)\n",
        "        loss_vals[i].backward(retain_graph = True)\n",
        "\n",
        "        #for analysis/logging: calculate the avg of the entries of the gradient and the parameters themselves\n",
        "        gradient_avg = 0\n",
        "        param_avg = 0\n",
        "        count = 0\n",
        "        for param in models[i].parameters():\n",
        "          count += 1\n",
        "          gradient_avg += torch.mean(torch.abs(param.grad)).item()\n",
        "          param_avg += torch.mean(torch.abs(param)).item()\n",
        "        gradient_avg = gradient_avg/count\n",
        "        param_avg = param_avg/count\n",
        "\n",
        "        optimizers[i].step()\n",
        "\n",
        "        new_losses = loss(models)\n",
        "\n",
        "        #If they defect, give up (to save compute)\n",
        "        if early_termination_at_mutual_D and (new_losses[0].item() >= args.G\\\n",
        "                                              or new_losses[1].item() >= args.G\\\n",
        "            or math.isnan(loss_vals[0].item()) or math.isnan(loss_vals[1].item())):\n",
        "          wandb_run.tags = wandb_run.tags + (\"Fail\",)\n",
        "          return\n",
        "\n",
        "        #If gradient step increased loss, then revert:\n",
        "        loss_decrease = loss_vals[i].item() - new_losses[i].item()\n",
        "        if args.STEP_3_IMPROVEMENTS_ONLY and loss_decrease < 0.0:\n",
        "          models[i].load_state_dict(old_model.state_dict())\n",
        "        else:\n",
        "          n_of_successful_steps[i] += 1\n",
        "          if print_progress:\n",
        "            print(\"Agent \",i,\" just made an update.\")\n",
        "            print(\"diff between agents: \",diff(models[0],models[1]))\n",
        "            print(\"Current rates of increasing diff:\", loss_diff_rate(models[0], models[1]))\n",
        "            print('Step 3 Ag {} Turn {}, Losses [0, 1]: {}, {}'.format(i, turn_no,\\\n",
        "                                                              loss(models)[0].item(),\\\n",
        "                                                              loss(models)[1].item()))\n",
        "\n",
        "      if print_progress:\n",
        "          behavior_graph(models[i])\n",
        "          behavior_change_graph(old_model=old_model, new_model=models[i],\\\n",
        "                                diff_val=diff(models[0],models[1]).item())\n",
        "\n",
        "    #log change in loss throughout step\n",
        "    new_losses = loss(models)\n",
        "    new_losses = (new_losses[0].item(), new_losses[1].item())\n",
        "    wandb.log({'loss_0': new_losses[0],\n",
        "               'loss_1': new_losses[1],\n",
        "               'turn_no': turn_no,\n",
        "               'phase': \"Step 3\",\n",
        "               'loss of (D,D)': args.G,\n",
        "               'loss of (C,C)': 1,\n",
        "               'agent_diffs': diff(models[0], models[1]).item(),\n",
        "               'turn_loss_decrease_0' : pre_turn_losses[0]-new_losses[0],\n",
        "               'turn_loss_decrease_1' : pre_turn_losses[1]-new_losses[1],\n",
        "               'loss_diff': abs(new_losses[0]-new_losses[1])})\n",
        "    pre_turn_losses = new_losses\n",
        "\n",
        "  print(\"Number of successful steps:\", n_of_successful_steps)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wZ_s9uoQ1EeE"
      },
      "source": [
        "# Running the experiments\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-qkHJOef1QSa"
      },
      "source": [
        "## Initialization"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "id": "fxk5BGj41Os5"
      },
      "outputs": [],
      "source": [
        "list_of_agents_to_create = [0, 1]\n",
        "models = []\n",
        "for i in list_of_agents_to_create:\n",
        "  models.append(NeuralNet(args.INP_DIM+1,args.MODEL_HIDDEN_DIM_LIST, args.MODEL_HIDDEN_TYPE_LIST, args.OUTP_DIM))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "isblR0P11-Xb"
      },
      "source": [
        "## Pre-train (\"Step 2\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HPedXUdvhGon"
      },
      "outputs": [],
      "source": [
        "for i in list_of_agents_to_create:\n",
        "  step2(models[i], n_of_steps=args.N_OF_STEPS_STEP_2, self_play_prob=args.STEP_2_SELF_PLAY_PROB)\n",
        "if len(list_of_agents_to_create)==2:\n",
        "  print(\"diff after Step 2: \",diff(models[0],models[1]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "id": "VDvWwTicHOo-"
      },
      "outputs": [],
      "source": [
        "log_behavior_graph_to_wandb(models[0], id=\"Model 0 after Step 2\")\n",
        "log_behavior_graph_to_wandb(models[1], id=\"Model 1 after Step 2\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qBtkbssaNTxS",
        "outputId": "5efeaebf-9691-4857-8bfb-388a07ca7074"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0.493\n",
            "0.522\n"
          ]
        }
      ],
      "source": [
        "print(best_response_test(models[0],models[1], n_of_perturbations=1000))\n",
        "print(best_response_test(models[1],models[0], n_of_perturbations=1000))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "id": "531offnt3keU"
      },
      "outputs": [],
      "source": [
        "# If step 2 used 0 noise, we now need to set the noise to something non-zero.\n",
        "if args.NO_NOISE_STEP_2:\n",
        "    test_diff_noises, diff_noise_vals_pl1, diff_noise_vals_pl2 = generate_noise(), generate_noise(), generate_noise()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l7yNLN8r3lmy"
      },
      "source": [
        "## LOLA"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "id": "CuH0_FgZ3mcj"
      },
      "outputs": [],
      "source": [
        "def is_successful(loss_hist, threshold_factor=0.8, backward_horizon=100):\n",
        "  for i in range(1,backward_horizon+1):\n",
        "    if loss_hist[-i][0] > threshold_factor * args.G:\n",
        "      return False\n",
        "\n",
        "\n",
        "import math\n",
        "\n",
        "early_termination = False\n",
        "old_loss_vals = (0,0)\n",
        "loss_hist = []\n",
        "for step_no in range(args.N_OF_LOLA_STEPS):\n",
        "  if args.LOLA_LINEAR_LA_DECAY:\n",
        "    la = args.LOLA_LA * (1- step_no/args.N_OF_LOLA_STEPS)\n",
        "  else:\n",
        "    la = args.LOLA_LA\n",
        "  if not early_termination:\n",
        "    if args.TAYLOR_LOLA:\n",
        "      loss_vals = lola_update(loss_without_forward,models,beta=la,alpha=args.LOLA_LR, algo='taylor_lola')\n",
        "    else:\n",
        "      loss_vals = lola_update(loss_without_forward,models,beta=la,alpha=args.LOLA_LR, algo='exact_lola')\n",
        "    loss_hist.append((loss_vals[0].item(), loss_vals[1].item()))\n",
        "    wandb.log({\"loss_0\": loss_vals[0].item(),\n",
        "               \"loss_1\": loss_vals[1].item(),\n",
        "               'turn_loss_decrease_0' : old_loss_vals[0]-loss_vals[0].item(),\n",
        "               'turn_loss_decrease_1' : old_loss_vals[1]-loss_vals[1].item(),\n",
        "               'loss of (D,D)': args.G,\n",
        "               'loss of (C,C)': 1,\n",
        "               'agent_diffs': diff(models[0], models[1]).item()})\n",
        "    old_loss_vals = (loss_vals[0].item(), loss_vals[1].item())\n",
        "    if step_no % args.LOG_BEHAVIOR_GRAPH_EVERY_N_STEPS == 0:\n",
        "      log_behavior_graph_to_wandb(models[0], id=\"Model 0 after \"+str(step_no)+\" steps of LOLA\")\n",
        "      log_behavior_graph_to_wandb(models[1], id=\"Model 1 after \"+str(step_no)+\" steps of LOLA\")\n",
        "    print(\"Turn\", step_no, \". Current loss:\", loss_vals[0].item(), \",\", loss_vals[1].item())\n",
        "    if (loss_vals[0].item()> args.G and loss_vals[1].item()> args.G and args.LOLA_EARLY_TERMINATION and (args.N_OF_STEPS_STEP_2>1 or step_no>args.LOLA_EARLY_TERMINATION_SHIELD))\\\n",
        "            or math.isnan(loss_vals[0].item()) or math.isnan(loss_vals[1].item()) or math.isinf(loss_vals[0].item()) or math.isinf(loss_vals[1].item()):\n",
        "      early_termination = True\n",
        "      wandb_run.tags = wandb_run.tags + (\"Fail\",)\n",
        "      wandb.log({\n",
        "          \"n_of_lola_steps_before_failure\" : step_no\n",
        "      })\n",
        "if not early_termination:\n",
        "  wandb.log({\n",
        "          \"n_of_lola_steps_before_failure\" : args.N_OF_LOLA_STEPS\n",
        "  })\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {
        "id": "fHAeSaw5fGyg"
      },
      "outputs": [],
      "source": [
        "log_behavior_graph_to_wandb(models[0], id=\"Model 0 after LOLA\")\n",
        "log_behavior_graph_to_wandb(models[1], id=\"Model 1 after LOLA\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "n9DWt1foUXVf"
      },
      "outputs": [],
      "source": [
        "print (best_response_test(models[0],models[1]))\n",
        "print (best_response_test(models[1],models[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oLWC2yPz21Ub"
      },
      "source": [
        "## Mutual best response learning (\"Step 3\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wRCOWw3125P3"
      },
      "outputs": [],
      "source": [
        "step3(models[0], models[1], n_of_turns = args.N_OF_TURNS_STEP_3,\\\n",
        "      print_progress=False, self_play=False, n_of_steps_per_turn = args.N_OF_STEPS_PER_TURN_STEP_3,\\\n",
        "      early_termination_at_mutual_D=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 26,
      "metadata": {
        "id": "heHNkwQaHGuG"
      },
      "outputs": [],
      "source": [
        "log_behavior_graph_to_wandb(models[0], id=\"Model 0 after Step 3\")\n",
        "log_behavior_graph_to_wandb(models[1], id=\"Model 1 after Step 3\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QvLwiCMZ28hm"
      },
      "outputs": [],
      "source": [
        "print (best_response_test(models[0],models[1]))\n",
        "print (best_response_test(models[1],models[0]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DbB6S0og6sWU"
      },
      "outputs": [],
      "source": [
        "wandb.finish()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "toc_visible": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}