{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "import os\n",
    "\n",
    "sys.path.append(os.path.dirname(os.getcwd()))\n",
    "\n",
    "from src.models import *\n",
    "from src.loss_Erf import *\n",
    "from src.equi_test import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "student = EquiBlocks(8,[(3,2)],3,activation='erf')\n",
    "teacher = EquiBlocks(8,[(3,2)],3,activation='erf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(2.3655, grad_fn=<DivBackward0>),\n",
       " tensor(1.1861, grad_fn=<AddBackward0>))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "N_train = 500000\n",
    "train = torch.randn(8,N_train)\n",
    "N_test = 100000\n",
    "test = torch.randn(8,N_test)\n",
    "torch.sum((student(train)-teacher(train))**2)/N_train,F(student.create_W(),teacher.create_W())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.9946, grad_fn=<DivBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "student = EquiBlocks(8,[(3,2)],3,activation='erf')\n",
    "teacher = EquiBlocks(8,[(3,2)],3,activation='erf')\n",
    "\n",
    "N_train = 50000\n",
    "train = torch.randn(8,N_train)\n",
    "F(student.create_W(),teacher.create_W())/(torch.sum((student(train)-teacher(train))**2)/N_train/2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "e3nn",
   "language": "python",
   "name": "e3nn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
