{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import h5py\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_num = 100\n",
    "input_file = \"../data/ns_contextual/ns_random_forces.h5\"\n",
    "output_file = f\"../data/ns_contextual/ns_random_forces_top{data_num}_mu.h5\"\n",
    "# output_file = f\"../data/ns_contextual/ns_random_forces_few_shot_bottom{data_num})_mu.h5\"\n",
    "\n",
    "\n",
    "with h5py.File(input_file, 'r') as f_in:\n",
    "    mus = np.array(f_in['train']['mu'])\n",
    "    a = np.argsort(mus)\n",
    "\n",
    "    top_mu_ids = sorted(a[-data_num:])\n",
    "\n",
    "    with h5py.File(output_file, 'w') as f_out:\n",
    "        for data_name in f_in['train']:\n",
    "            if str(data_name) == 'a' or str(data_name) == 'u':\n",
    "                continue\n",
    "            # print(str(data_name))\n",
    "            ds_in = f_in['train'][data_name]\n",
    "            new_shape = list(ds_in.shape)\n",
    "            new_shape[0] = data_num\n",
    "            ds_out = f_out.create_dataset(data_name, shape=new_shape, dtype=ds_in.dtype)\n",
    "            for idx in range(data_num):\n",
    "                ds_out[idx, ...] = ds_in[top_mu_ids[idx], ...]\n",
    "        ds_in = f_in['train']['u']\n",
    "        ds_in0 = f_in['train']['a']\n",
    "        new_shape = list(ds_in.shape)\n",
    "        new_shape[0] = data_num; new_shape[-1] += 1\n",
    "        ds_out = f_out.create_dataset('u', shape=new_shape, dtype=ds_in.dtype)\n",
    "        for idx in range(data_num):\n",
    "            ds_out[idx, ..., 0] = ds_in0[top_mu_ids[idx], ...]\n",
    "            ds_out[idx, ..., 1:] = ds_in[top_mu_ids[idx], ...]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_num = 100\n",
    "input_file = \"../data/ns_contextual/ns_random_forces.h5\"\n",
    "# output_file = f\"../data/ns_contextual/ns_random_forces_top{data_num}_mu.h5\"\n",
    "output_file = f\"../data/ns_contextual/ns_random_forces_bottom{data_num}_mu.h5\"\n",
    "\n",
    "\n",
    "with h5py.File(input_file, 'r') as f_in:\n",
    "    mus = np.array(f_in['train']['mu'])\n",
    "    a = np.argsort(mus)\n",
    "\n",
    "    bottom_mu_ids = sorted(a[:data_num])\n",
    "\n",
    "    with h5py.File(output_file, 'w') as f_out:\n",
    "        for data_name in f_in['train']:\n",
    "            # print(str(data_name))\n",
    "            if str(data_name) == 'a' or str(data_name) == 'u':\n",
    "                continue\n",
    "            ds_in = f_in['train'][data_name]\n",
    "            new_shape = list(ds_in.shape)\n",
    "            new_shape[0] = data_num\n",
    "            ds_out = f_out.create_dataset(data_name, shape=new_shape, dtype=ds_in.dtype)\n",
    "            for idx in range(data_num):\n",
    "                ds_out[idx, ...] = ds_in[bottom_mu_ids[idx], ...]\n",
    "        ds_in = f_in['train']['u']\n",
    "        ds_in0 = f_in['train']['a']\n",
    "        new_shape = list(ds_in.shape)\n",
    "        new_shape[0] = data_num; new_shape[-1] += 1\n",
    "        ds_out = f_out.create_dataset('u', shape=new_shape, dtype=ds_in.dtype)\n",
    "        for idx in range(data_num):\n",
    "            ds_out[idx, ..., 0] = ds_in0[bottom_mu_ids[idx], ...]\n",
    "            ds_out[idx, ..., 1:] = ds_in[bottom_mu_ids[idx], ...]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f (100, 256, 256)\n",
      "mu (100,)\n",
      "u (100, 256, 256, 201)\n"
     ]
    }
   ],
   "source": [
    "with h5py.File(output_file, 'r') as f_out_test:\n",
    "    for data_name in f_out_test:\n",
    "        print(data_name, f_out_test[data_name].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "f (200, 256, 256)\n",
      "mu (200,)\n",
      "u (200, 256, 256, 201)\n"
     ]
    }
   ],
   "source": [
    "with h5py.File('data/ns_random_forces_1.h5', 'r') as f_out_test:\n",
    "    for data_name in f_out_test:\n",
    "        print(data_name, f_out_test[data_name].shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
