{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "356a31de-452c-4919-8f5f-18609115e3fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pprint\n",
    "from copy import deepcopy\n",
    "\n",
    "os.makedirs(\"walker_benchmark_configs\", exist_ok=True)\n",
    "\n",
    "def iterativeOverride(datadict, override_key, override_value, mult=True):\n",
    "    # deepcopy before data manipulation\n",
    "    newdict = deepcopy(datadict)\n",
    "    \n",
    "    if isinstance(datadict, dict):\n",
    "        for key,value in datadict.items():\n",
    "            # recurse into nested dicts\n",
    "            if isinstance(value, dict):\n",
    "                newdict[key] = iterativeOverride(datadict[key], override_key, override_value, mult=mult)\n",
    "            # convert to string\n",
    "            elif isinstance(value, str) and key == override_key: \n",
    "                if mult:\n",
    "                    newdict[key] = float(newdict[key]) * override_value\n",
    "                else:\n",
    "                    newdict[key] = override_value\n",
    "    elif isinstance(datadict, list):\n",
    "        for i in len(datadict):\n",
    "            newdict[i] = iterativeOverride(datadict[i], override_key, override_value, mult=mult)\n",
    "\n",
    "    return newdict\n",
    "\n",
    "with open('default_walker_mujoco.xml', 'r') as in_xml:\n",
    "    default_xml = in_xml.read()\n",
    "\n",
    "n = 5\n",
    "for i in range(n):\n",
    "    # Global config values\n",
    "    gravity_x = random.uniform(0, 0)\n",
    "    gravity_z = np.linspace(-0.5, -4, n+1)[i]\n",
    "    wind = \" \".join([f\"{random.uniform(0, 1.0):.2f}\" for _ in range(3)])\n",
    "    density = random.uniform(0, 500)\n",
    "    \n",
    "    # Walker specific config values\n",
    "    #pprint.pprint(config_file[\"mujoco\"][\"worldbody\"])\n",
    "    left_body_friction = random.uniform(0.01, 50)\n",
    "    right_body_friction = random.uniform(0.01, 50)\n",
    "    \n",
    "    left_body_scaling = random.uniform(1.5, 2.5)\n",
    "    right_body_scaling = random.uniform(1.5, 2.5)\n",
    "    \n",
    "    config_file = xmltodict.parse(default_xml)\n",
    "    config_file[\"mujoco\"][\"option\"][\"@gravity\"] = f\"{gravity_x:.2f} 0 {gravity_z:.2f}\"\n",
    "    config_file[\"mujoco\"][\"option\"][\"@wind\"] = wind\n",
    "    config_file[\"mujoco\"][\"option\"][\"@density\"] = f\"{density:.2f}\"\n",
    "    \n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0], \"@friction\", right_body_friction, mult=True)\n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1], \"@friction\", left_body_friction, mult=True)\n",
    "    \n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0], \"@size\", right_body_scaling, mult=True)\n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1], \"@size\", left_body_scaling, mult=True)\n",
    "\n",
    "\n",
    "    with open(os.path.join(\"walker_benchmark_configs\", \"adapted_config_\"+str(i)+\".xml\"), \"w\") as f:\n",
    "        f.write(xmltodict.unparse(config_file, pretty=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1cccfd8c-7f82-4735-a635-d6ec83756e28",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import random\n",
    "import pprint\n",
    "import numpy as np\n",
    "import xmltodict\n",
    "from copy import deepcopy\n",
    "\n",
    "os.makedirs(\"ant_benchmark_configs\", exist_ok=True)\n",
    "\n",
    "def iterativeOverride(datadict, override_key, override_value, mult=True):\n",
    "    # deepcopy before data manipulation\n",
    "    newdict = deepcopy(datadict)\n",
    "    \n",
    "    if isinstance(datadict, dict):\n",
    "        for key,value in datadict.items():\n",
    "            # recurse into nested dicts\n",
    "            if isinstance(value, dict):\n",
    "                newdict[key] = iterativeOverride(datadict[key], override_key, override_value, mult=mult)\n",
    "            # convert to string\n",
    "            elif isinstance(value, str) and key == override_key: \n",
    "                if mult:\n",
    "                    newdict[key] = float(newdict[key]) * override_value\n",
    "                else:\n",
    "                    newdict[key] = override_value\n",
    "    elif isinstance(datadict, list):\n",
    "        for i in len(datadict):\n",
    "            newdict[i] = iterativeOverride(datadict[i], override_key, override_value, mult=mult)\n",
    "\n",
    "    return newdict\n",
    "\n",
    "with open('default_ant_mujoco.xml', 'r') as in_xml:\n",
    "    default_xml = in_xml.read()\n",
    "\n",
    "n = 5\n",
    "for i in range(n):\n",
    "    # Global config values\n",
    "    gravity_x = random.uniform(0, 0)\n",
    "    gravity_z = np.linspace(-0.5, -4, n+1)[i]\n",
    "    wind = \" \".join([f\"{random.uniform(0, 1.0):.2f}\" for _ in range(3)])\n",
    "    density = random.uniform(0, 500)\n",
    "    \n",
    "    # Walker specific config values\n",
    "    #pprint.pprint(config_file[\"mujoco\"][\"worldbody\"])\n",
    "    left_body_friction = random.uniform(0.01, 50)\n",
    "    right_body_friction = random.uniform(0.01, 50)\n",
    "    \n",
    "    left_front_body_scaling = random.uniform(1.25, 1.75)\n",
    "    right_front_body_scaling = random.uniform(1.25, 1.75)\n",
    "    left_back_body_scaling = random.uniform(1.25, 1.75)\n",
    "    right_back_body_scaling = random.uniform(1.25, 1.75)\n",
    "    \n",
    "    \n",
    "    config_file = xmltodict.parse(default_xml)\n",
    "    config_file[\"mujoco\"][\"option\"][\"@gravity\"] = f\"{gravity_x:.2f} 0 {gravity_z:.2f}\"\n",
    "    config_file[\"mujoco\"][\"option\"][\"@wind\"] = wind\n",
    "    config_file[\"mujoco\"][\"option\"][\"@density\"] = f\"{density:.2f}\"\n",
    "    \n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0], \"@friction\", right_body_friction, mult=True)\n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1], \"@friction\", left_body_friction, mult=True)\n",
    "    \n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0], \"@size\", right_front_body_scaling, mult=True)\n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1], \"@size\", left_front_body_scaling, mult=True)\n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][2] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][2], \"@size\", right_back_body_scaling, mult=True)\n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][3] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][3], \"@size\", right_back_body_scaling, mult=True)\n",
    "\n",
    "\n",
    "    with open(os.path.join(\"ant_benchmark_configs\", \"adapted_config_\"+str(i)+\".xml\"), \"w\") as f:\n",
    "        f.write(xmltodict.unparse(config_file, pretty=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e5b24dcd-9726-429b-91ad-7df7c500804c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['0.046', '.145']\n",
      "['0.046', '.15']\n",
      "['0.046', '.094']\n",
      "['0.046', '.133']\n",
      "['0.046', '.106']\n",
      "['0.046', '.07']\n",
      "['0.046', '.145']\n",
      "['0.046', '.15']\n",
      "['0.046', '.094']\n",
      "['0.046', '.133']\n",
      "['0.046', '.106']\n",
      "['0.046', '.07']\n",
      "['0.046', '.145']\n",
      "['0.046', '.15']\n",
      "['0.046', '.094']\n",
      "['0.046', '.133']\n",
      "['0.046', '.106']\n",
      "['0.046', '.07']\n",
      "['0.046', '.145']\n",
      "['0.046', '.15']\n",
      "['0.046', '.094']\n",
      "['0.046', '.133']\n",
      "['0.046', '.106']\n",
      "['0.046', '.07']\n",
      "['0.046', '.145']\n",
      "['0.046', '.15']\n",
      "['0.046', '.094']\n",
      "['0.046', '.133']\n",
      "['0.046', '.106']\n",
      "['0.046', '.07']\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import random\n",
    "import pprint\n",
    "import numpy as np\n",
    "import xmltodict\n",
    "from copy import deepcopy\n",
    "\n",
    "os.makedirs(\"cheetah_benchmark_configs\", exist_ok=True)\n",
    "\n",
    "def iterativeOverride(datadict, override_key, override_value, mult=True):\n",
    "    # deepcopy before data manipulation\n",
    "    newdict = deepcopy(datadict)\n",
    "    \n",
    "    if isinstance(datadict, dict):\n",
    "        for key,value in datadict.items():\n",
    "            # recurse into nested dicts\n",
    "            if isinstance(value, dict):\n",
    "                newdict[key] = iterativeOverride(datadict[key], override_key, override_value, mult=mult)\n",
    "            # convert to string\n",
    "            elif isinstance(value, str) and key == override_key: \n",
    "                if mult:\n",
    "                    if isinstance(value, str) and len(value.split(\" \")) > 1:\n",
    "                        single_vals = value.split(\" \")\n",
    "                        print(single_vals)\n",
    "                        single_vals = [str(float(single_val) * override_value) for single_val in single_vals]\n",
    "                        newdict[key] = \" \".join(single_vals)\n",
    "                    else:\n",
    "                        newdict[key] = float(newdict[key]) * override_value\n",
    "                else:\n",
    "                    newdict[key] = override_value\n",
    "    elif isinstance(datadict, list):\n",
    "        for i in len(datadict):\n",
    "            newdict[i] = iterativeOverride(datadict[i], override_key, override_value, mult=mult)\n",
    "\n",
    "    return newdict\n",
    "\n",
    "with open('default_cheetah_mujoco.xml', 'r') as in_xml:\n",
    "    default_xml = in_xml.read()\n",
    "\n",
    "n = 5\n",
    "for i in range(n):\n",
    "    # Global config values\n",
    "    gravity_x = random.uniform(0, 0)\n",
    "    gravity_z = np.linspace(-0.5, -4, n+1)[i]\n",
    "    wind = \" \".join([f\"{random.uniform(0, 1.0):.2f}\" for _ in range(3)])\n",
    "    density = random.uniform(0, 500)\n",
    "    \n",
    "    # Walker specific config values\n",
    "    #pprint.pprint(config_file[\"mujoco\"][\"worldbody\"])\n",
    "    left_body_friction = random.uniform(0.01, 50)\n",
    "    right_body_friction = random.uniform(0.01, 50)\n",
    "    \n",
    "    top_body_scaling = random.uniform(1.5, 2.5)\n",
    "    back_body_scaling = random.uniform(1.5, 2.5)\n",
    "    \n",
    "    \n",
    "    config_file = xmltodict.parse(default_xml)\n",
    "    config_file[\"mujoco\"][\"option\"][\"@gravity\"] = f\"{gravity_x:.2f} 0 {gravity_z:.2f}\"\n",
    "    config_file[\"mujoco\"][\"option\"][\"@wind\"] = wind\n",
    "    config_file[\"mujoco\"][\"option\"][\"@density\"] = f\"{density:.2f}\"\n",
    "    \n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0], \"@friction\", right_body_friction, mult=True)\n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1], \"@friction\", left_body_friction, mult=True)\n",
    "    \n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][0], \"@size\", top_body_scaling, mult=True)\n",
    "    config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1] = iterativeOverride(config_file[\"mujoco\"][\"worldbody\"][\"body\"][\"body\"][1], \"@size\", back_body_scaling, mult=True)\n",
    "\n",
    "    with open(os.path.join(\"cheetah_benchmark_configs\", \"adapted_config_\"+str(i)+\".xml\"), \"w\") as f:\n",
    "        f.write(xmltodict.unparse(config_file, pretty=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2707676-98d1-47ef-b616-72e8b280c3c5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (va-for-rl)",
   "language": "python",
   "name": "venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
