{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b422040c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from generate_imgs import render_from_csv_to_h5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2cb2e1f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_independent_test_csv(N, seed=123, y_range=(0, 1), index_begin=0):\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    # 1. Sample x independently, still shaped like sin distribution\n",
    "    noise_x = rng.uniform(0, np.pi / 2, N)\n",
    "    x = np.sin(noise_x)  # marginal similar to training\n",
    "    x += rng.normal(0, 0.01, N)\n",
    "\n",
    "    # 2. Sample y from same range but independently from x\n",
    "    y = rng.uniform(y_range[0], y_range[1], N)\n",
    "    noise_y = rng.normal(0, 0.15, N)\n",
    "    y = y + noise_y\n",
    "\n",
    "    # 3. Image domain versions\n",
    "    x_im = (x + rng.normal(0, 0.01, N))\n",
    "    y_im = np.exp(y + rng.normal(0, 0.1, N)) + rng.normal(0, 0.15, N)\n",
    "\n",
    "    # 4. Remaining factors\n",
    "    scale = rng.uniform(0.5, 0.7, N)\n",
    "    orientation = rng.uniform(0, 360, N)\n",
    "    scale_im = scale + rng.normal(0, 0.05, N)\n",
    "    orientation_im = orientation\n",
    "\n",
    "    shape_choices = [\"square\", \"ellipse\", \"heart\"]\n",
    "    shape = rng.choice(shape_choices, size=N)\n",
    "\n",
    "    # 5. Shape flipping\n",
    "    shape_im = shape.copy()\n",
    "    p_flip = 0.05\n",
    "    flip_mask = rng.uniform(0, 1.0, N) < p_flip\n",
    "    shape_flip_map = {\n",
    "        \"square\": [\"ellipse\", \"heart\"],\n",
    "        \"ellipse\": [\"square\", \"heart\"],\n",
    "        \"heart\": [\"square\", \"ellipse\"]\n",
    "    }\n",
    "    shape_im[flip_mask] = [rng.choice(shape_flip_map[s]) for s in shape_im[flip_mask]]\n",
    "\n",
    "    # 6. Normalize all [0,1]\n",
    "    x = (x - np.min(x)) / (np.max(x) - np.min(x))\n",
    "    y = (y - np.min(y)) / (np.max(y) - np.min(y))\n",
    "    scale = (scale - np.min(scale)) / (np.max(scale) - np.min(scale))\n",
    "    orientation = (orientation - np.min(orientation)) / (np.max(orientation) - np.min(orientation))\n",
    "    x_im = (x_im - np.min(x_im)) / (np.max(x_im) - np.min(x_im))\n",
    "    y_im = (y_im - np.min(y_im)) / (np.max(y_im) - np.min(y_im))\n",
    "    scale_im = (scale_im - np.min(scale_im)) / (np.max(scale_im) - np.min(scale_im))\n",
    "    orientation_im = (orientation_im - np.min(orientation_im)) / (np.max(orientation_im) - np.min(orientation_im))\n",
    "\n",
    "    df = pd.DataFrame({\n",
    "        \"index\": np.arange(N) + index_begin,\n",
    "        \"shape\": shape,\n",
    "        \"x\": x,\n",
    "        \"y\": y,\n",
    "        \"scale\": scale,\n",
    "        \"orientation\": orientation,\n",
    "        \"x_im\": x_im,\n",
    "        \"y_im\": y_im,\n",
    "        \"scale_im\": scale_im,\n",
    "        \"orientation_im\": orientation_im,\n",
    "        \"shape_im\": shape_im\n",
    "    })\n",
    "\n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6c73d9d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df = generate_independent_test_csv(1000, seed=34, index_begin=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "2283cd93",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_fixed_and_variants(df, N_fixed=1000, seed=42):\n",
    "    rng = np.random.default_rng(seed)\n",
    "\n",
    "    # Take N_fixed fixed samples and keep their original index\n",
    "    df_fixed = df.sample(n=N_fixed, random_state=seed).reset_index(drop=False)\n",
    "    df_fixed = df_fixed.rename(columns={\"index\": \"orig_index\"})  # preserve original IDs\n",
    "\n",
    "    # === 1. Sampled x 10 for each ===\n",
    "    x_min, x_max = df_fixed[\"x\"].min(), df_fixed[\"x\"].max()\n",
    "    rows = []\n",
    "    for _, row in df_fixed.iterrows():\n",
    "        for _ in range(10):\n",
    "            new_row = row.copy()\n",
    "            new_x = rng.uniform(x_min, x_max) + rng.normal(0, 0.01)\n",
    "            new_row[\"x\"] = new_x\n",
    "            new_x_im = new_x + rng.normal(0, 0.01)\n",
    "            new_row[\"x_im\"] = new_x_im\n",
    "            rows.append(new_row)\n",
    "    df_sampled_x = pd.DataFrame(rows).reset_index(drop=True)\n",
    "    # Normalize x and x_im again\n",
    "    df_sampled_x[\"x\"] = (df_sampled_x[\"x\"] - x_min) / (x_max - x_min)\n",
    "    df_sampled_x[\"x_im\"] = (df_sampled_x[\"x_im\"] - df_sampled_x[\"x_im\"].min()) / (df_sampled_x[\"x_im\"].max() - df_sampled_x[\"x_im\"].min())\n",
    "\n",
    "    # sample y 10 for each\n",
    "    y_min, y_max = df_fixed[\"y_im\"].min(), df_fixed[\"y_im\"].max()\n",
    "    rows = []\n",
    "    for _, row in df_fixed.iterrows():\n",
    "        for _ in range(10):\n",
    "            new_row = row.copy()\n",
    "            new_y = rng.uniform(y_min, y_max) + rng.normal(0, 0.15)\n",
    "            new_row[\"y\"] = new_y\n",
    "            new_y_im = np.exp(new_y + rng.normal(0, 0.1)) + rng.normal(0, 0.15)\n",
    "            new_row[\"y_im\"] = new_y_im\n",
    "            rows.append(new_row)\n",
    "    df_sampled_y = pd.DataFrame(rows).reset_index(drop=True)\n",
    "    # Normalize y and y_im again\n",
    "    df_sampled_y[\"y\"] = (df_sampled_y[\"y\"] - df_sampled_y[\"y\"].min()) / (df_sampled_y[\"y\"].max() - df_sampled_y[\"y\"].min())\n",
    "    df_sampled_y[\"y_im\"] = (df_sampled_y[\"y_im\"] - df_sampled_y[\"y_im\"].min()) / (df_sampled_y[\"y_im\"].max() - df_sampled_y[\"y_im\"].min())\n",
    "\n",
    "\n",
    "    # === Assign GLOBAL continuous IDs ===\n",
    "    start = 1000\n",
    "    df_sampled_x[\"index\"] = np.arange(start, start + len(df_sampled_x))\n",
    "    start += len(df_sampled_x)\n",
    "    df_sampled_y[\"index\"] = np.arange(start, start + len(df_sampled_y))\n",
    "    start += len(df_sampled_y)\n",
    "\n",
    "    return df_sampled_x, df_sampled_y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9c479ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_sampled_x, df_sampled_y = generate_fixed_and_variants(test_df, N_fixed=1000, seed=31)\n",
    "test_df.to_csv('/home/***/disco_v2/disco_v2/data/dSprites/utils/ctf_on_main_exp/test.csv', index=False)\n",
    "df_sampled_x.to_csv('/home/***/disco_v2/disco_v2/data/dSprites/utils/ctf_on_main_exp/cf_x.csv', index=False)\n",
    "df_sampled_y.to_csv('/home/***/disco_v2/disco_v2/data/dSprites/utils/ctf_on_main_exp/cf_y.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c509f60",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_all = pd.concat([test_df, df_sampled_x, df_sampled_y], ignore_index=True)\n",
    "df_all.to_csv('/home/***/disco_v2/disco_v2/data/dSprites/utils/ctf_on_main_exp/cf_all.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1f9eaeb",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Rendering & writing:   0%|          | 0/21000 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/***/miniconda3/envs/disco/lib/python3.13/site-packages/skimage/draw/draw.py:48: RuntimeWarning: invalid value encountered in divide\n",
      "  distances = ((r * cos_alpha + c * sin_alpha) / r_rad) ** 2 + (\n",
      "/home/***/miniconda3/envs/disco/lib/python3.13/site-packages/skimage/draw/draw.py:49: RuntimeWarning: invalid value encountered in divide\n",
      "  (r * sin_alpha - c * cos_alpha) / c_rad\n",
      "Rendering & writing: 100%|██████████| 21000/21000 [00:13<00:00, 1515.74it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved 21000 images and metadata to /home/***/disco_v2/disco_v2/data/dSprites/utils/ctf_on_main_exp\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "render_from_csv_to_h5(\"/home/***/disco_v2/disco_v2/data/dSprites/utils/ctf_on_main_exp/cf_all.csv\", \"/home/***/disco_v2/disco_v2/data/dSprites/utils/ctf_on_main_exp\", file_name=\"cf_all.h5\", resolution=64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7cd4ecef",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "disco",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
