{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "111c502f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0,'..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e6b59ce3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "from train import train\n",
    "import priors\n",
    "import encoders\n",
    "import positional_encodings\n",
    "import utils\n",
    "import bar_distribution\n",
    "import transformer\n",
    "\n",
    "from samlib.utils import chunker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "acf7423d",
   "metadata": {},
   "outputs": [],
   "source": [
    "kwargs = \\\n",
    "{\n",
    " 'nlayers': 6, \n",
    " 'dropout': 0.0, 'steps_per_epoch': 100, \n",
    "}\n",
    "    \n",
    "    \n",
    "def train_and_compare_fast_gp_mix(*args, **kwargs):\n",
    "    hps = kwargs['extra_prior_kwargs_dict']['hyperparameters']\n",
    "    num_features = kwargs['extra_prior_kwargs_dict']['num_features']\n",
    "    baseline_res = priors.fast_gp_mix.evaluate(\n",
    "        *args[0].get_batch_method(10000,kwargs['bptt'],num_features, hyperparameters=hps),\n",
    "        hyperparameters=hps, \n",
    "        use_mse=Losses.mse == args[2])\n",
    "    print(baseline_res, 'with fast_gp_mix')\n",
    "    \n",
    "    res = train(*args, **kwargs)\n",
    "    return res, baseline_res\n",
    "\n",
    "def train_and_compare_fast_gp(*args, num_evals=1000, **kwargs):\n",
    "    hps = kwargs['extra_prior_kwargs_dict']['hyperparameters']\n",
    "    num_features = kwargs['extra_prior_kwargs_dict']['num_features']\n",
    "    baseline_res = priors.fast_gp.evaluate(\n",
    "        *args[0].get_batch_method(num_evals,kwargs['bptt'],num_features, hyperparameters=hps, device='cpu'),\n",
    "        hyperparameters=hps, \n",
    "        use_mse=Losses.mse == args[2], device='cpu')\n",
    "    print(baseline_res, 'with fast_gp')\n",
    "    \n",
    "    res = train(*args, **kwargs)\n",
    "    return res, baseline_res\n",
    "\n",
    "def train_and_compare_gp(*args, num_evals=10000, **kwargs):\n",
    "    num_features = kwargs['extra_prior_kwargs_dict']['num_features']\n",
    "    baseline_res = priors.gp.evaluate(\n",
    "        *args[0].get_batch_method(num_evals,kwargs['bptt'],num_features),\n",
    "        use_mse=Losses.mse == args[2])\n",
    "    print(baseline_res, 'with fast_gp')\n",
    "    \n",
    "    res = train(*args, **kwargs)\n",
    "    return res, baseline_res\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "da083e24",
   "metadata": {},
   "outputs": [],
   "source": [
    "import gpytorch\n",
    "hps = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .6, 'fast_computations': (False,False,False)}\n",
    "\n",
    "import numpy as np, scipy.stats as st\n",
    "\n",
    "def compute_mean_and_conf_interval(accuracies, confidence=.95):\n",
    "    accuracies = np.array(accuracies)\n",
    "    n = len(accuracies)\n",
    "    m, se = np.mean(accuracies, -1), st.sem(accuracies, -1)\n",
    "    h = se * st.t.ppf((1 + confidence) / 2., n-1)\n",
    "    return m, h\n",
    "\n",
    "\n",
    "def bl(hps,bptt, num_evals=100, num_features=1, step_size=1, evals_per_batch=None, speedups=(False,False,False,False)):\n",
    "    if evals_per_batch is None:\n",
    "        evals_per_batch = num_evals\n",
    "    else:\n",
    "        assert num_evals%evals_per_batch == 0\n",
    "    results = []\n",
    "    for batch_i in range(num_evals//evals_per_batch):\n",
    "        with gpytorch.settings.fast_computations(False,False,False):\n",
    "            batch = priors.fast_gp.get_batch(evals_per_batch,bptt,num_features, hyperparameters=hps)\n",
    "        with gpytorch.settings.fast_pred_var(speedups[0]), gpytorch.settings.fast_computations(*speedups[1:]):\n",
    "            all_res, baseline_res,_ = priors.fast_gp.evaluate(\n",
    "                    *batch,\n",
    "                    hyperparameters=hps, step_size=step_size\n",
    "                    )\n",
    "        print(baseline_res, 'with fast_gp')\n",
    "        \n",
    "        results.append(all_res)\n",
    "    all_results = torch.cat(results,1) # seq x batch_size\n",
    "    return compute_mean_and_conf_interval(all_results) # mean array, var array\n",
    "    \n",
    "    \n",
    "#settings = [{'num_evals':n,} for n in [100,1000]]\n",
    "    \n",
    "#js = [ex.submit(bl, hps, 2000, step_size=100, evals_per_batch=2, num_features=5, **kwargs) for kwargs in settings]\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8088aa12",
   "metadata": {},
   "outputs": [],
   "source": [
    "# below you can simply replace the prior to priors.fast_gp_mix to do experiments over mixtures of GPs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "165e683c",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_features = 5\n",
    "hps = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .6, 'fast_computations': (False,False,False)}\n",
    "ys = priors.fast_gp.get_batch(100000,20,num_features, hyperparameters=hps)[1]\n",
    "fivefeature_jobs = [\n",
    "    train(priors.fast_gp.DataLoader, bar_distribution.FullSupportBarDistribution(bar_distribution.get_bucket_limits(num_borders, ys=ys)), enc, emsize=emsize, nhead=nhead, warmup_epochs=warmup_epochs, y_encoder_generator=y_enc, pos_encoder_generator=pos_enc,\n",
    "              batch_size=batch_size, scheduler=decay, extra_prior_kwargs_dict={'num_features': num_features, 'fuse_x_y': False, 'hyperparameters': hps},\n",
    "              epochs=epochs, lr=lr, input_normalization=input_norm, bptt=2010, single_eval_pos_gen=single_eval_pos,aggregate_k_gradients=step_every, **kwargs) \n",
    "    for enc in [encoders.Linear] for y_enc in [encoders.Linear] for emsize in [512] for nhead in [4] for nhid in [emsize*2] for epochs in [50*25,100*25,200*25,400*25] \n",
    "    for warmup_epochs in [epochs//4] for input_norm in [False]\n",
    "    for batch_size in [4] for step_every in [100//batch_size] for lr in [.0001,.0003,.001] for decay in [utils.get_cosine_schedule_with_warmup] for num_borders in [1000,10000] \n",
    "    for single_eval_pos in [utils.get_weighted_single_eval_pos_sampler(2000)]\n",
    "    for pos_enc in [positional_encodings.PositionalEncoding if single_eval_pos is None else positional_encodings.NoPositionalEncoding] \n",
    "    for redo in range(1)\n",
    "]\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "15d01f3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np, scipy.stats as st\n",
    "\n",
    "def compute_mean_and_conf_interval(accuracies, confidence=.95):\n",
    "    accuracies = np.array(accuracies)\n",
    "    n = len(accuracies)\n",
    "    m, se = np.mean(accuracies), st.sem(accuracies)\n",
    "    h = se * st.t.ppf((1 + confidence) / 2., n-1)\n",
    "    return m, h\n",
    "hps = {'noise': 1e-4, 'outputscale': 1., 'lengthscale': .6, 'fast_computations': (False,False,False)}\n",
    "\n",
    "@torch.inference_mode()\n",
    "def run_test(model,device='cuda:0',step_size=100, start_pos=1, batch_size=1000, sub_batch_size=10, seq_len=2000):\n",
    "    assert batch_size % sub_batch_size == 0\n",
    "    model.to(device)\n",
    "\n",
    "    model.eval()\n",
    "    nlls = []\n",
    "    nll_confidences = []\n",
    "    mses = []\n",
    "    max_mses = []\n",
    "    eval_positions = []\n",
    "    \n",
    "    def get_metrics(model, eval_pos, batch_size):\n",
    "        x,y, target_y = priors.fast_gp.get_batch(batch_size=batch_size, seq_len=eval_pos+1, num_features=5,hyperparameters=hps, device=device)\n",
    "        logits = model((x,y), single_eval_pos=eval_pos)\n",
    "        if isinstance(model.criterion,nn.GaussianNLLLoss):\n",
    "            nll = model.criterion(logits[0][...,0], target_y[eval_pos], var=logits[0][...,1].abs())\n",
    "            return nll, 0., 0.\n",
    "        means = model.criterion.mean(logits) # num_evals x batch_size\n",
    "        maxs = (model.criterion.borders[logits.argmax(-1)] + model.criterion.borders[logits.argmax(-1)+1])/2\n",
    "        mse = nn.MSELoss()\n",
    "        nll = model.criterion(logits[0], target_y[eval_pos])\n",
    "        return nll, mse(means[0], target_y[eval_pos]), mse(maxs[0], target_y[eval_pos])\n",
    "        \n",
    "    \n",
    "    for eval_pos in range(start_pos, seq_len, step_size):\n",
    "        eval_positions.append(eval_pos)\n",
    "        print(eval_pos)\n",
    "        \n",
    "        nll = []\n",
    "        mean_mse = []\n",
    "        max_mse = []\n",
    "        for i in range(batch_size//sub_batch_size):\n",
    "            batch_nll, batch_mean_mse, batch_max_mse = get_metrics(model, eval_pos, sub_batch_size)\n",
    "            nll.append(batch_nll)\n",
    "            mean_mse.append(batch_mean_mse)\n",
    "            max_mse.append(batch_max_mse)\n",
    "        \n",
    "        nll = torch.cat(nll)\n",
    "        mean_mse = torch.tensor(mean_mse).mean()\n",
    "        max_mse = torch.tensor(max_mse).mean()\n",
    "        \n",
    "        \n",
    "        mses.append(mean_mse)\n",
    "        max_mses.append(max_mse)\n",
    "        nlls.append(nll.mean())\n",
    "        nll_confidences.append(compute_mean_and_conf_interval(nll.to('cpu'))[1])\n",
    "    return eval_positions, torch.stack(mses).to('cpu'), torch.stack(max_mses).to('cpu'), torch.stack(nlls).to('cpu'), torch.tensor(nll_confidences).to('cpu')\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "755e88e4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
