{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "from environments import *\n",
    "\n",
    "from copy import deepcopy\n",
    "import json\n",
    "\n",
    "\n",
    "def change_config_value(config, key, value):\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        key (str): algorithms.gamma のような形式で与えられる\n",
    "        value (any): 変更する値\n",
    "    \"\"\"\n",
    "    config = deepcopy(config)\n",
    "    keys = key.split(\".\")\n",
    "    current = config\n",
    "    for k in keys[:-1]:\n",
    "        current = current[k]\n",
    "    if keys[-1] not in current:\n",
    "        raise ValueError(f\"{key} does not exist.\")\n",
    "    if isinstance(current[keys[-1]], dict):\n",
    "        raise ValueError(f\"{key} is not a leaf node.\")\n",
    "    current[keys[-1]] = value\n",
    "    return config\n",
    "\n",
    "\n",
    "def variate_config(config, dict_key_values):\n",
    "    \"\"\"辞書形式で与えられる変更の全ての組み合わせを生成する。\n",
    "    例えば、 dict_key_values = {\"algorithms.gamma\": [0.1, 0.2], \"algorithms.c1\": [5000, 10000]} とした場合、\n",
    "    {\"algorithms.gamma\": 0.1, \"algorithms.c1\": 5000}, {\"algorithms.gamma\": 0.1, \"algorithms.c1\": 10000},\n",
    "    {\"algorithms.gamma\": 0.2, \"algorithms.c1\": 5000}, {\"algorithms.gamma\": 0.2, \"algorithms.c1\": 10000}\n",
    "    のような組み合わせを生成する。\n",
    "\n",
    "    Args:\n",
    "        dict_key_values (dict): key が \"algorithms.gamma\" のような形式で与えられる。 values は変更する値のリスト。\n",
    "    \"\"\"\n",
    "    configs = [deepcopy(config)]\n",
    "    for key, values in dict_key_values.items():\n",
    "        new_configs = []\n",
    "        for value in values:\n",
    "            new_configs.extend([change_config_value(c, key, value) for c in configs])\n",
    "        configs = new_configs\n",
    "    return configs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "objective_name:\t MatrixFactorization\n",
      "algorithm_name:\t RSHTR\n",
      "====================================\n",
      "# of configs:\t 2\n",
      "config/config_20240930175809_MatrixFactorization_RSHTR_0.json\n",
      "config/config_20240930175809_MatrixFactorization_RSHTR_1.json\n"
     ]
    }
   ],
   "source": [
    "# objective_name = ROSENBROCK_RANKDEFICIENT\n",
    "objective_name = MATRIXFACTORIZATION\n",
    "# objective_name = MATRIXFACTORIZATION_COMPLETION\n",
    "# objective_name = LOGISTIC\n",
    "# objective_name = SOFTMAX\n",
    "# objective_name = MLPNET\n",
    "# objective_name = MLPNET_ELU\n",
    "\n",
    "\n",
    "# algorithm_name = GRADIENT_DESCENT\n",
    "# algorithm_name = TRUST_REGION\n",
    "# algorithm_name = CUBIC_REGULARIZED_NEWTON\n",
    "# algorithm_name = SUBSPACE_GRADIENT_DESCENT\n",
    "# algorithm_name = SUBSPACE_REGULARIZED_NEWTON\n",
    "algorithm_name = SUBSPACE_TRUST_REGION\n",
    "# algorithm_name = KRYLOV_CUBIC_REGULARIZED_NEWTON\n",
    "\n",
    "\n",
    "dict_key_values = {\n",
    "    \"iteration\": [100000],\n",
    "    \"log_interval\": [10],\n",
    "    # \"max_time\": [30],\n",
    "    \"max_time\": [600],\n",
    "    # \"max_time\": [4000],\n",
    "    # \"max_time\": [48000],\n",
    "}\n",
    "\n",
    "\n",
    "objective_dict_key_values = {\n",
    "    ROSENBROCK_RANKDEFICIENT: {\n",
    "        \"dim\": [10000],\n",
    "        # \"rank\": [50],\n",
    "        # \"rank\": [25, 100, 1000, 10000],\n",
    "        \"rank\": [25, 50, 100, 150],\n",
    "        # \"rank\": [150],\n",
    "        \"matrix_seed\": [1],\n",
    "    },\n",
    "    MATRIXFACTORIZATION: {\n",
    "        \"data_name\": [\"movie\"],\n",
    "        \"rank\": [50],\n",
    "    },\n",
    "    MATRIXFACTORIZATION_COMPLETION: {\n",
    "        \"data_name\": [\"movie\"],\n",
    "        \"rank\": [50],\n",
    "    },\n",
    "    LOGISTIC: {\n",
    "        # \"data_name\": [\"rcv1\", \"news20\", \"ad\"],\n",
    "        \"data_name\": [\"rcv1\", \"news20\"],\n",
    "        \"data_size\": [1000],\n",
    "        \"dim\": [10000],\n",
    "        \"bias\": [True],\n",
    "    },\n",
    "    SOFTMAX: {\n",
    "        \"objective_name\": [\"Softmax\"],\n",
    "        # \"data_name\": [\"scotus\", \"news20\", \"news20_tfidf\"],\n",
    "        # \"data_name\": [\"scotus\", \"news20_tfidf\"],\n",
    "        \"data_name\": [\"scotus\"],\n",
    "        \"data_size\": [None],\n",
    "        \"batch_size\": [1],\n",
    "        \"random_seed\": [0],\n",
    "        \"dim\": [10000],\n",
    "    },\n",
    "    MLPNET: {\n",
    "        \"objective_name\": [\"MLPNET\"],\n",
    "        # \"data_name\": [\"mnist\"],\n",
    "        \"data_name\": [\"cifar10\"],\n",
    "        \"data_size\": [\"large\"],\n",
    "        \"batch_size\": [1, 32, 1024],\n",
    "        \"random_seed\": [1, 2, 3, 4],\n",
    "        # \"random_seed\": [0],\n",
    "        \"activation\": [\"elu\"],\n",
    "        \"init_param\": [\"init\"],\n",
    "        \"param_noise_sigma\": [0.05],\n",
    "        \"param_noise_seed\": [0],\n",
    "        \"criterion\": [\"CrossEntropy\"],\n",
    "        \"layers_size\": [\n",
    "            # [784, 128, 64, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 10],\n",
    "            [3072, 128, 64, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 10],\n",
    "        ],\n",
    "        \"bias\": [True],\n",
    "    },\n",
    "}\n",
    "\n",
    "\n",
    "algorithm_dict_key_values = {\n",
    "    GRADIENT_DESCENT: {\n",
    "        \"eps\": [1e-8],\n",
    "        \"lr\": [0],\n",
    "        \"backward\": [\"DD\"],\n",
    "        \"linesearch\": [True],\n",
    "    },\n",
    "    TRUST_REGION: {\n",
    "        \"eig_solver\": [\"lanczos_hvp\"],\n",
    "        \"delta\": [1e-3],\n",
    "        \"Delta\": [1e-3],\n",
    "        \"nu\": [0.1],\n",
    "        \"eps\": [1e-8],\n",
    "        \"alpha\": [0.3],\n",
    "        \"beta\": [0.8],\n",
    "        \"backward\": [\"DD\"],\n",
    "        \"lanczos_precision\": [\"single\"],\n",
    "    },\n",
    "    CUBIC_REGULARIZED_NEWTON: {\n",
    "        \"reg_coef\": [1e-3],\n",
    "        \"solver_eps\": [1e-8],\n",
    "        \"eps\": [1e-6],\n",
    "        \"beta\": [0.5],\n",
    "        \"backward\": [\"DD\"],\n",
    "    },\n",
    "    SUBSPACE_GRADIENT_DESCENT: {\n",
    "        \"reduced_dim\": [100],\n",
    "        # \"reduced_dim\": [50, 200],\n",
    "        \"mode\": [\"random\"],\n",
    "        \"eps\": [1e-6],\n",
    "        \"lr\": [0],\n",
    "        \"backward\": [\"DD\"],\n",
    "        \"linesearch\": [True],\n",
    "        # \"random_matrix_seed\": [0, 1, 2, 3, 4, 5, 6, 7],\n",
    "        # \"random_matrix_seed\": [0],\n",
    "        \"random_matrix_seed\": [1, 2],\n",
    "    },\n",
    "    SUBSPACE_REGULARIZED_NEWTON: {\n",
    "        \"subspace_hessian_func\": [\"loop\"],\n",
    "        \"reduced_dim\": [100],\n",
    "        # \"reduced_dim\": [50, 200],\n",
    "        \"gamma\": [0.5],\n",
    "        \"c1\": [2],\n",
    "        \"c2\": [1],\n",
    "        \"eps\": [1e-6],\n",
    "        \"alpha\": [0.3],\n",
    "        \"beta\": [0.8],\n",
    "        \"backward\": [\"DD\"],\n",
    "        # \"random_matrix_seed\": [0, 1, 2, 3, 4, 5, 6, 7],\n",
    "        # \"random_matrix_seed\": [0],\n",
    "        \"random_matrix_seed\": [1, 2],\n",
    "    },\n",
    "    SUBSPACE_TRUST_REGION: {\n",
    "        \"subspace_hessian_func\": [\"loop\"],\n",
    "        \"eig_solver\": [\"lanczos_hvp\"],\n",
    "        \"reduced_dim\": [100],\n",
    "        # \"reduced_dim\": [50, 200],\n",
    "        \"delta\": [1e-3],\n",
    "        \"Delta\": [1e-3],\n",
    "        \"nu\": [0.1],\n",
    "        \"eps\": [1e-6],\n",
    "        \"alpha\": [0.3],\n",
    "        \"beta\": [0.8],\n",
    "        \"backward\": [\"DD\"],\n",
    "        \"lanczos_precision\": [\"single\"],\n",
    "        # \"random_matrix_seed\": [1, 2, 3, 4],\n",
    "        # \"random_matrix_seed\": [0],\n",
    "        \"random_matrix_seed\": [1, 2],\n",
    "    },\n",
    "    KRYLOV_CUBIC_REGULARIZED_NEWTON: {\n",
    "        \"reg_coef\": [1e-3],\n",
    "        \"reduced_dim\": [100],\n",
    "        \"solver_eps\": [1e-8],\n",
    "        \"eps\": [1e-6],\n",
    "        \"beta\": [0.5],\n",
    "        \"backward\": [\"DD\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "\n",
    "\n",
    "##########################\n",
    "# ここから先はいじらない #\n",
    "##########################\n",
    "\n",
    "dict_key_values.update({\n",
    "    \"objective.\" + k: v for k, v in objective_dict_key_values[objective_name].items()\n",
    "})\n",
    "dict_key_values.update({\n",
    "    \"algorithms.\" + k: v for k, v in algorithm_dict_key_values[algorithm_name].items()\n",
    "})\n",
    "\n",
    "config = {\n",
    "    \"objective\": {\n",
    "        \"objective_name\": objective_name,\n",
    "        **{\n",
    "            k: v[0] for k, v in objective_dict_key_values[objective_name].items()\n",
    "        }\n",
    "    },\n",
    "    \"algorithms\": {\n",
    "        \"solver_name\": algorithm_name,\n",
    "        **{\n",
    "            k: v[0] for k, v in algorithm_dict_key_values[algorithm_name].items()\n",
    "        }\n",
    "    },\n",
    "    \"constraints\": {\n",
    "        \"constraints_name\": \"NoConstraint\",\n",
    "        \"constraints_num\": 0\n",
    "    },\n",
    "    \"iteration\": 1000,\n",
    "    \"log_interval\": 10,\n",
    "    \"max_time\": 1800\n",
    "}\n",
    "\n",
    "\n",
    "configs = variate_config(config, dict_key_values)\n",
    "print(\"objective_name:\\t\", objective_name)\n",
    "print(\"algorithm_name:\\t\", algorithm_name)\n",
    "print(\"====================================\")\n",
    "print(\"# of configs:\\t\", len(configs))\n",
    "\n",
    "from pathlib import Path\n",
    "from datetime import datetime\n",
    "\n",
    "save_dir = Path(\"config/\")\n",
    "save_dir.mkdir(exist_ok=True)\n",
    "# for f in save_dir.glob(\"*\"):\n",
    "#     f.unlink()\n",
    "timestamp = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n",
    "\n",
    "\n",
    "for i, c in enumerate(configs):\n",
    "    save_path = str(save_dir / f\"config_{timestamp}_{objective_name}_{algorithm_name}_{i}.json\")\n",
    "    with open(save_path, \"w\") as f:\n",
    "        json.dump(c, f, indent=4)\n",
    "        print(save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
