{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": 2,
            "metadata": {},
            "outputs": [],
            "source": [
                "import numpy as np\n",
                "import pandas as pd\n",
                "import torch\n",
                "\n",
                "src_dir = \"../new_dir/DATA\"\n",
                "tgt_dir = \"../tgnnexplainer/xgraph/models/ext/tgat/processed\"\n",
                "dataset = \"MOOC\"\n",
                "\n",
                "# edges = pd.read_csv(f\"{src_dir}/{dataset}/edges.csv\", index_col=0)\n",
                "# edge_features = torch.load(f\"{src_dir}/{dataset}/edge_features.pt\")\n",
                "\n",
                "# ext_full = np.load(f\"./new_dir/DATA/{dataset}/int_train.npz\")\n",
                "\n",
                "# for i in ['indptr', 'indices', 'ts', 'eid']:\n",
                "#     print(ext_full[i].shape)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 12,
            "metadata": {},
            "outputs": [],
            "source": [
                "# Shorten the explain indices\n",
                "datasets = (\"simulate_v1\", \"simulate_v2\", \"reddit\", \"wikipedia\")\n",
                "for dataset in datasets:\n",
                "    ds = pd.read_csv(f'../tgnnexplainer/xgraph/dataset/explain_index/{dataset}.csv')\n",
                "    ds = ds[::5]\n",
                "    \n",
                "    ds.to_csv(f'../tgnnexplainer/xgraph/dataset/explain_index/{dataset}.csv', index=False)\n",
                "        # print(len(ds))\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": 15,
            "metadata": {},
            "outputs": [
                {
                    "data": {
                        "text/plain": [
                            "array([[ 0.49277036, -0.69036051,  1.34358778,  1.54337125],\n",
                            "       [-0.47879322, -1.11126717, -0.46390797,  1.23420109],\n",
                            "       [ 0.55225519, -0.70049968,  1.12268004, -2.12257938],\n",
                            "       ...,\n",
                            "       [ 0.02733296,  0.98931901,  0.19821423, -1.06424055],\n",
                            "       [ 0.38313002,  1.13750433, -1.96857973,  0.77038375],\n",
                            "       [-0.12718611, -2.66271972,  0.30175013, -0.70065818]])"
                        ]
                    },
                    "metadata": {},
                    "output_type": "display_data"
                }
            ],
            "source": [
                "def prepare_data_mooc(df: pd.DataFrame):\n",
                "    df = df.rename(columns={\"src\": \"u\", \"dst\": \"i\", \"time\": \"ts\", \"ext_roll\":\"label\"})\n",
                "    df[\"label\"] -= 1\n",
                "    \n",
                "    df['u'] = df['u'].astype(int)\n",
                "    df['i'] = df['i'].astype(int)\n",
                "\n",
                "    df.insert(len(df.columns), \"idx\", range(df.shape[0]))\n",
                "    df.insert(len(df.columns), \"e_idx\", range(df.shape[0]))\n",
                "\n",
                "    df = df.drop(columns=\"int_roll\")\n",
                "    return df\n",
                "\n",
                "\n",
                "ml_mooc = prepare_data_mooc(edges)\n",
                "ml_mooc.to_csv( f\"{tgt_dir}/ml_{dataset}.csv\", index=False)\n",
                "\n",
                "\n",
                "# For random generation\n",
                "# ml_features = np.random.randn(len(ml_mooc), 4)\n",
                "# np.save(\n",
                "#     f\"{tgt_dir}/ml_{dataset}_node.npy\",\n",
                "#     ml_features,\n",
                "# )\n",
                "\n",
                "# ml_features = np.random.randn(len(ml_mooc), 4)\n",
                "# np.save(\n",
                "#     f\"{tgt_dir}/ml_{dataset}.npy\", ml_features\n",
                "# )\n",
                "# ml_features\n",
                "\n",
                "np.save(\n",
                "    f\"{tgt_dir}/ml_{dataset}.npy\", edge_features.numpy()\n",
                ")"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "mlkit",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.11.5"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
