{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import jax.numpy as jnp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = np.random.randn(20, 5)\n",
    "w = np.random.randn(5)/np.sqrt(5)\n",
    "y = np.sign(X@w)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = jnp.array(X, dtype=jnp.float32)\n",
    "w = jnp.array(w, dtype=jnp.float32)\n",
    "y = jnp.array(y, dtype=jnp.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = X[0:16,:]\n",
    "y_train = y[0:16]\n",
    "X_val = X[16:20,:]\n",
    "y_val = y[16:20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax.random as jrn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = jrn.PRNGKey(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from two_lay_cvx_classifier import Two_Lay_CVX_Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = Two_Lay_CVX_Classifier(X, y, 40, 'CReLU', beta=10**-3, seed=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Beginning model fit\n",
      "Train Loss: 0.0014320282498374581\n",
      "\n",
      "Train Accuracy: 100.0\n",
      "\n",
      "Validation Loss: 0.7600662708282471\n",
      "\n",
      "Validation Accuracy: 75.0\n",
      "\n",
      "Train Loss: 0.001328754355199635\n",
      "\n",
      "Train Accuracy: 100.0\n",
      "\n",
      "Validation Loss: 0.08974400907754898\n",
      "\n",
      "Validation Accuracy: 100.0\n",
      "\n",
      "Train Loss: 0.001331092556938529\n",
      "\n",
      "Train Accuracy: 100.0\n",
      "\n",
      "Validation Loss: 0.0431893914937973\n",
      "\n",
      "Validation Accuracy: 100.0\n",
      "\n",
      "Train Loss: 0.0012467781780287623\n",
      "\n",
      "Train Accuracy: 100.0\n",
      "\n",
      "Validation Loss: 0.037912558764219284\n",
      "\n",
      "Validation Accuracy: 100.0\n",
      "\n",
      "Train Loss: 0.001269118976779282\n",
      "\n",
      "Train Accuracy: 100.0\n",
      "\n",
      "Validation Loss: 0.033396437764167786\n",
      "\n",
      "Validation Accuracy: 100.0\n",
      "\n",
      "Train Loss: 0.0012457756092771888\n",
      "\n",
      "Train Accuracy: 100.0\n",
      "\n",
      "Validation Loss: 0.03339178115129471\n",
      "\n",
      "Validation Accuracy: 100.0\n",
      "\n",
      "Validation accuracy is flat or increasing. CRONOS will now terminate\n",
      "Finished fitting model\n"
     ]
    }
   ],
   "source": [
    "clf.fit(0.1, 100, 10, 20, Xval=X_val, yval=y_val, check_opt=False, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Array(100., dtype=float32)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.score(X,y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'y_hat' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[11], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m jnp\u001b[38;5;241m.\u001b[39mcount_nonzero(jnp\u001b[38;5;241m.\u001b[39msign(\u001b[43my_hat\u001b[49m)\u001b[38;5;241m==\u001b[39my)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'y_hat' is not defined"
     ]
    }
   ],
   "source": [
    "jnp.count_nonzero(jnp.sign(y_hat)==y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "L = [1, 2, 3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "L[-1]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
