{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3a737d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf \n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats\n",
    "from functools import partial\n",
    "%matplotlib inline\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dc79807",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy.linalg as LA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffe3946d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.layers import Input, Dense, Concatenate, Add\n",
    "from tensorflow.keras import Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17bfd292",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.config.list_physical_devices('GPU')\n",
    "from tensorflow.keras import initializers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "351b1474",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cart(x,y):\n",
    "    temp = np.zeros([x.shape[0]*y.shape[0],x.shape[1]+y.shape[1]])\n",
    "    #num1 = y.shape[1]\n",
    "    j=0\n",
    "    num2 = x.shape[0]\n",
    "    num1 = x.shape[1]\n",
    "    for i in range(temp.shape[0]):\n",
    "        if i!=0 and i%num2 == 0:\n",
    "            j+=1\n",
    "        temp[i,0:num1] = x[i%num2]\n",
    "        temp[i,num1::] = y[j]\n",
    "    return temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61690b60",
   "metadata": {},
   "outputs": [],
   "source": [
    "n1=100\n",
    "n2=50\n",
    "x1 = np.linspace(0,4,n1).reshape(n1,1)\n",
    "x2 = np.linspace(-0.03,0.03,n2).reshape(n2,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b19f71d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = cart(x1,x2)\n",
    "xt = cart(x,x1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f677cac7",
   "metadata": {},
   "outputs": [],
   "source": [
    "del x2\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afb8a68e",
   "metadata": {},
   "outputs": [],
   "source": [
    "xt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "151860f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "indt = np.where (np.abs(xt[:,0]-xt[:,2])<0.05)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed94209f",
   "metadata": {},
   "outputs": [],
   "source": [
    "indbad = np.where (np.abs(xt[:,0]-xt[:,2])>=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "243d47d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "xts= xt[indt[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7689481",
   "metadata": {},
   "outputs": [],
   "source": [
    "xtb= xt[indbad[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f8de09",
   "metadata": {},
   "outputs": [],
   "source": [
    "del xt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35257e9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "xts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe40a1b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_label = np.ones((xts.shape[0],1)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22b50abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5f946d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "xtb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfe8d05e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "V = tf.keras.Sequential()\n",
    "V.add(layers.Dense(10, use_bias=True, activation = 'relu',input_shape=(4,),kernel_regularizer=tf.keras.regularizers.L1(0.001),\n",
    "    kernel_initializer=initializers.RandomNormal(stddev=0.1),\n",
    "    bias_initializer=initializers.Zeros()))\n",
    "V.add(layers.Dense(20, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001),\n",
    "    kernel_initializer=initializers.RandomNormal(stddev=0.1),\n",
    "    bias_initializer=initializers.Zeros()))\n",
    "V.add(layers.Dense(20, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001),\n",
    "    kernel_initializer=initializers.RandomNormal(stddev=0.1),\n",
    "    bias_initializer=initializers.Zeros()))\n",
    "# V.add(layers.Dense(20, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001),\n",
    "#     kernel_initializer=initializers.RandomNormal(stddev=0.1),\n",
    "#     bias_initializer=initializers.Zeros()))\n",
    "# V.add(layers.Dense(20, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001),\n",
    "#     kernel_initializer=initializers.RandomNormal(stddev=0.1),\n",
    "#     bias_initializer=initializers.Zeros()))\n",
    "#model.add(layers.Dense(1024, use_bias=True, activation = 'tanh'))\n",
    "V.add(layers.Dense(1,use_bias=True, activation='sigmoid',kernel_regularizer=tf.keras.regularizers.L1(0.001),kernel_constraint=tf.keras.constraints.NonNeg()))\n",
    "#model.add(tf.keras.layers.Lambda(lambda x: x * 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3d4f554",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "C = tf.keras.Sequential()\n",
    "C.add(layers.Dense(100, use_bias=True, activation = 'relu',input_shape=(4,),kernel_regularizer=tf.keras.regularizers.L1(0.001)))\n",
    "C.add(layers.Dense(200, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001)))\n",
    "C.add(layers.Dense(200, use_bias=True, activation = 'relu'))\n",
    "C.add(layers.Dense(200, use_bias=True, activation = 'relu'))\n",
    "#model.add(layers.Dense(1024, use_bias=True, activation = 'tanh'))\n",
    "C.add(layers.Dense(1,use_bias=True, activation='linear',kernel_regularizer=tf.keras.regularizers.L1(0.0001))\n",
    ")\n",
    "#model.add(tf.keras.layers.Lambda(lambda x: x * 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26b3d63f",
   "metadata": {},
   "outputs": [],
   "source": [
    "tau = 0.1\n",
    "a = tf.constant([[1 ,tau], [0 ,1]])\n",
    "b = tf.constant([[(tau**2)/2],[tau]])\n",
    "set_p = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b989813",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = tf.keras.optimizers.Adam(5e-4)\n",
    "opt2 = tf.keras.optimizers.Adam(5e-5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b3e6ae4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_cond(v,c,x,xb,a,tau,gamma):\n",
    "    set_p=3\n",
    "    K=0.5\n",
    "    nmin=-0.2\n",
    "    nmax=0.2\n",
    "    \n",
    "    \n",
    "    xinp1=np.hstack((x,np.clip(-K*(x[:,2].reshape(x.shape[0],1)-set_p),nmin,nmax)))\n",
    "    xinp2=np.hstack((xb,np.clip(-K*(xb[:,2].reshape(xb.shape[0],1)-set_p),nmin,nmax)))\n",
    "    #xinp2 = np.copy(xt)\n",
    "    ind = np.where(v(xinp1,training = False)>=0.5)\n",
    "    x_1 = x[ind[0],0:2]\n",
    "    x_2 = x[ind[0],2]\n",
    "    u = c(xinp1[ind[0]], training = False)\n",
    "    u2 = np.clip(-K*(x_2-set_p),nmin,nmax)\n",
    "    h = np.abs(x_1[:,0]-x_2[:])\n",
    "    #ind2 = np.where((tf.reshape(v(xinp1[ind[0]]),ind[0].shape)-h)<0.5 )\n",
    "    y_pred1 = tf.matmul(tf.constant(x_1,dtype=tf.float32),tf.transpose(a)) + tf.reshape(tf.stack(([u*((tau**2)/2),u*tau]),axis=1),(x_1.shape[0], 2))\n",
    "    y_pred2 = tf.cast(tf.reshape(x_2 + tau * u2,(ind[0].shape[0],1)),tf.float32)\n",
    "    y_pred  = tf.concat([y_pred1,y_pred2],axis=1)\n",
    "    u2_pred = np.clip(-K*(y_pred2-set_p),nmin,nmax) \n",
    "    y_pred = tf.concat([y_pred,u2_pred],axis=1)\n",
    "    y_predtot = V(y_pred) \n",
    "    h2 = np.abs(tf.reshape(y_pred1[:,0],(y_pred1[:,0].shape[0],1)).numpy()-y_pred2[:].numpy())\n",
    "    h = np.abs(x_1[:,0]-x_2[:])\n",
    "    ind1 = np.where(y_predtot < 0.8)\n",
    "    \n",
    "    \n",
    "    return ind1[0].shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0deff667",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_lip2(m):\n",
    "    test = m.weights[0].numpy()\n",
    "    \n",
    "    nm = int(len(m.weights)/2)\n",
    "    \n",
    "    lip = 0\n",
    "    for i in range(2,nm*2,2):\n",
    "\n",
    "        \n",
    "       \n",
    "        \n",
    "        test = np.matmul(test,m.weights[i].numpy())\n",
    "        \n",
    "        if i!= nm*2-2:\n",
    "            test2 = m.weights[i].numpy()\n",
    "            for j in range(i+2,len(m.weights),2):\n",
    "                \n",
    "                test2 = np.matmul(test2,m.weights[j].numpy())\n",
    "\n",
    "           \n",
    "            eigenvalues2, _ = LA.eig(np.matmul(test2.T,test2))\n",
    "        else:\n",
    "            eigenvalues2=1\n",
    "        eigenvalues1, _ = LA.eig(np.matmul(test.T,test))\n",
    "        \n",
    "        \n",
    "        lip+= np.sqrt(np.abs(np.max(eigenvalues1 ))) * np.sqrt(np.abs(np.max(eigenvalues2 )))\n",
    "    \n",
    "    return lip / (np.power(2,nm-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e541bb2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "xinp1=np.hstack((xts,np.clip(-K*(xts[:,2].reshape(xts.shape[0],1)-set_p),nmin,nmax)))\n",
    "np.min(V(xinp1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e7723b",
   "metadata": {},
   "outputs": [],
   "source": [
    "cce = tf.keras.losses.BinaryCrossentropy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "568cf814",
   "metadata": {},
   "outputs": [],
   "source": [
    "cce(V(xinp1),y_label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5f4b198",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta=0.1\n",
    "check_cond(V,C,xts,xtb,a,0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34a907d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "C(xinp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f90c0916",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 50000\n",
    "nmin=-0.2\n",
    "nmax=0.2\n",
    "er=1\n",
    "delta = 0.5\n",
    "\n",
    "switch = False\n",
    "k = 0\n",
    "t = 0\n",
    "K=0.5\n",
    "eps=0.05 \n",
    "ma = 1\n",
    "for i in range(epochs):\n",
    "    lipc = calculate_lip2(C)\n",
    "    lipv = calculate_lip2(V)\n",
    "    gamma = eps * lipc\n",
    "    no = check_cond(V,C,xts,xtb,a,0.1,gamma)\n",
    "    if no==0 and ma <eps :\n",
    "        break\n",
    "    \n",
    "    if i%500 ==0:\n",
    "        no = check_cond(V,C,xts,xtb,a,0.1,gamma)\n",
    "        print(er,i,no, ma)\n",
    "        switch=not(switch)\n",
    "        \n",
    "        \n",
    "        if switch:\n",
    "            \n",
    "            print('C')\n",
    "        else :\n",
    "            \n",
    "            print('V')\n",
    "    elif i%100 ==0:\n",
    "        no = check_cond(V,C,xts,xtb,a,0.1,gamma)\n",
    "        print(er,i,no,ma)\n",
    "        \n",
    "    if switch:\n",
    "        with tf.GradientTape() as C_tape:\n",
    "            \n",
    "            \n",
    "            xinp=np.hstack((xts,np.clip(-K*(xts[:,2].reshape(xts.shape[0],1)-set_p),nmin,nmax)))\n",
    "            #xinp2 = np.copy(xt)\n",
    "            ind = np.where(V(xinp)>=delta)\n",
    "            x_1 = xts[ind[0],0:2]\n",
    "            x_2 = xts[ind[0],2]\n",
    "            u = C(xinp[ind[0]],training=True)\n",
    "            u2 = tf.constant(-K*(x_2-set_p),tf.float32)\n",
    "#             x_1 = xt[:,0:2]\n",
    "#             x_2 = xt[:,2]\n",
    "#             u = C(xinp)\n",
    "            u2 = tf.constant(np.clip(-K*(x_2-set_p),nmin,nmax),tf.float32)\n",
    "            \n",
    "            y_pred1 = tf.matmul(tf.constant(x_1,dtype=tf.float32),tf.transpose(a)) + tf.reshape(tf.stack(([u*((tau**2)/2),u*tau]),axis=1),(x_1.shape[0], 2))\n",
    "            y_pred2 = tf.cast(tf.reshape(x_2 + tau * u2,(x_2.shape[0],1)),tf.float32)\n",
    "            #y_pred1 = tf.clip_by_value(y_pred1, 0, 4)\n",
    "            #y_pred2 = tf.clip_by_value(y_pred2, -0.3,0.3)\n",
    "            \n",
    "            y_pred  = tf.concat([y_pred1,y_pred2],axis=1)\n",
    "            h = np.abs(tf.reshape(y_pred1[:,0],(y_pred1[:,0].shape[0],1)).numpy()-y_pred2[:].numpy())\n",
    "            u2_pred = np.clip(-K*(y_pred2-set_p),nmin,nmax)\n",
    "            y_pred = tf.concat([y_pred,u2_pred],axis=1)\n",
    "            y_predtot = V(y_pred) \n",
    "            ind1 = np.where(V(y_pred) < delta)\n",
    "            ind4= np.where(h<eps)\n",
    "            ma = np.max(h)\n",
    "            wh=np.argmax(h)\n",
    "            n1 = tf.gather(y_pred1, wh)\n",
    "            n2 = tf.gather(y_pred2,wh)\n",
    "            \n",
    "            loss_c=tf.keras.losses.mean_squared_error(tf.reshape(tf.constant(1.0),1),tf.reshape(tf.constant(1.0),1))\n",
    "            #print(ind4[0])\n",
    "            \n",
    "            loss_c+= tf.keras.losses.mean_absolute_error(tf.reshape(y_pred1[:,0],y_pred1.shape[0]), tf.reshape(y_pred2,y_pred2.shape[0]))\n",
    "            ave = np.copy(loss_c)\n",
    "            loss_c += 0.071*tf.keras.losses.mean_squared_error(tf.reshape(y_pred1[:,1],y_pred1.shape[0]),0.0)\n",
    "            #print(loss_c)\n",
    "#             if ind1[0].size!=0:\n",
    "                \n",
    "#                 loss_c += tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(y_predtot,ind1[0]),ind1[0].shape),delta-gamma)\n",
    "                \n",
    "#             #loss_c = tf.keras.losses.mean_squared_error(tf.reshape(y_pred1[:,0],y_pred1.shape[0]), tf.reshape(y_pred2,y_pred2.shape[0]))\n",
    "#             #loss_c += tf.keras.losses.mean_squared_error(y_pred[:,0]-tf.reshape(y_pred2,y_pred2.shape[0]),0.0)\n",
    "#             elif ind4[0].size!=0:\n",
    "                \n",
    "#                 #loss_c += tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(y_predtot,ind4[0])-h[ind4[0]]-gamma,ind4[0].shape[0]),0.0)\n",
    "#             else: \n",
    "#                 pass\n",
    "            \n",
    "            if loss_c.numpy()==0 or tf.math.is_nan(loss_c).numpy():\n",
    "                switch = not(switch)\n",
    "                print(f\"switch {i} to V\")\n",
    "            else:\n",
    "                er = np.copy(loss_c)\n",
    "\n",
    "                gradsc = C_tape.gradient(loss_c,C.trainable_variables)\n",
    "                #capped_gradsc =  [(tf.clip_by_norm(grad, 1)) for grad in gradsc]\n",
    "                #print(capped_gradsc)\n",
    "                opt.apply_gradients(zip(gradsc,C.trainable_variables))\n",
    "                #del y_pred1, y_pred2\n",
    "                k+=1\n",
    "\n",
    "            \n",
    "            \n",
    "            \n",
    "    else:\n",
    "            with tf.GradientTape() as V_tape:\n",
    "                no = check_cond(V,C,xts,xtb,a,0.1,gamma)\n",
    "                \n",
    "                xinp1=np.hstack((xts,np.clip(-K*(xts[:,2].reshape(xts.shape[0],1)-set_p),nmin,nmax)))\n",
    "                xinp2=np.hstack((xtb,np.clip(-K*(xtb[:,2].reshape(xtb.shape[0],1)-set_p),nmin,nmax)))\n",
    "                if i>=0:\n",
    "                    ind = np.where(V(xinp1)>=delta)\n",
    "                    indbad = np.where(V(xinp2)<delta)\n",
    "                    v = V(xinp1[ind[0]],training = True)\n",
    "                    v2 = V(xinp2[indbad[0]],training = True)\n",
    "                    x_1 = xts[ind[0],0:2]\n",
    "                    x_2 = xts[ind[0],2]\n",
    "                    u = C(xinp1[ind[0]],training=True)\n",
    "                    u2 = tf.constant(np.clip(-K*(x_2-set_p),nmin,nmax),tf.float32)\n",
    "                    y_true = y_label[ind[0]]\n",
    "               \n",
    "#                 ind = np.where(V(xinp)<delta)\n",
    "#                 v = V(xinp[ind[0]],training = True)\n",
    "#                 x_1 = xt[ind[0],0:2]\n",
    "#                 x_2 = xt[ind[0],2]\n",
    "#                 u = C(xinp[ind[0]],training=True)\n",
    "#                 u2 = tf.constant(-K*(x_2-set_p),tf.float32)\n",
    "                else:\n",
    "                    v = V(xinp,training = True)\n",
    "                    x_1 = xt[:,0:2]\n",
    "                    x_2 = xt[:,2]\n",
    "                    u = C(xinp,training=False)\n",
    "                    u2 = tf.constant(np.clip(-K*(x_2-set_p),nmin,nmax),tf.float32)\n",
    "#                 x_1 = xt[:,0:2]\n",
    "#                 x_2 = xt[:,2]\n",
    "#                 u = C(xt[:])\n",
    "            \n",
    "                \n",
    "                y_pred1 = tf.matmul(tf.constant(x_1,dtype=tf.float32),tf.transpose(a)) + tf.reshape(tf.stack(([u*((tau**2)/2),u*tau]),axis=1),(x_1.shape[0], 2))\n",
    "                y_pred2 =  tf.cast(tf.reshape(x_2 + tau * u2,(x_2.shape[0],1)),tf.float32)\n",
    "                #y_pred1 = tf.clip_by_value(y_pred1, 0, 4)\n",
    "                #y_pred2 = tf.clip_by_value(y_pred2, -0.3,0.3)\n",
    "                \n",
    "                y_pred  = tf.concat([y_pred1,y_pred2],axis=1)\n",
    "                u2_pred = np.clip(-K*(y_pred2-set_p),nmin,nmax)\n",
    "                y_pred = tf.concat([y_pred,u2_pred],axis=1)\n",
    "                y_predtot = V(y_pred, training = True) \n",
    "                \n",
    "                \n",
    "                \n",
    "                #loss = tf.keras.losses.mean_squared_error(tf.reshape(tf.constant(1.0),1),tf.reshape(tf.constant(1.0),1))\n",
    "                loss = cce(y_predtot,y_true)\n",
    "               # if ind3[0].size!=0:\n",
    "                    \n",
    "                #loss = tf.keras.losses.mean_squared_error(tf.reshape(v,v.shape[0])-tf.reshape(y_predtot,v.shape[0])-gamma,0.0)\n",
    "                    #loss += tf.keras.losses.mean_absolute_error(tf.reshape(y_predtot2,y_predtot2.shape[0]),0.0)\n",
    "#                 elif ind4[0].size!=0:\n",
    "                    \n",
    "#                     loss += tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(y_predtot,ind4[0])-h[ind4[0]]-gamma,ind4[0].shape[0]),0.0)\n",
    "                #loss += tf.keras.losses.mean_squared_error(tf.reshape(v,v.shape[0]),0.0)\n",
    "                h = np.abs(x_1[:,0]-x_2[:]).reshape(x_1.shape[0],1)\n",
    "#                 ind5= np.where(v - h <0)\n",
    "#                 if ind5[0].size!=0:\n",
    "#                     loss+= tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(v,ind5[0])-h[ind5[0]]-gamma,ind5[0].shape[0]),0.0)\n",
    "#                 if indbad[0].size!=0:\n",
    "#                     pass\n",
    "# #                     try:\n",
    "# #                         batch = np.random.choice(indbad[0].shape[0], ind[0].shape[0])\n",
    "# #                         #loss+= tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(v2,batch)-gamma,batch.shape[0]),delta)\n",
    "# #                     except:\n",
    "# #                         #loss+= tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(v2,indbad[0])-gamma,indbad[0].shape[0]),delta)\n",
    "                \n",
    "                er = np.copy(loss)\n",
    "                if loss==0:\n",
    "                    switch = not(switch)\n",
    "                else:\n",
    "                    grads = V_tape.gradient(loss,V.trainable_variables)\n",
    "\n",
    "                    #capped_grads = [(tf.clip_by_norm(grad, 1)) for  grad in grads]\n",
    "\n",
    "                    opt2.apply_gradients(zip(grads,V.trainable_variables))\n",
    "                #del y_pred1, y_pred2\n",
    "                t+=1\n",
    "                #ind4= np.where(tf.reshape(V(xt[ind[0]]),ind[0].shape)- h <0)\n",
    "                #loss_c = tf.keras.losses.mean_squared_error(tf.reshape(V(xt[ind4[0]]),ind4[0].shape)-h[ind4[0]],0.0)\n",
    "                #ind3 = np.where(V(y_pred)>=delta)\n",
    "                #loss_c += tf.keras.losses.mean_squared_error(tf.reshape(tf.gather(y_predtot,ind3[0]),ind3[0].shape),0.0)\n",
    "                \n",
    "                #u = C(xt,training = False)\n",
    "                #x_1 = xt[:,0:2]\n",
    "                #x_2 = xt[:,2]\n",
    "                #y_pred1 = tf.matmul(tf.constant(x_1,dtype=tf.float32),tf.transpose(a)) + tf.reshape(((tau**2)/2)*u[:,0],(xt.shape[0],1) )\n",
    "                #y_pred2 = tf.reshape(tf.constant(x_2,dtype=tf.float32) + tau * u[:,1],(xt.shape[0],1))\n",
    "                #y_pred = np.hstack([y_pred1,y_pred2])\n",
    "                #y_predtot = V(y_pred) \n",
    "                #ind = np.where (y_predtot >= delta)\n",
    "                #ind2 = np.where (y_predtot < delta)\n",
    "#                 if ind[0].shape[0]==0 :\n",
    "#                     switch=not(switch)\n",
    "#                     print(f\"switch {i} to C\")\n",
    "#                 else:\n",
    "#                     #loss_c = tf.keras.losses.mean_squared_error(tf.reshape(V(xt[ind4[0]]),ind4[0].shape)-h[ind4[0]],0.0)\n",
    "#                     #loss = tf.reduce_mean(tf.sqrt(tf.square(tf.subtract(V(xt[ind[0]],training = True),delta))))\n",
    "#                     loss = tf.keras.losses.mean_squared_error(tf.reshape(V(xt[ind4[0]]),ind4[0].shape)-h[ind4[0]]-gamma,0.0)\n",
    "#                     if tf.math.is_nan(loss).numpy():\n",
    "#                         print(f\"switch {i}\")\n",
    "#                         switch = not(switch)\n",
    "#                     else:\n",
    "#                         "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb327206-296e-42c5-9756-08823b59400c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
