{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3a737d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf \n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats\n",
    "from functools import partial\n",
    "%matplotlib inline\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d8c8768",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy.linalg as LA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ffe3946d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.layers import Input, Dense, Concatenate, Add\n",
    "from tensorflow.keras import Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17bfd292",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.config.list_physical_devices('GPU')\n",
    "from tensorflow.keras import initializers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "351b1474",
   "metadata": {},
   "outputs": [],
   "source": [
    "def cart(x,y):\n",
    "    temp = np.zeros([x.shape[0]*y.shape[0],x.shape[1]+y.shape[1]])\n",
    "    #num1 = y.shape[1]\n",
    "    j=0\n",
    "    num2 = x.shape[0]\n",
    "    num1 = x.shape[1]\n",
    "    for i in range(temp.shape[0]):\n",
    "        if i!=0 and i%num2 == 0:\n",
    "            j+=1\n",
    "        temp[i,0:num1] = x[i%num2]\n",
    "        temp[i,num1::] = y[j]\n",
    "    return temp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61690b60",
   "metadata": {},
   "outputs": [],
   "source": [
    "n1=25\n",
    "x1 = np.linspace(-0.5,0.5,n1).reshape(n1,1)\n",
    "x2 = np.linspace(-0.5,0.5,int(n1)).reshape(int(n1),1)\n",
    "x2s = np.linspace(-0.05,0.05,int(n1)).reshape(int(n1),1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcc308f4-f4da-42a7-a63f-a06f36346a11",
   "metadata": {},
   "outputs": [],
   "source": [
    "n2=30\n",
    "u = np.linspace(-0.5,0.5,n2).reshape(n2,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28a67069-b002-4459-895a-fb7c2269ef32",
   "metadata": {},
   "outputs": [],
   "source": [
    "u"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51ba7f6d-a066-4bd8-9f59-76253183737b",
   "metadata": {},
   "outputs": [],
   "source": [
    "x1,x2,x2s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b19f71d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "x = cart(x1,x2)\n",
    "x3 = cart(x2s,x2s)\n",
    "xt = cart(x,x3)\n",
    "#xt2=cart(x1,x2)\n",
    "#xt = cart(xt,xt2)\n",
    "#xt = cart(xt,u)\n",
    "\n",
    "del x,x1,x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30fd8deb",
   "metadata": {},
   "outputs": [],
   "source": [
    "xt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "243d47d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "eps= 0.2\n",
    "xts= xt[np.where (np.abs(xt[:,1]-xt[:,5])<=eps ) ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35257e9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "xts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f077cea-a569-4eac-a795-af57807771de",
   "metadata": {},
   "outputs": [],
   "source": [
    "xts = xts[np.where(np.abs(xts[:,0]-xts[:,4])<=eps )]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "91c538ba-d19f-4260-b9d9-2a1f654e8db2",
   "metadata": {},
   "outputs": [],
   "source": [
    "xts.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "911cb417",
   "metadata": {},
   "outputs": [],
   "source": [
    "xts[12143]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e74950e9-2167-44fc-ab96-3aa70dfd43d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#xts = np.stack(np.array_split(xts,5))\n",
    "xts = np.stack(np.array_split(xt,1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99916c31-922b-4ce2-aac6-96550646a59e",
   "metadata": {},
   "outputs": [],
   "source": [
    "xts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe40a1b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_label = np.ones((xts.shape[1],1)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22b50abc",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_label.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfe8d05e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "V = tf.keras.Sequential()\n",
    "V.add(layers.Dense(10, use_bias=True, activation = 'relu',input_shape=(6,),kernel_regularizer=tf.keras.regularizers.L1(0.001),\n",
    "    kernel_initializer=initializers.RandomNormal(stddev=0.1),\n",
    "    bias_initializer=initializers.Zeros()))\n",
    "V.add(layers.Dense(20, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001),\n",
    "    kernel_initializer=initializers.RandomNormal(stddev=0.1),\n",
    "    bias_initializer=initializers.Zeros()))\n",
    "V.add(layers.Dense(20, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001),\n",
    "    kernel_initializer=initializers.RandomNormal(stddev=0.1),\n",
    "    bias_initializer=initializers.Zeros()))\n",
    "# V.add(layers.Dense(20, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001),\n",
    "#     kernel_initializer=initializers.RandomNormal(stddev=0.1),\n",
    "#     bias_initializer=initializers.Zeros()))\n",
    "# V.add(layers.Dense(20, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001),\n",
    "#     kernel_initializer=initializers.RandomNormal(stddev=0.1),\n",
    "#     bias_initializer=initializers.Zeros()))\n",
    "#model.add(layers.Dense(1024, use_bias=True, activation = 'tanh'))\n",
    "V.add(layers.Dense(1,use_bias=True, activation='sigmoid',kernel_regularizer=tf.keras.regularizers.L1(0.001),kernel_constraint=tf.keras.constraints.NonNeg()))\n",
    "#model.add(tf.keras.layers.Lambda(lambda x: x * 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3d4f554",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "C = tf.keras.Sequential()\n",
    "C.add(layers.Dense(10, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001) ,input_shape=(4,)))\n",
    "C.add(layers.Dense(30, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001) ))\n",
    "C.add(layers.Dense(300, use_bias=True, activation = 'relu',kernel_regularizer=tf.keras.regularizers.L1(0.001) ))\n",
    "C.add(layers.Dense(300, use_bias=True, activation ='relu',kernel_regularizer=tf.keras.regularizers.L1(0.001) ))\n",
    "#model.add(layers.Dense(1024, use_bias=True, activation = 'tanh'))\n",
    "C.add(layers.Dense(2,use_bias=True, activation='linear',kernel_regularizer=tf.keras.regularizers.L1(0.001)))\n",
    "#model.add(tf.keras.layers.Lambda(lambda x: x * 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7fd81c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "V(xts[0,:,:6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8e637ec-a634-4467-8860-b34623402cc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "C(xts[0,:,:])[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26b3d63f",
   "metadata": {},
   "outputs": [],
   "source": [
    "C(xts[0,:,:])[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d921b4c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "x1 = xts[:,0].reshape(xts.shape[0],1)\n",
    "x2 = xts[:,1].reshape(xts.shape[0],1)\n",
    "x3 = xts[:,2].reshape(xts.shape[0],1)\n",
    "x4 = xts[:,3].reshape(xts.shape[0],1)\n",
    "\n",
    "x1n = 1.0*x1 + tau* x2\n",
    "#x2n = x2 + tau *(-3.447*np.power(x1,3)+2.35*np.power(x1,2)*x3+1.303*np.power(x3,2)*x1+3.939*np.power(x3,3)+21.52*x1 -5*x3 + 8*u1-31.2*u2)\n",
    "x2n = 1.0*x2 + tau *(-3.447*np.power(1.0*x1,3)+2.35*np.power(1.0*x1,2)*x3+1.303*np.power(1.0*x3,2)*x1+3.939*np.power(1.0*x3,3)+21.52*x1 -5.0*x3 )\n",
    "x3n = 1.0*x3+ tau*x4\n",
    "x4n = 1.0*x4 + tau*(4.023*np.power(1.0*x1,3)-36.551*np.power(1.0*x1,2)*x3-4.131*np.power(1.0*x2,2)*x3-27.06*np.power(1.0*x3,3)-25.115*x1+77.7*x3)\n",
    "#x4n = x4 + tau*(4.023*np.power(x1,3)-36.551*np.power(x1,2)*x3-4.131*np.power(x2,2)*x3-27.06*np.power(x3,3)-25.115*x1+77.7*x3-31.2*u1+391.2*u2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a00e923",
   "metadata": {},
   "outputs": [],
   "source": [
    "x1n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "162afc42",
   "metadata": {},
   "outputs": [],
   "source": [
    "x2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eb37a5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "x1ab = xts[:,4]\n",
    "x2ab = xts[:,5]\n",
    "\n",
    "x1nab = x1ab + tau*x2ab\n",
    "x2nab = x2ab + tau*(-5.131*np.power(x1ab,3)+32.1*x1+9.1*uab) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b989813",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = tf.keras.optimizers.Adam(5e-5)\n",
    "\n",
    "opt2 = tf.keras.optimizers.Adam(5e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b3e6ae4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_cond(v,c,x,tau,eta):\n",
    "    delta = 0.5\n",
    "    \n",
    "    ind = np.where(V(x[:,:6])>=delta)\n",
    "    #x_1 = xts[ind[0],0:4]\n",
    "    #x_2 = xts[ind[0],4:6]\n",
    "    u1 = C(x[ind[0]])[:,0].numpy().reshape(ind[0].shape[0],1)\n",
    "    u2 = C(x[ind[0]])[:,1].numpy().reshape(ind[0].shape[0],1)\n",
    "    u = x[ind[0],6].reshape(ind[0].shape[0],1)\n",
    "    #u=ctrl_nn(torch.tensor(x[ind[0],4:6].reshape(ind[0].shape[0],2))).detach().numpy()\n",
    "    x1 = x[ind[0],0].reshape(ind[0].shape[0],1)\n",
    "    x2 = x[ind[0],1].reshape(ind[0].shape[0],1)\n",
    "    x3 = x[ind[0],2].reshape(ind[0].shape[0],1)\n",
    "    x4 = x[ind[0],3].reshape(ind[0].shape[0],1)\n",
    "    #u1= xts[:,6].reshape(xts.shape[0],1)\n",
    "    #u2 = xts[:,7].reshape(xts.shape[0],1)\n",
    "    \n",
    "    \n",
    "    \n",
    "    x1n = 1.0*x1 + tau* x2\n",
    "    #x2n = x2 + tau *(-3.447*np.power(x1,3)+2.35*np.power(x1,2)*x3+1.303*np.power(x3,2)*x1+3.939*np.power(x3,3)+21.52*x1 -5*x3 + 8*u1-31.2*u2)\n",
    "    x2n = 1.0*x2 + tau *(-3.447*np.power(1.0*x1,3)+2.35*np.power(1.0*x1,2)*x3+1.303*np.power(1.0*x3,2)*x1+3.939*np.power(1.0*x3,3)+21.52*x1 -5.0*x3+8*u1-31.2*u2)\n",
    "    x3n = 1.0*x3+ tau*x4\n",
    "    x4n = 1.0*x4 + tau*(4.023*np.power(1.0*x1,3)-36.551*np.power(1.0*x1,2)*x3-4.131*np.power(1.0*x2,2)*x3-27.06*np.power(1.0*x3,3)-25.115*x1+77.7*x3+-31.2*u1+392*u2)\n",
    "    \n",
    "    y_pred1 = x4n\n",
    "    \n",
    "    x1sn =x[ind[0],4].reshape(ind[0].shape[0],1) +  tau*x[ind[0],5].reshape(ind[0].shape[0],1)\n",
    "     \n",
    "    y_pred2 = x[ind[0],5].reshape(ind[0].shape[0],1) + tau*(9.8*np.sin(x[ind[0],4].reshape(ind[0].shape[0],1))+u)\n",
    "    \n",
    "    y_pred  = tf.concat([y_pred1,y_pred2],axis=1)\n",
    "    #h = np.abs(tf.reshape(y_pred1[:,0],(y_pred1[:,0].shape[0],1)).numpy()-y_pred2[:].numpy())\n",
    "    \n",
    "    y_pred = tf.concat( (x1n,x2n,x3n,x4n,x1sn,y_pred2),axis=1)\n",
    "    y_predtot = V(y_pred) \n",
    "    ind1 = np.where(y_predtot<delta+eta)\n",
    "    \n",
    "    return ind1[0].shape[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1d76acb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_lip2(m):\n",
    "    test = m.weights[0].numpy()\n",
    "    \n",
    "    nm = int(len(m.weights)/2)\n",
    "    \n",
    "    lip = 0\n",
    "    for i in range(2,nm*2,2):\n",
    "\n",
    "        \n",
    "       \n",
    "        \n",
    "        test = np.matmul(test,m.weights[i].numpy())\n",
    "        \n",
    "        if i!= nm*2-2:\n",
    "            test2 = m.weights[i].numpy()\n",
    "            for j in range(i+2,len(m.weights),2):\n",
    "                \n",
    "                test2 = np.matmul(test2,m.weights[j].numpy())\n",
    "\n",
    "           \n",
    "            eigenvalues2, _ = LA.eig(np.matmul(test2.T,test2))\n",
    "        else:\n",
    "            eigenvalues2=1\n",
    "        eigenvalues1, _ = LA.eig(np.matmul(test.T,test))\n",
    "        \n",
    "        \n",
    "        lip+= np.sqrt(np.abs(np.max(eigenvalues1 ))) * np.sqrt(np.abs(np.max(eigenvalues2 )))\n",
    "    \n",
    "    return lip / (np.power(2,nm-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57e7723b",
   "metadata": {},
   "outputs": [],
   "source": [
    "cce = tf.keras.losses.BinaryCrossentropy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5f4b198",
   "metadata": {},
   "outputs": [],
   "source": [
    "tau = 0.01\n",
    "check_cond(V,C,xts[1,:,:],tau,eta)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbcdf1fa-5985-4cd4-9e82-b5322d3bd3df",
   "metadata": {},
   "outputs": [],
   "source": [
    "te"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1673568-32f7-40ba-9cef-4e9e72607f18",
   "metadata": {},
   "outputs": [],
   "source": [
    "tau = tf.constant(0.0005)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bea2de6f-b21d-4705-af8f-1fc2a256a824",
   "metadata": {},
   "outputs": [],
   "source": [
    "tau"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e22250b-e753-4e99-87d5-d46e877372b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.math.pow(3.0,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8db9365e-e11e-4255-adc8-2983294525fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "ind1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6863135f-10e0-49f2-8866-0a134fa219bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21d200e1-1c7e-41ad-b8d8-695e134fb27e",
   "metadata": {},
   "outputs": [],
   "source": [
    "x1n.shape , x2n.shape, x3n.shape,x4n.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03ef00ce-afe7-4fae-9b24-bffc015ebbe4",
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.concat( (x1n,x2n,x3n,x4n,x1sn,y_pred2),axis=1).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecff4b27-fdd8-4a68-aee6-9f17dcac7fad",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.sin(x1s).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c8c3cf3-59a2-4e1a-bcef-9e7b4fa22ed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.Tensor(np.array([0.0,0.0]).reshape(1,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abd7b766-46d0-442e-8d64-8698565772c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "ctrl_nn(torch.Tensor(np.array([0,0]).reshape(1,2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36cf5523-e20d-446a-8f59-5cd403b33b36",
   "metadata": {},
   "outputs": [],
   "source": [
    "ctrl_nn(torch.tensor(xts[j,:,0:2].numpy(),dtype=torch.float64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36c17871-0edb-4f9d-b6ef-9637ef219db4",
   "metadata": {},
   "outputs": [],
   "source": [
    "xts[j,:,0:2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f90c0916",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 50000\n",
    "xts = tf.convert_to_tensor(xts,dtype=tf.float32)\n",
    "\n",
    "nmin=-0.2\n",
    "nmax=0.2\n",
    "er=1\n",
    "delta = 0.5\n",
    "eta = 0.1\n",
    "eta2 = 0.005\n",
    "tau = tf.constant(0.01,dtype=tf.float32)\n",
    "alpha = tf.constant(1,dtype=tf.float32)\n",
    "switch = True\n",
    "k = 0\n",
    "t = 0\n",
    "K=0.5\n",
    "\n",
    "h1c = 1\n",
    "h2c =1\n",
    "no =1 \n",
    "gamma = 0.1\n",
    "mult = True\n",
    "for i in range(epochs):\n",
    "    lipc = calculate_lip2(C)\n",
    "    eta = eps * lipc\n",
    "    if mult:\n",
    "        no = 0\n",
    "        for j in range(xts.shape[0]):\n",
    "            pass\n",
    "            #no += check_cond(V,C,xts[j,:,:],tau.numpy())\n",
    "        if no==0 and h1c <=eps  and ind2tot==0:\n",
    "            break\n",
    "        \n",
    "        if i%50 ==0:\n",
    "            #no = check_cond(V,C,xts,0.01)\n",
    "            print(er,i,no, h1c,h2c,C(np.array([0,0,0,0]).reshape(1,4)))\n",
    "            #switch=not(switch)\n",
    "            \n",
    "            \n",
    "            if switch:\n",
    "                \n",
    "                print('C')\n",
    "            else :\n",
    "                \n",
    "                print('V')\n",
    "        \n",
    "    else:\n",
    "        no = check_cond(V,C,xts,0.01,eta)\n",
    "        if no==0 and h1c <=h2c and h2c<=eps and ind2tot==0:\n",
    "            break\n",
    "        \n",
    "        if i%500 ==0:\n",
    "            no = check_cond(V,C,xts,0.01,eta)\n",
    "            print(er,i,no, h1,ind2tot)\n",
    "            switch=not(switch)\n",
    "            \n",
    "            \n",
    "            if switch:\n",
    "                \n",
    "                print('C')\n",
    "            else :\n",
    "                \n",
    "                print('V')\n",
    "        elif i%100 ==0:\n",
    "            no = check_cond(V,C,xts,0.01,eta)\n",
    "            print(er,i,no,h1,h2)\n",
    "        \n",
    "    if switch:\n",
    "        with tf.GradientTape()  as C1_tape,  tf.GradientTape() as C2_tape:\n",
    "            \n",
    "            if mult:\n",
    "                H1=[]\n",
    "                H2=[]\n",
    "                ind2tot=0\n",
    "                loss_c=tf.keras.losses.mean_squared_error(tf.reshape(tf.constant(1.0),1),tf.reshape(tf.constant(1.0),1))\n",
    "                for j in range(xts.shape[0]):\n",
    "                    \n",
    "                    #loss_c=tf.keras.losses.mean_squared_error(tf.reshape(tf.constant(1.0),1),tf.reshape(tf.constant(1.0),1))\n",
    "                    #ind = np.where(V(xts[j,:,:6])>=delta)\n",
    "                    ind = [np.arange(0,xts.shape[1])]\n",
    "                    #x_1 = xts[ind[0],0:4]\n",
    "                    #x_2 = xts[ind[0],4:6]\n",
    "                    u1 = C(xts[j,:,0:4],training=True)\n",
    "    \n",
    "                    #u2 = tf.reshape(C(xts[j,ind[0]],training=True)[:,1],(ind[0].shape[0],1))\n",
    "                    #u2 = tf.reshape(C(xts[j,:,0:6],training=True)[:,1],(ind[0].shape[0],1))\n",
    "                    #u = tf.reshape(xts[j,:,6],(xts.shape[1],1))\n",
    "                    u=ctrl_nn(torch.tensor(xts[j,:,0:2].numpy(),dtype=torch.float64)).detach().numpy()\n",
    "                    x1 = tf.reshape(xts[j,:,0],(xts.shape[1],1))\n",
    "                    x2 = tf.reshape(xts[j,:,1],(xts.shape[1],1))\n",
    "                    x3 = tf.reshape(xts[j,:,2],(xts.shape[1],1))\n",
    "                    x4 = tf.reshape(xts[j,:,3],(xts.shape[1],1))\n",
    "                    x1s = tf.reshape(xts[j,:,0],(xts.shape[1],1))\n",
    "                    x2s = tf.reshape(xts[j,:,1],(xts.shape[1],1))\n",
    "                    #u1= xts[:,6].reshape(xts.shape[0],1)\n",
    "                    #u2 = xts[:,7].reshape(xts.shape[0],1)\n",
    "                    \n",
    "                                        \n",
    "                    \n",
    "                    x1n = alpha*x1 + tau* x2\n",
    "                    x2n = alpha*x2 + tau *(   2*9.8*tf.math.sin(x1) - tf.math.sin(x1-x3)*x2*x4 + tf.cast((tf.reshape(u1[:,0],[xts.shape[1],1])),tf.float32))\n",
    "                    #x2n = 1.0*x2 + tau *(-3.447*tf.math.pow(alpha*x1,3)+2.35*tf.math.pow(alpha*x1,2)*x3+1.303*tf.math.pow(alpha*x3,2)*x1+3.939*tf.math.pow(alpha*x3,3)+21.52*x1 -5.0*x3+ u1)\n",
    "                    #x2n = 1.0* x2+ tau*(9.8*np.sin(x1)+u1)\n",
    "                    x3n = alpha*x3+ tau*x4\n",
    "                    x4n = alpha*x4 + tau*(9.8*tf.math.sin(x3)+tf.math.sin(x1-x3)*x2*x4 +tf.cast ((tf.reshape(u1[:,1],[xts.shape[1],1])),tf.float32 ))\n",
    "                    #x4n = x4 + tau*(4.023*np.power(x1,3)-36.551*np.power(x1,2)*x3-4.131*np.power(x2,2)*x3-27.06*np.power(x3,3)-25.115*x1+77.7*x3-31.2*u1+391.2*u2)\n",
    "                    \n",
    "                    y_pred11 = x2n\n",
    "                    y_pred12 = x1n \n",
    "                    \n",
    "                    \n",
    "                    x1sn =x1s +  tau*x2s\n",
    "                    y_pred21 = x2s + tau*(9.8*tf.math.sin(x1s)+u)\n",
    "                    y_pred22 = x1sn\n",
    "                    \n",
    "                    #y_pred1  = tf.concat([y_pred11,y_pred21],axis=1)\n",
    "                    #y_pred2 = tf.concat([y_pred12,y_pred22],axis=1)\n",
    "                    h01 = tf.cast(tf.math.abs(tf.reshape(x2,(x2.shape[0],1)).numpy()-x2s),dtype=tf.float32)\n",
    "                    \n",
    "                    h02 = tf.math.abs(tf.reshape(x1,(x1.shape[0],1)).numpy()-x1s)\n",
    "                    #H1.append(h01.numpy())\n",
    "                    #H2.append(h02.numpy())\n",
    "                    \n",
    "                    h1 = tf.math.abs(tf.math.abs(tf.reshape(y_pred11,(y_pred11.shape[0],1)).numpy()-y_pred21[:].numpy()))\n",
    "                    h12 = tf.math.abs(tf.math.abs(tf.reshape(y_pred11,(y_pred11.shape[0],1)).numpy()-y_pred21[:].numpy()))\n",
    "                    h13 = tf.math.abs(tf.math.abs(tf.reshape(y_pred11,(y_pred11.shape[0],1)).numpy()-y_pred21[:].numpy()))\n",
    "                    h2 = tf.math.abs(tf.math.abs(tf.reshape(y_pred12,(y_pred12.shape[0],1)).numpy()-y_pred22[:].numpy()))\n",
    "                    ind2 = np.where(h13-h01>0)\n",
    "                    ind2tot+=ind2[0].shape[0]\n",
    "                    #u2_pred = np.clip(-K*(y_pred2-set_p),nmin,nmax)\n",
    "                    #y_pred = tf.concat([x1n,x2n,x3n,x4n,x1sn,y_pred2],axis=1)\n",
    "                    #y_predtot = V(y_pred) \n",
    "                    #ind1 = np.where(V(y_pred) < delta)\n",
    "                    #ind4= np.where(h<eps)\n",
    "                    h1 = np.max(h1)\n",
    "                    h2 = np.max(h2)\n",
    "                    H1.append(h1)\n",
    "                    H2.append(h2)\n",
    "                    #wh=np.argmax(h)\n",
    "                    #n1 = tf.gather(y_pred1, wh)\n",
    "                    #n2 = tf.gather(y_pred2,wh)\n",
    "                    \n",
    "                    \n",
    "                    #print(ind4[0])\n",
    "                    loss_c1=  tf.reduce_mean (tf.square (tf.subtract(tf.reshape(y_pred21[:,0],y_pred21.shape[0]), tf.reshape(y_pred11,y_pred11.shape[0]))))+tf.math.reduce_max((tf.abs(tf.subtract(y_pred11,y_pred21))))\n",
    "                    #loss_c+= tf.reduce_mean(tf.square(tf.subtract(tf.gather(y_pred21,ind2[0]), tf.gather(y_pred11,ind2[0]))))\n",
    "                    #loss_c+= tf.keras.losses.mean_absolute_error(tf.reshape(y_pred21[:,0],y_pred21.shape[0]), tf.reshape(y_pred11,y_pred11.shape[0]))\n",
    "                    #loss_c+= tf.reduce_mean(tf.square(tf.subtract(y_pred12,y_pred22)))\n",
    "                    loss_c1+=  tf.reduce_mean(tf.square(tf.subtract(x4n,0)))\n",
    "                    #ave = np.copy(loss_c)\n",
    "                    \n",
    "                    #loss_c += 0.071*tf.keras.losses.mean_squared_error(tf.reshape(y_pred1[:,1],y_pred1.shape[0]),0.0)\n",
    "                    #print(loss_c)\n",
    "            #             if ind1[0].size!=0:\n",
    "                        \n",
    "            #                 loss_c += tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(y_predtot,ind1[0]),ind1[0].shape),delta-gamma)\n",
    "                        \n",
    "            #             #loss_c = tf.keras.losses.mean_squared_error(tf.reshape(y_pred1[:,0],y_pred1.shape[0]), tf.reshape(y_pred2,y_pred2.shape[0]))\n",
    "            #             #loss_c += tf.keras.losses.mean_squared_error(y_pred[:,0]-tf.reshape(y_pred2,y_pred2.shape[0]),0.0)\n",
    "            #             elif ind4[0].size!=0:\n",
    "                        \n",
    "            #                 #loss_c += tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(y_predtot,ind4[0])-h[ind4[0]]-gamma,ind4[0].shape[0]),0.0)\n",
    "            #             else: \n",
    "            #                 pass\n",
    "                    \n",
    "                if loss_c1.numpy()==0 or tf.math.is_nan(loss_c1).numpy():\n",
    "                    #switch = not(switch)\n",
    "                    print(f\"switch {i} to V\")\n",
    "                else:\n",
    "                    er = np.copy(loss_c1)\n",
    "                    #er2 = np.copy(loss_c2)\n",
    "                    gradsc = C1_tape.gradient(loss_c1,C.trainable_variables)\n",
    "                    #gradsc2 = C2_tape.gradient(loss_c2,C2.trainable_variables)\n",
    "                    #capped_gradsc =  [(tf.clip_by_norm(grad, 1)) for grad in gradsc]\n",
    "                    #print(gradsc)\n",
    "                    opt.apply_gradients(zip(gradsc,C.trainable_variables))\n",
    "                    #opt.apply_gradients(zip(gradsc2,C2.trainable_variables))\n",
    "                    #del y_pred1, y_pred2\n",
    "                    k+=1\n",
    "                h1c,h2c = np.max(H1),np.max(H2)\n",
    "                \n",
    "                \n",
    "                    \n",
    "\n",
    "                    \n",
    "            else:\n",
    "                #xinp=np.hstack((xts,np.clip(-K*(xts[:,2].reshape(xts.shape[0],1)-set_p),nmin,nmax)))\n",
    "                #xinp2 = np.copy(xt)\n",
    "                ind = np.where(V(xts[:,:6])>=0)\n",
    "                #x_1 = xts[ind[0],0:4]\n",
    "                #x_2 = xts[ind[0],4:6]\n",
    "                u1 = tf.reshape(C(xts[ind[0]],training=True)[:,0],(ind[0].shape[0],1))\n",
    "                \n",
    "                u2 = tf.reshape(C(xts[ind[0]],training=True)[:,1],(ind[0].shape[0],1))\n",
    "                #u2 = C(xts[ind[0]])[:,1].numpy().reshape(ind[0].shape[0],1)\n",
    "                #u = torch.tensor(xts[ind[0],6].reshape(ind[0].shape[0],1))\n",
    "                u = ctrl_nn(torch.tensor(xts[ind[0],4:6])).detach().numpy()\n",
    "                x1 = xts[ind[0],0].reshape(ind[0].shape[0],1)\n",
    "                x2 = xts[ind[0],1].reshape(ind[0].shape[0],1)\n",
    "                x3 = xts[ind[0],2].reshape(ind[0].shape[0],1)\n",
    "                x4 = xts[ind[0],3].reshape(ind[0].shape[0],1)\n",
    "                x1s = xts[ind[0],4].reshape(ind[0].shape[0],1)\n",
    "                x2s = xts[ind[0],5].reshape(ind[0].shape[0],1)\n",
    "                #u1= xts[:,6].reshape(xts.shape[0],1)\n",
    "                #u2 = xts[:,7].reshape(xts.shape[0],1)\n",
    "        \n",
    "                \n",
    "                \n",
    "                x1n = 1.0*x1 + tau* x2\n",
    "                #x2n = x2 + tau *(-3.447*np.power(x1,3)+2.35*np.power(x1,2)*x3+1.303*np.power(x3,2)*x1+3.939*np.power(x3,3)+21.52*x1 -5*x3 + 8*u1-31.2*u2)\n",
    "                x2n = 1.0*x2 + tau *(-3.447*tf.math.pow(alpha*x1,3)+2.35*tf.math.pow(alpha*x1,2)*x3+1.303*tf.math.pow(alpha*x3,2)*x1+3.939*tf.math.pow(alpha*x3,3)+21.52*x1 -5.0*x3+ 8.0*u1-31.2*u2)\n",
    "                x3n = 1.0*x3+ tau*x4\n",
    "                x4n = 1.0*x4 + tau*(4.023*tf.math.pow(alpha*x1,3)-36.551*tf.math.pow(alpha*x1,2)*x3-4.131*tf.math.pow(alpha*x2,2)*x3-27.06*tf.math.pow(alpha*x3,3)-25.115*x1+77.7*x3-31.2*u1+391.2*u2)\n",
    "                #x4n = x4 + tau*(4.023*np.power(x1,3)-36.551*np.power(x1,2)*x3-4.131*np.power(x2,2)*x3-27.06*np.power(x3,3)-25.115*x1+77.7*x3-31.2*u1+391.2*u2)\n",
    "        \n",
    "                y_pred11 = x2n\n",
    "                y_pred12 = x1n \n",
    "        \n",
    "                \n",
    "                x1sn =x1s +  tau*x2s\n",
    "                y_pred21 = x2s + tau*(9.8*np.sin(x1s)+u)\n",
    "                y_pred22 = x1sn\n",
    "                \n",
    "                #y_pred1  = tf.concat([y_pred11,y_pred21],axis=1)\n",
    "                #y_pred2 = tf.concat([y_pred12,y_pred22],axis=1)\n",
    "                h01 = tf.math.abs(tf.reshape(x2,(x2.shape[0],1)).numpy()-x2s)\n",
    "                h02 = tf.math.abs(tf.reshape(x1,(x1.shape[0],1)).numpy()-x1s)\n",
    "                h1 = tf.math.abs(tf.math.abs(tf.reshape(y_pred11,(y_pred11.shape[0],1)).numpy()-y_pred21[:].numpy()))\n",
    "                h2 = tf.math.abs(tf.math.abs(tf.reshape(y_pred12,(y_pred12.shape[0],1)).numpy()-y_pred22[:].numpy()))\n",
    "                #u2_pred = np.clip(-K*(y_pred2-set_p),nmin,nmax)\n",
    "                #y_pred = tf.concat([x1n,x2n,x3n,x4n,x1sn,y_pred2],axis=1)\n",
    "                #y_predtot = V(y_pred) \n",
    "                #ind1 = np.where(V(y_pred) < delta)\n",
    "                #ind4= np.where(h<eps)\n",
    "                h1 = np.max(h1)\n",
    "                h2 = np.max(h2)\n",
    "                \n",
    "                #wh=np.argmax(h)\n",
    "                #n1 = tf.gather(y_pred1, wh)\n",
    "                #n2 = tf.gather(y_pred2,wh)\n",
    "                \n",
    "                loss_c=tf.reduce_mean(tf.square(tf.subtract(0.0,0.0)))\n",
    "                #print(ind4[0])\n",
    "                \n",
    "                loss_c+= tf.reduce_mean(tf.square(tf.subtract(y_pred11,y_pred21))) + tf.math.reduce_max((tf.abs(tf.subtract(y_pred11,y_pred21))))\n",
    "                loss_c+= tf.reduce_mean(tf.square(tf.subtract(y_pred12,y_pred22))) + tf.math.reduce_max((tf.abs(tf.subtract(y_pred12,y_pred22))))\n",
    "                loss_c += (tf.reduce_mean(tf.square(tf.subtract(x2n,0)))+tf.reduce_mean(tf.square(tf.subtract(x4n,0)))+tf.reduce_mean(tf.square(tf.subtract(x1n,0)))+tf.reduce_mean(tf.square(tf.subtract(x3n,0))))\n",
    "                ave = np.copy(loss_c)\n",
    "                if i%10 ==0:\n",
    "                    print(ave,h1,h2)\n",
    "                #loss_c += 0.071*tf.keras.losses.mean_squared_error(tf.reshape(y_pred1[:,1],y_pred1.shape[0]),0.0)\n",
    "                #print(loss_c)\n",
    "        #             if ind1[0].size!=0:\n",
    "                    \n",
    "        #                 loss_c += tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(y_predtot,ind1[0]),ind1[0].shape),delta-gamma)\n",
    "                    \n",
    "        #             #loss_c = tf.keras.losses.mean_squared_error(tf.reshape(y_pred1[:,0],y_pred1.shape[0]), tf.reshape(y_pred2,y_pred2.shape[0]))\n",
    "        #             #loss_c += tf.keras.losses.mean_squared_error(y_pred[:,0]-tf.reshape(y_pred2,y_pred2.shape[0]),0.0)\n",
    "        #             elif ind4[0].size!=0:\n",
    "                    \n",
    "        #                 #loss_c += tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(y_predtot,ind4[0])-h[ind4[0]]-gamma,ind4[0].shape[0]),0.0)\n",
    "        #             else: \n",
    "        #                 pass\n",
    "                \n",
    "                if loss_c.numpy()==0 or tf.math.is_nan(loss_c).numpy():\n",
    "                    switch = not(switch)\n",
    "                    print(f\"switch {i} to V\")\n",
    "                else:\n",
    "                    er = np.copy(loss_c)\n",
    "        \n",
    "                    gradsc = C_tape.gradient(loss_c,C.trainable_variables)\n",
    "                    #capped_gradsc =  [(tf.clip_by_norm(grad, 1)) for grad in gradsc]\n",
    "                    #print(gradsc)\n",
    "                    opt.apply_gradients(zip(gradsc,C.trainable_variables))\n",
    "                    #del y_pred1, y_pred2\n",
    "                    k+=1\n",
    "\n",
    "            \n",
    "            \n",
    "            \n",
    "    else:\n",
    "        \n",
    "        with tf.GradientTape() as V_tape:\n",
    "            \n",
    "            \n",
    "            if mult:\n",
    "                no = 0\n",
    "                for j in range(xts.shape[0]):\n",
    "                    #no += check_cond(V,C,xts[j,:,:],0.01)\n",
    "                    ind = np.where(V(xts[j,:,:6])>=delta)\n",
    "                    #x_1 = xts[ind[0],0:4]\n",
    "                    #x_2 = xts[ind[0],4:6]\n",
    "                    u1 = tf.reshape(C(xts[j,ind[0]],training=True)[:,0],(ind[0].shape[0],1))\n",
    "                    \n",
    "                    u2 = tf.reshape(C(xts[j,ind[0]],training=True)[:,1],(ind[0].shape[0],1))\n",
    "                    #u2 = C(xts[ind[0]])[:,1].numpy().reshape(ind[0].shape[0],1)\n",
    "                    #u = xts[j,ind[0],6].reshape(ind[0].shape[0],1)\n",
    "                    u=ctrl_nn(torch.tensor(xts[j,ind[0],4:6].reshape(ind[0].shape[0],2))).detach().numpy()\n",
    "                    x1 = xts[j,ind[0],0].reshape(ind[0].shape[0],1)\n",
    "                    x2 = xts[j,ind[0],1].reshape(ind[0].shape[0],1)\n",
    "                    x3 = xts[j,ind[0],2].reshape(ind[0].shape[0],1)\n",
    "                    x4 = xts[j,ind[0],3].reshape(ind[0].shape[0],1)\n",
    "                    x1s = xts[j,ind[0],0].reshape(ind[0].shape[0],1)\n",
    "                    x2s = xts[j,ind[0],1].reshape(ind[0].shape[0],1)\n",
    "                    #u1= xts[:,6].reshape(xts.shape[0],1)\n",
    "                    #u2 = xts[:,7].reshape(xts.shape[0],1)\n",
    "                    x1n = 1.0*x1 + tau* x2\n",
    "                    #x2n = x2 + tau *(-3.447*np.power(x1,3)+2.35*np.power(x1,2)*x3+1.303*np.power(x3,2)*x1+3.939*np.power(x3,3)+21.52*x1 -5*x3 + 8*u1-31.2*u2)\n",
    "                    #x2n = 1.0*x2 + tau *(-3.447*np.power(1.0*x1,3)+2.35*np.power(1.0*x1,2)*x3+1.303*np.power(1.0*x3,2)*x1+3.939*np.power(1.0*x3,3)+21.52*x1 -5.0*x3+ u1)\n",
    "                    x2n = alpha*x2 + tau *(   2*9.8*tf.math.sin(x1) - tf.math.sin(x1-x3)*x2*x2 + u1 ) \n",
    "                    x3n = 1.0*x3+ tau*x4\n",
    "                    #x4n = 1.0*x4 + tau*(4.023*np.power(1.0*x1,3)-36.551*np.power(1.0*x1,2)*x3-4.131*np.power(1.0*x2,2)*x3-27.06*np.power(1.0*x3,3)-25.115*x1+77.7*x3+u2)\n",
    "                    x4n = alpha*x4 + tau*(9.8*tf.math.sin(x3)+tf.math.sin(x1-x3)*x4*x4 +u2 )\n",
    "                    #x4n = x4 + tau*(4.023*np.power(x1,3)-36.551*np.power(x1,2)*x3-4.131*np.power(x2,2)*x3-27.06*np.power(x3,3)-25.115*x1+77.7*x3-31.2*u1+391.2*u2)\n",
    "        \n",
    "                    y_pred1 = x4n\n",
    "        \n",
    "                    x1sn =x1s +  tau*x2s\n",
    "                    \n",
    "                    y_pred2 = x2s + tau*(9.8*np.sin(x1s)+u)\n",
    "                    \n",
    "     \n",
    "                                    \n",
    "                    y_pred = tf.concat([x1n,x2n,x3n,x4n,x1sn,y_pred2],axis=1)\n",
    "                    y_predtot = V(y_pred) \n",
    "                    \n",
    "                    \n",
    "                    \n",
    "                    #loss = tf.keras.losses.mean_squared_error(tf.reshape(tf.constant(1.0),1),tf.reshape(tf.constant(1.0),1))\n",
    "                    loss = cce(y_predtot,y_label[ind[0]])\n",
    "                    \n",
    "                er = np.copy(loss)\n",
    "                if no==0:\n",
    "                    switch = not(switch)\n",
    "                    print('Switch to C')\n",
    "                else:\n",
    "                    grads = V_tape.gradient(loss,V.trainable_variables)\n",
    "\n",
    "                    #capped_grads = [(tf.clip_by_norm(grad, 1)) for  grad in grads]\n",
    "\n",
    "                    opt2.apply_gradients(zip(grads,V.trainable_variables))\n",
    "                #del y_pred1, y_pred2\n",
    "                t+=1\n",
    "            else:\n",
    "                ind = np.where(V(xts[:,:6])>=delta)\n",
    "                #x_1 = xts[ind[0],0:4]\n",
    "                #x_2 = xts[ind[0],4:6]\n",
    "                u1 = tf.reshape(C(xts[ind[0]])[:,0],(ind[0].shape[0],1))\n",
    "                u2 = tf.reshape(C(xts[ind[0]])[:,1],(ind[0].shape[0],1))\n",
    "            #u2 = C(xts[ind[0]])[:,1].numpy().reshape(ind[0].shape[0],1)\n",
    "                #u = xts[ind[0],6].reshape(ind[0].shape[0],1)\n",
    "                u = ctrl_nn(torch.tensor(xts[:,4:6])).detach().numpy()\n",
    "                x1 = xts[ind[0],0].reshape(ind[0].shape[0],1)\n",
    "                x2 = xts[ind[0],1].reshape(ind[0].shape[0],1)\n",
    "                x3 = xts[ind[0],2].reshape(ind[0].shape[0],1)\n",
    "                x4 = xts[ind[0],3].reshape(ind[0].shape[0],1)\n",
    "                x1s = xts[ind[0],4].reshape(ind[0].shape[0],1)\n",
    "                x2s = xts[ind[0],5].reshape(ind[0].shape[0],1)\n",
    "                #u1= xts[:,6].reshape(xts.shape[0],1)\n",
    "                #u2 = xts[:,7].reshape(xts.shape[0],1)\n",
    "    \n",
    "                \n",
    "                \n",
    "                x1n = 1.0*x1 + tau* x2\n",
    "                #x2n = x2 + tau *(-3.447*np.power(x1,3)+2.35*np.power(x1,2)*x3+1.303*np.power(x3,2)*x1+3.939*np.power(x3,3)+21.52*x1 -5*x3 + 8*u1-31.2*u2)\n",
    "                x2n = 1.0*x2 + tau *(-3.447*np.power(1.0*x1,3)+2.35*np.power(1.0*x1,2)*x3+1.303*np.power(1.0*x3,2)*x1+3.939*np.power(1.0*x3,3)+21.52*x1 -5.0*x3+ 8.0*u1)\n",
    "                x3n = 1.0*x3+ tau*x4\n",
    "                x4n = 1.0*x4 + tau*(4.023*np.power(1.0*x1,3)-36.551*np.power(1.0*x1,2)*x3-4.131*np.power(1.0*x2,2)*x3-27.06*np.power(1.0*x3,3)-25.115*x1+77.7*x3+39.2*u2)\n",
    "                #x4n = x4 + tau*(4.023*np.power(x1,3)-36.551*np.power(x1,2)*x3-4.131*np.power(x2,2)*x3-27.06*np.power(x3,3)-25.115*x1+77.7*x3-31.2*u1+391.2*u2)\n",
    "    \n",
    "                y_pred1 = x4n\n",
    "    \n",
    "                x1sn =x1s +  tau*x2s\n",
    "                \n",
    "                y_pred2 = x2s + tau*(9.8*np.sin(x1s)+u)\n",
    "                \n",
    " \n",
    "                                \n",
    "                y_pred = tf.concat([x1n,x2n,x3n,x4n,x1sn,y_pred2],axis=1)\n",
    "                y_predtot = V(y_pred) \n",
    "                \n",
    "                \n",
    "                \n",
    "                #loss = tf.keras.losses.mean_squared_error(tf.reshape(tf.constant(1.0),1),tf.reshape(tf.constant(1.0),1))\n",
    "                loss = cce(y_predtot,y_label[ind[0]])\n",
    "               # if ind3[0].size!=0:\n",
    "                    \n",
    "                #loss = tf.keras.losses.mean_squared_error(tf.reshape(v,v.shape[0])-tf.reshape(y_predtot,v.shape[0])-gamma,0.0)\n",
    "                    #loss += tf.keras.losses.mean_absolute_error(tf.reshape(y_predtot2,y_predtot2.shape[0]),0.0)\n",
    "#                 elif ind4[0].size!=0:\n",
    "                    \n",
    "#                     loss += tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(y_predtot,ind4[0])-h[ind4[0]]-gamma,ind4[0].shape[0]),0.0)\n",
    "                #loss += tf.keras.losses.mean_squared_error(tf.reshape(v,v.shape[0]),0.0)\n",
    "                #h = np.abs(x_1[:,0]-x_2[:]).reshape(x_1.shape[0],1)\n",
    "#                 ind5= np.where(v - h <0)\n",
    "#                 if ind5[0].size!=0:\n",
    "#                     loss+= tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(v,ind5[0])-h[ind5[0]]-gamma,ind5[0].shape[0]),0.0)\n",
    "#                 if indbad[0].size!=0:\n",
    "#                     pass\n",
    "# #                     try:\n",
    "# #                         batch = np.random.choice(indbad[0].shape[0], ind[0].shape[0])\n",
    "# #                         #loss+= tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(v2,batch)-gamma,batch.shape[0]),delta)\n",
    "# #                     except:\n",
    "# #                         #loss+= tf.keras.losses.mean_absolute_error(tf.reshape(tf.gather(v2,indbad[0])-gamma,indbad[0].shape[0]),delta)\n",
    "                \n",
    "                er = np.copy(loss)\n",
    "                if no==0:\n",
    "                    switch = not(switch)\n",
    "                    print('Switch to C')\n",
    "                else:\n",
    "                    grads = V_tape.gradient(loss,V.trainable_variables)\n",
    "\n",
    "                    #capped_grads = [(tf.clip_by_norm(grad, 1)) for  grad in grads]\n",
    "\n",
    "                    opt2.apply_gradients(zip(grads,V.trainable_variables))\n",
    "                #del y_pred1, y_pred2\n",
    "                t+=1\n",
    "                #ind4= np.where(tf.reshape(V(xt[ind[0]]),ind[0].shape)- h <0)\n",
    "                #loss_c = tf.keras.losses.mean_squared_error(tf.reshape(V(xt[ind4[0]]),ind4[0].shape)-h[ind4[0]],0.0)\n",
    "                #ind3 = np.where(V(y_pred)>=delta)\n",
    "                #loss_c += tf.keras.losses.mean_squared_error(tf.reshape(tf.gather(y_predtot,ind3[0]),ind3[0].shape),0.0)\n",
    "                \n",
    "                #u = C(xt,training = False)\n",
    "                #x_1 = xt[:,0:2]\n",
    "                #x_2 = xt[:,2]\n",
    "                #y_pred1 = tf.matmul(tf.constant(x_1,dtype=tf.float32),tf.transpose(a)) + tf.reshape(((tau**2)/2)*u[:,0],(xt.shape[0],1) )\n",
    "                #y_pred2 = tf.reshape(tf.constant(x_2,dtype=tf.float32) + tau * u[:,1],(xt.shape[0],1))\n",
    "                #y_pred = np.hstack([y_pred1,y_pred2])\n",
    "                #y_predtot = V(y_pred) \n",
    "                #ind = np.where (y_predtot >= delta)\n",
    "                #ind2 = np.where (y_predtot < delta)\n",
    "#                 if ind[0].shape[0]==0 :\n",
    "#                     switch=not(switch)\n",
    "#                     print(f\"switch {i} to C\")\n",
    "#                 else:\n",
    "#                     #loss_c = tf.keras.losses.mean_squared_error(tf.reshape(V(xt[ind4[0]]),ind4[0].shape)-h[ind4[0]],0.0)\n",
    "#                     #loss = tf.reduce_mean(tf.sqrt(tf.square(tf.subtract(V(xt[ind[0]],training = True),delta))))\n",
    "#                     loss = tf.keras.losses.mean_squared_error(tf.reshape(V(xt[ind4[0]]),ind4[0].shape)-h[ind4[0]]-gamma,0.0)\n",
    "#                     if tf.math.is_nan(loss).numpy():\n",
    "#                         print(f\"switch {i}\")\n",
    "#                         switch = not(switch)\n",
    "#                     else:\n",
    "#                         "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
