{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def apply_perm(perm, x, axis):\n",
    "#     assert perm.shape[0] == perm.shape[1]\n",
    "#     assert x.shape[axis] == perm.shape[0]\n",
    "\n",
    "#     # Bring the specified axis to the front\n",
    "#     x = x.moveaxis(axis, 0)\n",
    "\n",
    "#     # Store the original shape and reshape for matrix multiplication\n",
    "#     original_shape = x.shape\n",
    "#     x = x.reshape(x.shape[0], -1)\n",
    "\n",
    "#     # Apply the permutation\n",
    "#     x_permuted = perm @ x\n",
    "\n",
    "#     # Reshape back to the expanded original shape\n",
    "#     x_permuted = x_permuted.reshape(original_shape)\n",
    "\n",
    "#     # Move the axis back to its original position\n",
    "#     x_permuted = x_permuted.moveaxis(0, axis)\n",
    "\n",
    "#     return x_permuted\n",
    "\n",
    "\n",
    "def apply_perm(X, P, axis):\n",
    "    \"\"\"\n",
    "    Permute a tensor along a specified axis.\n",
    "\n",
    "    Parameters:\n",
    "    X (torch.Tensor): The input tensor, can be 1D, 2D, 3D, or 4D.\n",
    "    P (list or torch.Tensor): The permutation to be applied.\n",
    "    axis (int): The axis along which to permute.\n",
    "\n",
    "    Returns:\n",
    "    torch.Tensor: The permuted tensor.\n",
    "    \"\"\"\n",
    "    # Ensure P is a torch.Tensor\n",
    "    if not isinstance(P, torch.Tensor):\n",
    "        P = torch.tensor(P)\n",
    "\n",
    "    # Check if the axis is valid for the tensor dimensions\n",
    "    if axis < 0 or axis >= X.dim():\n",
    "        raise ValueError(\"Axis is out of bounds for the tensor dimensions\")\n",
    "\n",
    "    # Permute the tensor\n",
    "    # Generate indices for all dimensions\n",
    "    idx = [slice(None)] * X.dim()\n",
    "    # Set the indices for the specified axis to the permutation\n",
    "    idx[axis] = P\n",
    "\n",
    "    return X[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ccmm.matching.utils import perm_indices_to_perm_matrix\n",
    "\n",
    "\n",
    "n = 3\n",
    "d = 5\n",
    "\n",
    "X = torch.arange(n * d).reshape(n, d).float()\n",
    "\n",
    "P_indices = torch.randperm(d)\n",
    "P = perm_indices_to_perm_matrix(P_indices).float()\n",
    "\n",
    "gt_res = X @ P.T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "P.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "gt_res = X @ P.T\n",
    "gt_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func_res = apply_perm(P=P_indices.T.long(), X=X, axis=1)\n",
    "func_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert torch.all(gt_res == func_res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def perm_rows(x, perm):\n",
    "    \"\"\"\n",
    "    X ~ (n, d0) or (n, d0, d1) or (n, d0, d1, d2)\n",
    "    perm ~ (n, n)\n",
    "    \"\"\"\n",
    "    assert x.shape[0] == perm.shape[0]\n",
    "    assert perm.shape[0] == perm.shape[1]\n",
    "\n",
    "    input_dims = \"jklm\"[: x.dim()]\n",
    "    output_dims = \"iklm\"[: x.dim()]\n",
    "\n",
    "    ein_string = f\"ij,{input_dims}->{output_dims}\"\n",
    "\n",
    "    return torch.einsum(ein_string, perm, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 3\n",
    "d = 5\n",
    "e = 4\n",
    "X = torch.arange(n * d * e).reshape(n, d, e).float()\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "P = torch.randperm(3)\n",
    "print(P)\n",
    "P = perm_indices_to_perm_matrix(P)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perm_rows(X, P)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ccmm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
