{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "class SyntheticDatasetGenerator:\n",
    "    def __init__(self, config_path):\n",
    "        self.config = self.load_config(config_path)\n",
    "        self.save_path = os.path.join(\"data\", self.config[\"save_name\"])\n",
    "        if not os.path.exists(self.save_path):\n",
    "            os.makedirs(self.save_path)\n",
    "\n",
    "    def load_config(self, config_path):\n",
    "        with open(config_path, 'r') as f:\n",
    "            return json.load(f)\n",
    "\n",
    "    def generate_feature(self, wave_type, f_support, f_base):\n",
    "        x_feature = np.sin(np.linspace(0, 2 * np.pi * f_base, self.config[\"n_points\"])).reshape(-1, 1)\n",
    "        x_feature *= 0.5\n",
    "        start_idx = np.random.randint(0, self.config[\"n_points\"] - self.config[\"n_support\"])\n",
    "        \n",
    "        if wave_type == \"sine\":\n",
    "            x_tmp = np.sin(np.linspace(0, 2 * np.pi * f_support, self.config[\"n_points\"])).reshape(-1, 1)\n",
    "            x_feature[start_idx:start_idx + self.config[\"n_support\"], 0] += x_tmp[start_idx:start_idx + self.config[\"n_support\"], 0]\n",
    "        elif wave_type == \"square\":\n",
    "            x_tmp = np.sign(np.sin(np.linspace(0, 2 * np.pi * f_support, self.config[\"n_points\"]))).reshape(-1, 1)\n",
    "            x_feature[start_idx:start_idx + self.config[\"n_support\"], 0] += x_tmp[start_idx:start_idx + self.config[\"n_support\"], 0]\n",
    "        elif wave_type == \"line\":\n",
    "            x_feature[start_idx:start_idx + self.config[\"n_support\"], 0] += np.zeros(self.config[\"n_support\"])\n",
    "        else:\n",
    "            raise ValueError(\"wave must be one of sine, square, line\")\n",
    "        \n",
    "        return x_feature, start_idx\n",
    "\n",
    "    def generate_sample(self):\n",
    "        dict_all = {}\n",
    "        dict_all[\"signal\"] = np.zeros((self.config[\"n_points\"], self.config[\"n_feature\"]))\n",
    "        idx_features = np.random.permutation(np.arange(self.config[\"n_feature\"]))\n",
    "        f_sine_sum = 0\n",
    "        \n",
    "        for enum, idx_feature in enumerate(idx_features):\n",
    "            f_base = np.random.randint(self.config[\"f_base\"][\"min\"], self.config[\"f_base\"][\"max\"] + 1)\n",
    "            f_support = np.random.randint(self.config[\"f_sin\"][\"min\"], self.config[\"f_sin\"][\"max\"] + 1)\n",
    "            \n",
    "            if enum < 2:\n",
    "                f_sine_sum += f_support\n",
    "                x_tmp, _ = self.generate_feature(\"sine\", f_support, f_base)\n",
    "            else:\n",
    "                wave_type = np.random.choice([\"line\", \"square\"])\n",
    "                x_tmp, _ = self.generate_feature(wave_type, f_support, f_base)\n",
    "            \n",
    "            dict_all[\"signal\"][:, idx_feature] = x_tmp.squeeze()\n",
    "        \n",
    "        # Class definition based on f_sine_sum\n",
    "        dict_all[\"target\"] = np.where(f_sine_sum <= np.quantile([np.random.randint(self.config[\"f_sin\"][\"min\"], self.config[\"f_sin\"][\"max\"] + 1) + np.random.randint(self.config[\"f_sin\"][\"min\"], self.config[\"f_sin\"][\"max\"] + 1) for _ in range(10000)], self.config[\"quantile_class\"][0]), 0, 1)\n",
    "        \n",
    "        return dict_all\n",
    "\n",
    "    def generate_dataset(self):\n",
    "        for i in range(self.config[\"nb_simulation\"]):\n",
    "            sample = self.generate_sample()\n",
    "            np.save(os.path.join(self.save_path, f\"sample_{i}.npy\"), sample)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = SyntheticDatasetGenerator(\"config_synthetic_small.json\")\n",
    "generator.generate_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "import os\n",
    "\n",
    "class SyntheticDatasetGenerator:\n",
    "    def __init__(self, config_path):\n",
    "        self.config = self.load_config(config_path)\n",
    "        self.save_path = os.path.join(\"data_small\", self.config[\"save_name\"])\n",
    "        if not os.path.exists(self.save_path):\n",
    "            os.makedirs(self.save_path)\n",
    "\n",
    "    def load_config(self, config_path):\n",
    "        with open(config_path, 'r') as f:\n",
    "            return json.load(f)\n",
    "\n",
    "    def generate_feature(self, wave_type, f_support, f_base):\n",
    "        x_feature = np.sin(np.linspace(0, 2 * np.pi * f_base, self.config[\"n_points\"])).reshape(-1, 1)\n",
    "        x_feature *= 0.5\n",
    "        start_idx = np.random.randint(0, self.config[\"n_points\"] - self.config[\"n_support\"])\n",
    "        \n",
    "        if wave_type == \"sine\":\n",
    "            x_tmp = np.sin(np.linspace(0, 2 * np.pi * f_support, self.config[\"n_points\"])).reshape(-1, 1)\n",
    "            x_feature[start_idx:start_idx + self.config[\"n_support\"], 0] += x_tmp[start_idx:start_idx + self.config[\"n_support\"], 0]\n",
    "        elif wave_type == \"square\":\n",
    "            x_tmp = np.sign(np.sin(np.linspace(0, 2 * np.pi * f_support, self.config[\"n_points\"]))).reshape(-1, 1)\n",
    "            x_feature[start_idx:start_idx + self.config[\"n_support\"], 0] += x_tmp[start_idx:start_idx + self.config[\"n_support\"], 0]\n",
    "        elif wave_type == \"line\":\n",
    "            x_feature[start_idx:start_idx + self.config[\"n_support\"], 0] += np.zeros(self.config[\"n_support\"])\n",
    "        else:\n",
    "            raise ValueError(\"wave must be one of sine, square, line\")\n",
    "        \n",
    "        return x_feature, start_idx\n",
    "\n",
    "    def generate_sample(self):\n",
    "        dict_all = {}\n",
    "        dict_all[\"signal\"] = np.zeros((self.config[\"n_points\"], self.config[\"n_feature\"]))\n",
    "        idx_features = np.random.permutation(np.arange(self.config[\"n_feature\"]))\n",
    "        f_sine_sum = 0\n",
    "        \n",
    "        # Store metadata\n",
    "        dict_all[\"metadata\"] = {\n",
    "            \"sine_feature_indices\": [],\n",
    "            \"start_indices\": [],\n",
    "            \"frequencies\": [],\n",
    "            \"threshold\": None\n",
    "        }\n",
    "        \n",
    "        for enum, idx_feature in enumerate(idx_features):\n",
    "            f_base = np.random.randint(self.config[\"f_base\"][\"min\"], self.config[\"f_base\"][\"max\"] + 1)\n",
    "            f_support = np.random.randint(self.config[\"f_sin\"][\"min\"], self.config[\"f_sin\"][\"max\"] + 1)\n",
    "            \n",
    "            if enum < 2:\n",
    "                f_sine_sum += f_support\n",
    "                x_tmp, start_idx = self.generate_feature(\"sine\", f_support, f_base)\n",
    "                dict_all[\"signal\"][:, idx_feature] = x_tmp.squeeze()\n",
    "                \n",
    "                # Store metadata for sine features\n",
    "                dict_all[\"metadata\"][\"sine_feature_indices\"].append(idx_feature)\n",
    "                dict_all[\"metadata\"][\"start_indices\"].append(start_idx)\n",
    "                dict_all[\"metadata\"][\"frequencies\"].append(f_support)\n",
    "            else:\n",
    "                wave_type = np.random.choice([\"line\", \"square\"])\n",
    "                x_tmp, start_idx = self.generate_feature(wave_type, f_support, f_base)\n",
    "                dict_all[\"signal\"][:, idx_feature] = x_tmp.squeeze()\n",
    "        \n",
    "        # Class definition based on f_sine_sum\n",
    "        threshold = np.quantile([np.random.randint(self.config[\"f_sin\"][\"min\"], self.config[\"f_sin\"][\"max\"] + 1) + np.random.randint(self.config[\"f_sin\"][\"min\"], self.config[\"f_sin\"][\"max\"] + 1) for _ in range(10000)], self.config[\"quantile_class\"][0])\n",
    "        dict_all[\"target\"] = np.where(f_sine_sum <= threshold, 0, 1)\n",
    "        dict_all[\"metadata\"][\"threshold\"] = threshold\n",
    "        \n",
    "        return dict_all\n",
    "\n",
    "    def generate_dataset(self):\n",
    "        for i in range(self.config[\"nb_simulation\"]):\n",
    "            sample = self.generate_sample()\n",
    "            np.save(os.path.join(self.save_path, f\"sample_{i}.npy\"), sample[\"signal\"])\n",
    "            np.save(os.path.join(self.save_path, f\"target_{i}.npy\"), sample[\"target\"])\n",
    "            with open(os.path.join(self.save_path, f\"metadata_{i}.json\"), 'w') as f:\n",
    "                json.dump({k: [int(x) for x in v] if isinstance(v, list) else v for k, v in sample[\"metadata\"].items()}, f)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator = SyntheticDatasetGenerator(\"config_synthetic.json\")\n",
    "generator.generate_dataset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
