{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc1e8a7a-129a-4093-8162-d8de2ca1d6bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import glob\n",
    "import os\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"../loader\")\n",
    "sys.path.append(\"../config\")\n",
    "sys.path.append(\"../methods/emmix\")\n",
    "sys.path.append(\"real_data/\")\n",
    "sys.path.append(\"../\")\n",
    "import dataset_info\n",
    "import config_path\n",
    "import reader\n",
    "import sp_tensor\n",
    "import utils as ut\n",
    "import utils_exp as ue\n",
    "\n",
    "import importlib\n",
    "importlib.reload(dataset_info)\n",
    "\n",
    "def select_k_numbers_unique(n, k):\n",
    "    #if k > n:\n",
    "    #    raise ValueError(\"k must be less than or equal to n\")\n",
    "    step = (n - 1) / (k - 1) if k > 1 else 0 \n",
    "    selected = [int(round(1 + i * step)) for i in range(k)]\n",
    "    return list(sorted(set(selected)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e42756a-44d1-4a90-884d-6c75aa8ece4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(os.getcwd())\n",
    "os.chdir('../')\n",
    "print(os.getcwd())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac825d50-dc97-4842-8fca-b57b0606dd8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_info.real_datasets_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "113892ab-18a6-4936-91c6-22c203dc8da9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_ranks(dataset_name, strct, div_number=2, n_rank=8):\n",
    "    if strct == \"Tucker\":\n",
    "        return get_tucker_rank(dataset_name)\n",
    "    elif strct in [\"CP\", \"train\"]:\n",
    "        return get_cp_or_train_rank(dataset_name, strct, div_number=div_number, n_rank=n_rank)\n",
    "    else:\n",
    "        raise \"Invalid low-rank structure\"\n",
    "    \n",
    "\n",
    "def get_tucker_rank(dataset_name):\n",
    "    if \"Jesture\" in dataset_name:\n",
    "        tensor_size = dataset_info.tensor_sizes_rec[\"Jesture\"]\n",
    "        D = len(tensor_size)\n",
    "    else:\n",
    "        D = dataset_info.tensor_dims[dataset_name]\n",
    "        tensor_size = dataset_info.tensor_sizes[dataset_name]\n",
    "\n",
    "    ## Tensor dim is large, we do not do tucker decomp.\n",
    "    ## Thus, this case, tucker rank will be defined as 0.\n",
    "    if D < 6:\n",
    "        R_tucker = [2,3,4]\n",
    "    elif D < 8:\n",
    "        R_tucker = [2,3]\n",
    "    elif D < 10:\n",
    "        R_tucker = [2]\n",
    "    else:\n",
    "        R_tucker = [0]\n",
    "\n",
    "    params_list = [ ut.tucker_n_params(tensor_size, [r for d in range(D)]) for r in R_tucker ]\n",
    "\n",
    "    return R_tucker, params_list\n",
    "\n",
    "def get_cp_or_train_rank(dataset_name, strct, div_number=2, n_rank=8):\n",
    "    \"\"\"\n",
    "    if div_number == 1[2], \n",
    "    then, the max_rank is defined such that\n",
    "    the number of parameter of CP model is same[half] as \n",
    "    the number of nnz in the tensor.\n",
    "    \"\"\"\n",
    "    if \"Jesture\" in dataset_name:\n",
    "        N = str.split(dataset_name, \"Jesture\")[1]\n",
    "        tensor_size = dataset_info.tensor_sizes_rec[\"Jesture\"]\n",
    "        n_samples = len(reader.load_data_jes(N=N).values)\n",
    "    else:\n",
    "        tensor_size = dataset_info.tensor_sizes[dataset_name]\n",
    "        n_samples = len(reader.load_data_real(dataset_name).values)\n",
    "        \n",
    "    tensor_dim  = len(tensor_size)\n",
    "    \n",
    "\n",
    "    max_n_params = int( n_samples / div_number )\n",
    "\n",
    "    max_rank = 1\n",
    "    n_param = 0\n",
    "    while( n_param < max_n_params):\n",
    "        if strct == \"CP\":\n",
    "            n_param = ut.cp_n_params(tensor_size, max_rank)\n",
    "        elif strct == \"train\":\n",
    "            max_rank_vec = [max_rank for d in range(tensor_dim-1)]\n",
    "            n_param = ut.train_n_params(tensor_size, max_rank_vec)\n",
    "        else:\n",
    "            raise \"Invalid low-rank structure\"\n",
    "            \n",
    "        #print(max_rank, n_param)\n",
    "        max_rank += 1\n",
    "\n",
    "    ranks = select_k_numbers_unique(max_rank,n_rank)\n",
    "    if strct == \"CP\":\n",
    "        params_list = [ ut.cp_n_params(tensor_size, r) for r in ranks ]\n",
    "    elif strct == \"train\":\n",
    "        params_list = [ ut.train_n_params(tensor_size, [r for d in range(tensor_dim-1)] ) for r in ranks ]\n",
    "    return ranks, params_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4cc5614d-49b6-4c9d-8528-e15e629c9b79",
   "metadata": {},
   "outputs": [],
   "source": [
    "def run():\n",
    "\n",
    "    # For pure models\n",
    "    ranks_cp = {dataset_name: get_ranks(dataset_name, \"CP\")[0] for dataset_name in dataset_info.real_datasets_list}\n",
    "    ranks_train  = {dataset_name: get_ranks(dataset_name, \"train\")[0] for dataset_name in dataset_info.real_datasets_list}\n",
    "    ranks_tucker = {dataset_name: get_ranks(dataset_name, \"Tucker\")[0] for dataset_name in dataset_info.real_datasets_list}\n",
    "\n",
    "    \"\"\"\n",
    "    Ns = [str.split( path, \"/\" )[-1] for path in glob.glob(config_path.data_repo_jes+\"*0\") ]\n",
    "    for N in Ns:\n",
    "        ranks_cp[f\"Jesture{N}\"] = get_ranks(f\"Jesture{N}\", \"CP\")[0] \n",
    "        ranks_train[f\"Jesture{N}\"] = get_ranks(f\"Jesture{N}\", \"train\")[0] \n",
    "        ranks_tucker[f\"Jesture{N}\"] = get_ranks(f\"Jesture{N}\", \"Tucker\")[0] \n",
    "    \"\"\"\n",
    "\n",
    "    datasets_list = ranks_cp.keys()\n",
    "\n",
    "    params_cp = {dataset_name: get_ranks(dataset_name, \"CP\")[1] for dataset_name in datasets_list}\n",
    "    params_train  = {dataset_name: get_ranks(dataset_name, \"train\")[1] for dataset_name in datasets_list}\n",
    "    params_tucker = {dataset_name: get_ranks(dataset_name, \"Tucker\")[1] for dataset_name in datasets_list}\n",
    "\n",
    "    ranks_cp_vec = { dataset_name : [[r_cp, 0, 0]\n",
    "     for r_cp in ranks_cp[dataset_name]]\n",
    "     for dataset_name in datasets_list }\n",
    "    \n",
    "    ranks_tucker_vec = { dataset_name : [[0, r_tucker, 0]\n",
    "     for r_tucker in ranks_tucker[dataset_name]]\n",
    "     for dataset_name in datasets_list}\n",
    "\n",
    "    ranks_train_vec = { dataset_name : [[0, 0, r_train]\n",
    "     for r_train  in ranks_train[dataset_name]]\n",
    "     for dataset_name in datasets_list}\n",
    "\n",
    "    # For mixture models\n",
    "    \n",
    "    rank_limit_cp = 5\n",
    "    rank_limit_train  = 5\n",
    "    rank_limit_tucker = 5\n",
    "    \n",
    "    ## CP and Train\n",
    "    \n",
    "    ranks_cp_train = { dataset_name : [[r_cp, 0, r_train]\n",
    "     for r_cp in ranks_cp[dataset_name][0:rank_limit_cp]\n",
    "     for r_train  in ranks_train[dataset_name][0:rank_limit_train]]\n",
    "     for dataset_name in datasets_list}\n",
    "    \n",
    "    params_cp_train = { dataset_name : [pcp + 0 + ptrain \n",
    "     for pcp in params_cp[dataset_name][0:rank_limit_cp]\n",
    "     for ptrain  in params_train[dataset_name][0:rank_limit_train]]\n",
    "     for dataset_name in datasets_list}\n",
    "    \n",
    "    ## CP and Tucker\n",
    "    \n",
    "    ranks_cp_tucker = { dataset_name : [[r_cp, r_tucker, 0]\n",
    "     for r_cp in ranks_cp[dataset_name][0:rank_limit_cp]\n",
    "     for r_tucker  in ranks_tucker[dataset_name][0:rank_limit_tucker]]\n",
    "     for dataset_name in datasets_list}\n",
    "    \n",
    "    params_cp_tucker = { dataset_name : [pcp + ptucker + 0\n",
    "     for pcp in params_cp[dataset_name][0:rank_limit_cp]\n",
    "     for ptucker  in params_tucker[dataset_name][0:rank_limit_tucker]]\n",
    "     for dataset_name in datasets_list}\n",
    "    \n",
    "    ## Tucker and Train\n",
    "    \n",
    "    ranks_tucker_train = { dataset_name : [[0, r_tucker, r_train]\n",
    "     for r_tucker in ranks_tucker[dataset_name][0:rank_limit_tucker]\n",
    "     for r_train  in ranks_train[dataset_name][0:rank_limit_train]]\n",
    "     for dataset_name in datasets_list}\n",
    "    \n",
    "    params_tucker_train = { dataset_name : [pcp + ptucker + 0\n",
    "     for pcp in params_cp[dataset_name][0:rank_limit_cp]\n",
    "     for ptucker in params_tucker[dataset_name][0:rank_limit_tucker]]\n",
    "     for dataset_name in datasets_list}\n",
    "    \n",
    "    ## CP, Tucker and Train (ALL)\n",
    "    \n",
    "    ranks_cp_tucker_train = { dataset_name : [[r_cp, r_tucker, r_train] \n",
    "     for r_cp in ranks_cp[dataset_name][0:rank_limit_cp]\n",
    "     for r_tucker in ranks_tucker[dataset_name][0:rank_limit_tucker]\n",
    "     for r_train  in ranks_train[dataset_name][0:rank_limit_train]]\n",
    "     for dataset_name in datasets_list}\n",
    "    \n",
    "    params_cp_tucker_train = { dataset_name : [pcp + ptucker + ptrain \n",
    "     for pcp in params_cp[dataset_name][0:rank_limit_cp]\n",
    "     for ptucker in ranks_tucker[dataset_name][0:rank_limit_tucker]\n",
    "     for ptrain  in params_train[dataset_name][0:rank_limit_train]]\n",
    "     for dataset_name in datasets_list}\n",
    "\n",
    "    cps = {\"params\":params_cp,\"ranks\":ranks_cp_vec}\n",
    "    tuckers = {\"params\":params_tucker,\"ranks\":ranks_tucker_vec}\n",
    "    trains  = {\"params\":params_train, \"ranks\":ranks_train_vec}\n",
    "    \n",
    "    cptrains     = {\"params\":params_cp_train,     \"ranks\":ranks_cp_train}\n",
    "    cptuckers    = {\"params\":params_cp_tucker,    \"ranks\":ranks_cp_tucker}\n",
    "    tuckertrains = {\"params\":params_tucker_train, \"ranks\":ranks_tucker_train}\n",
    "    \n",
    "    cptuckertrains  = {\"params\":params_cp_tucker_train, \"ranks\":ranks_cp_tucker_train}\n",
    "\n",
    "    ue.pickle_dump(cps,     \"config/ranks/ranks_cp.pkl\")\n",
    "    ue.pickle_dump(tuckers, \"config/ranks/ranks_tucker.pkl\")\n",
    "    ue.pickle_dump(trains,  \"config/ranks/ranks_train.pkl\")\n",
    "    \n",
    "    ue.pickle_dump(cptrains,     \"config/ranks/ranks_cptrain.pkl\")\n",
    "    ue.pickle_dump(cptuckers,    \"config/ranks/ranks_cptucker.pkl\")\n",
    "    ue.pickle_dump(tuckertrains, \"config/ranks/ranks_tuckertrain.pkl\")\n",
    "    \n",
    "    ue.pickle_dump(cptuckertrains,  \"config/ranks/ranks_cptuckertrain.pkl\")\n",
    "\n",
    "    print(\"all ranks are saved\")\n",
    "\n",
    "    ## When you read rank list, run..\n",
    "    ## loaded = ue.pickle_load(\"ranks/ranks_cp.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3587b01-a8c1-4a0e-b5a7-1c656c1c76ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c81eb2b4-8cd4-4d49-b5be-f1992ae8461b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "## When you read rank list, run..\n",
    "loaded = ue.pickle_load(\"config/ranks/ranks_cp.pkl\")\n",
    "loaded[\"ranks\"]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
