{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe8f7da7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.graph_objects as go"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ac02d9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import sys\n",
    "# sys.path.append('super_resolution/')\n",
    "# model = torch.load('super_resolution/model_epoch_1.pth')\n",
    "# model\n",
    "# !pip install --upgrade kaleido"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e90c4f06",
   "metadata": {},
   "outputs": [],
   "source": [
    "def CoReLU(x,eps=1e-7):\n",
    "    x1 = torch.relu(x[:,:1])\n",
    "    x_1 = x[:,1:]\n",
    "    scalar = (x1 / (torch.norm(x_1,2,dim=1,keepdim=True) + eps)).clamp(0,1)\n",
    "    return torch.cat([x1,scalar*x_1],dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bca6ccc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "\n",
    "omega = torch.arange(-1,1.1,.2)\n",
    "Omega = torch.meshgrid([omega,omega])\n",
    "x = torch.cat([Omega[0].flatten().unsqueeze(0), Omega[1].flatten().unsqueeze(0)]).t()\n",
    "y = CoReLU(x)\n",
    "z = torch.relu(x)\n",
    "fig = go.Figure()\n",
    "# Add traces\n",
    "\n",
    "fig.add_trace(go.Scatter(x=[0,2,2,0,0], y=[0,0,2,2,0],\n",
    "                    fill='toself', \n",
    "#                     fillcolor='grey',\n",
    "                    hoveron = 'points+fills', # select where hover is active\n",
    "                    line_color='grey',\n",
    "                    text=\"Fills\",\n",
    "                    hoverinfo = 'text+x+y',\n",
    "                    showlegend=False\n",
    "                    ))\n",
    "fig.add_trace(go.Scatter(x=[0,2,2,0], y=[0,-2,2,0],\n",
    "                    fill='toself', \n",
    "#                     fillcolor='green',\n",
    "                    hoveron = 'points+fills', # select where hover is active\n",
    "                    line_color='green',\n",
    "                    text=\"Fills\",\n",
    "                    hoverinfo = 'text+x+y',\n",
    "                    showlegend=False\n",
    "                    ))\n",
    "fig.add_trace(go.Scatter(x=x[:,0], y=x[:,1],\n",
    "                    mode='markers',\n",
    "                    marker=dict(size=4),\n",
    "                    name='input'))\n",
    "fig.add_trace(go.Scatter(x=z[:,0], y=z[:,1],\n",
    "                    mode='markers',\n",
    "                    marker=dict(size=8, symbol=\"cross-thin\"),marker_line_width=2,marker_line_color='grey',\n",
    "                    name='ReLU'))\n",
    "fig.add_trace(go.Scatter(x=y[:,0], y=y[:,1],\n",
    "                    mode='markers',\n",
    "                    marker=dict(size=8, symbol=\"x-thin\"),marker_line_width=2,marker_line_color='green',\n",
    "                    name='CoLU'))\n",
    "\n",
    "\n",
    "fig.update_layout(xaxis_range=[-1.1,1.1], yaxis_range=[-1.1,1.1],\n",
    "                   xaxis_side=\"top\", height=400, width=500,\n",
    "                    legend=dict(\n",
    "                    yanchor=\"bottom\",\n",
    "                    y=0.01,\n",
    "                    xanchor=\"left\",\n",
    "                    x=0.01\n",
    "                    ),\n",
    "                    template=\"plotly_white\",\n",
    "                    font=dict(family=\"Times New Roman\",size=25,),\n",
    "                  margin=dict(l=0, r=0, t=0, b=0)\n",
    ")\n",
    "fig.update_xaxes(showticklabels=False)\n",
    "fig.update_yaxes(showticklabels=False)\n",
    "# fig.show()\n",
    "# fig.write_image('fig.pdf')\n",
    "# fig.write_image('fig.pdf')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "759e5f8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df = pd.read_csv(\"../latent-diffusion/logs/2023-01-13T13-57-09_new-ffhq-ldm-vq-4/csv/version_0/metrics.csv\")\n",
    "df.keys()\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30a9620d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import numpy as np\n",
    "curve = np.array(df['train/loss_simple_step'])\n",
    "\n",
    "plt.plot(curve[~np.isnan(curve)])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c13b3634",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.eye(10)[torch.randperm(10)]@torch.randn([10,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f4c1977",
   "metadata": {},
   "outputs": [],
   "source": [
    "w = torch.randn([100,32,64,3,3]).permute(0,3,4,1,2)\n",
    "U,S,V = torch.svd(w)\n",
    "U.shape, S.shape, V.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d26000e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "v = U@torch.diag_embed(S)@V.mT\n",
    "v.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe770b15",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.norm(w - v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6102985d",
   "metadata": {},
   "outputs": [],
   "source": [
    "w1 = torch.randn([32,1,3,3]).permute(2,3,0,1)\n",
    "w2 = torch.eye(32)[torch.randperm(32)]@w1\n",
    "C = torch.einsum('klOi,kloi->Oo',w1,w2)\n",
    "C.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26d67018",
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "idr,idc = scipy.optimize.linear_sum_assignment(-C)\n",
    "len(idc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3297ccf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "idr,idc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03f71b42",
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(C[idr,idc])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4e41d56",
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(C[torch.arange(32),torch.randperm(32)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1df7d29d",
   "metadata": {},
   "outputs": [],
   "source": [
    "sd = torch.load(\"mnist_pure_cnn_seed_1.pt\")\n",
    "sd.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4a0a833",
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.plot(sd['conv1.bias'].cpu().numpy())\n",
    "for key in sd:\n",
    "    print(key,sd[key].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7703270b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mnist import test, Net\n",
    "from torchvision import datasets, transforms\n",
    "transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,)),\n",
    "        transforms.Resize(32)\n",
    "        ])\n",
    "dataset2 = datasets.MNIST('../data', train=False,\n",
    "                       transform=transform)\n",
    "test_loader = torch.utils.data.DataLoader(dataset2, batch_size = 1000)\n",
    "device = 'cuda'\n",
    "model = Net().cuda()\n",
    "sd = torch.load(\"mnist_pure_cnn_seed_1.pt\")\n",
    "p = idc\n",
    "sd['conv1.weight'] = sd['conv1.weight'][p]\n",
    "sd['conv1.bias'] = sd['conv1.bias'][p]\n",
    "sd['conv2.weight'] = sd['conv2.weight'][:,p]\n",
    "\n",
    "model.load_state_dict(sd)\n",
    "test(model, device, test_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63fe17a1-2182-42e0-924f-49e637b9dc0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import plotly.express as px\n",
    " \n",
    "df = px.data.tips()\n",
    " \n",
    "fig = px.box(df, x = \"sex\", y=\"total_bill\")\n",
    "fig.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b14f8059-75c4-4467-8c8b-c992eeb0a4f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "def read_images_in_folder(folder_path):\n",
    "    image_list = []\n",
    "\n",
    "    # Ensure the provided path is a directory\n",
    "    if not os.path.isdir(folder_path):\n",
    "        raise ValueError(\"The provided path is not a directory.\")\n",
    "\n",
    "    # Loop through all files in the directory\n",
    "    # for filename in os.listdir(folder_path):\n",
    "    filenames = [str(i)+'.png' for i in range(11)]\n",
    "    for filename in filenames:\n",
    "        file_path = os.path.join(folder_path, filename)\n",
    "\n",
    "        # Check if the file is a regular file and has a common image extension\n",
    "        if os.path.isfile(file_path) and filename.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):\n",
    "            # Open the image using PIL\n",
    "            with Image.open(file_path) as img:\n",
    "                # Convert the image to a NumPy array and append to the list\n",
    "                image_array = np.array(img)\n",
    "                image_list.append(image_array)\n",
    "\n",
    "    return image_list\n",
    "\n",
    "folder_path = 'super_resolution/cat_unaligned/'\n",
    "images = read_images_in_folder(folder_path)\n",
    "len(images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72ad992c-6a35-4e1b-ab6c-80d85c459751",
   "metadata": {},
   "outputs": [],
   "source": [
    "import plotly.graph_objects as go\n",
    "import pandas as pd \n",
    "import plotly.express as px\n",
    "def plot_violin(folder_path):\n",
    "    images = read_images_in_folder(folder_path)\n",
    "    \n",
    "    sample = pd.DataFrame({i/10:images[i].flatten()[::10000]/255*2-1 for i in range(11)})\n",
    "    \n",
    "    # sample\n",
    "    fig = px.violin(sample,\n",
    "                 width=1000, height=400)\n",
    "    fig.update_layout(xaxis=dict(showgrid=False),\n",
    "                      yaxis=dict(showgrid=True),\n",
    "                      xaxis_title=\"\", yaxis_title=\"\",\n",
    "                      paper_bgcolor='rgba(0,0,0,0)',\n",
    "                      plot_bgcolor='rgba(0,0,0,0)',\n",
    "                      margin=dict(l=0, r=0, t=0, b=0),)\n",
    "    fig.show()\n",
    "    return fig\n",
    "\n",
    "fig1 = plot_violin('super_resolution/cat_unaligned/')\n",
    "fig1.write_image(\"cat_unaligned.pdf\")\n",
    "fig1.write_image(\"cat_unaligned.pdf\")\n",
    "fig2 = plot_violin('super_resolution/cat_aligned/')\n",
    "fig2.write_image(\"cat_aligned.pdf\")\n",
    "fig2.write_image(\"cat_aligned.pdf\")\n",
    "# fig3 = go.Figure(data=fig1.data + fig2.data)\n",
    "# fig3.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a220f1f-a9de-4354-a67d-32ed76dfaea9",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "df = pd.DataFrame({'idx':i,np.ravel(images[i]/255*2-1)})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea0897db-889d-4c7b-84d2-fbbbd8caa6e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformers\n",
    "import inspect\n",
    "inspect.getfile(\n",
    "transformers.FlaxViTModel\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a208836c-87f3-44a2-832b-efbc76071809",
   "metadata": {},
   "outputs": [],
   "source": [
    "import flax.linen as nn\n",
    "import jax.numpy as jnp\n",
    "import jax\n",
    "def cosilu(x,eps=1e-7):\n",
    "    x1 = x[:,:1]\n",
    "    x_1 = x[:,1:]\n",
    "    scalar = x1 / (jnp.linalg.norm(x_1,axis=1,keepdims=True) + eps)\n",
    "    scalar = jax.nn.sigmoid(scalar)\n",
    "    return jnp.concatenate([x1,scalar*x_1],axis=1)\n",
    "\n",
    "key = jax.random.PRNGKey(0)\n",
    "x = jax.random.normal(key, (10,3))\n",
    "cosilu(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83ae1971-a956-465a-8b15-b115d0edd972",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
