{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cd ./PDOA/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import pickle\n",
    "import numpy as np\n",
    "from lib.utilities.MORL_utils import generate_w_batch_test\n",
    "\n",
    "num_demo_per_pref = 10\n",
    "pref_interval = 0.01\n",
    "pref_set = generate_w_batch_test(None, pref_interval, 2)\n",
    "\n",
    "\n",
    "env_names=['HalfCheetah', 'Hopper', 'Ant', 'Swimmer', 'Walker2d']\n",
    "dataset_types=['expert_uniform', 'amateur_uniform']\n",
    "for env_name in env_names:\n",
    "    for dt in dataset_types:\n",
    "        dataset_path = f\"./data/d4morl/MO-{env_name}-v2_50000_{dt}.pkl\"\n",
    "        with open(dataset_path, 'rb') as f:\n",
    "            dataset = pickle.load(f)\n",
    "        dataset_prefs=np.array([x['preference'][0] for x in dataset])\n",
    "        min_pref_0, max_pref_0 = np.min(dataset_prefs[:, 0]), np.max(dataset_prefs[:, 0])\n",
    "        print(f'preference range: {[min_pref_0, 1-min_pref_0]} - {[max_pref_0, 1-max_pref_0]}')\n",
    "        \n",
    "        test_dataset = []\n",
    "        for test_pref in pref_set:\n",
    "            if test_pref[0]<min_pref_0 or test_pref[0]>max_pref_0:\n",
    "                continue\n",
    "            expert_demo = []\n",
    "            pref_distance = np.abs(dataset_prefs[:, 0]-test_pref[0])\n",
    "            idx = np.argsort(pref_distance)\n",
    "            expert_demo = [dataset[x] for x in idx[:num_demo_per_pref]]\n",
    "            test_dataset.append({'preference': test_pref, 'cost_limit': -1.0 * np.ones_like(test_pref), 'demo': expert_demo, 'selected_idx': idx[:num_demo_per_pref]})\n",
    "\n",
    "            expected_pref = np.mean(np.array([x['preference'][0] for x in expert_demo]), axis=0)\n",
    "            print(f'test pref: {test_pref}, demo: {expected_pref}, error: {np.mean(np.abs(test_pref-expected_pref))}')\n",
    "        pickle.dump(test_dataset, open(f\"./data/d4morl/test/MO-{env_name}-v2_{dt}.pkl\", 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import pickle\n",
    "import numpy as np\n",
    "from lib.utilities.MORL_utils import generate_w_batch_test\n",
    "from lib.utilities.MORL_utils import MOOfflineEnv\n",
    "import copy\n",
    "\n",
    "num_threshold=5\n",
    "num_demo_per_thres = 2\n",
    "default_min_threshold = 5\n",
    "min_threshold_dict = {\n",
    "    'OfflineAntVelocityGymnasium-v1': 20, 'OfflineHalfCheetahVelocityGymnasium-v1': 10, 'OfflineHopperVelocityGymnasium-v1': 10, \n",
    "    'OfflineSwimmerVelocityGymnasium-v1': 10, 'OfflineWalker2dVelocityGymnasium-v1': 10,\n",
    "}\n",
    "\n",
    "#env_names=['OfflineBallCircle-v0', 'OfflineCarCircle-v0', 'OfflineDroneCircle-v0', 'OfflinePointGoal1Gymnasium-v0', 'OfflineCarGoal1Gymnasium-v0']\n",
    "#env_names=['OfflineAntVelocityGymnasium-v1', 'OfflineHalfCheetahVelocityGymnasium-v1', 'OfflineHopperVelocityGymnasium-v1', 'OfflineSwimmerVelocityGymnasium-v1', 'OfflineWalker2dVelocityGymnasium-v1']\n",
    "#env_names = ['OfflineBallRun-v0', 'OfflineCarRun-v0', 'OfflineDroneRun-v0', 'OfflineAntRun-v0']\n",
    "env_names=['OfflineAntCircle-v0']\n",
    "for env_name in env_names:\n",
    "    env=MOOfflineEnv(env_name, dataset_class='safe')\n",
    "    dataset = env.get_dataset('expert_uniform')\n",
    "\n",
    "    test_dataset = []\n",
    "    \n",
    "    min_threshold = min_threshold_dict.get(env.spec.id, default_min_threshold)\n",
    "    last_thres = 0.0\n",
    "    for idx in range(num_threshold+1):\n",
    "        thres = min_threshold+(env.max_episode_cost-min_threshold)*idx/num_threshold\n",
    "        cost_return = np.array([np.sum(x['raw_rewards'][:, 1]) for x in dataset])\n",
    "        reward_return = np.array([np.sum(x['raw_rewards'][:, 0]) for x in dataset])\n",
    "\n",
    "        expert_demo = []\n",
    "        safe_traj_idx, = np.where(np.logical_and(cost_return<=thres, cost_return>last_thres))\n",
    "        assert len(safe_traj_idx)\n",
    "        high_reward_idx = np.argsort(-reward_return[safe_traj_idx])[:num_demo_per_thres]\n",
    "\n",
    "        selected_idx = safe_traj_idx[list(high_reward_idx)]\n",
    "        expert_demo = [copy.deepcopy(dataset[x]) for x in selected_idx]\n",
    "\n",
    "        test_dataset.append({'preference': np.array([1.0, 0]), 'cost_limit': np.array([-1.0, thres]), 'demo': expert_demo, 'selected_idx': selected_idx})\n",
    "        print(f'thres: {thres}, demo_cost: {np.mean(cost_return[selected_idx])}, demo_reward: {np.mean(reward_return[selected_idx])}, selected_idx: {selected_idx}')\n",
    "        last_thres = thres\n",
    "    pickle.dump(test_dataset, open(f\"./data/safe/test/{env_name}.pkl\", 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import pickle\n",
    "import numpy as np\n",
    "from lib.utilities.MORL_utils import generate_w_batch_test\n",
    "from lib.utilities.MORL_utils import MOOfflineEnv\n",
    "import copy\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "\n",
    "num_threshold=5\n",
    "num_demo_per_combo = 2\n",
    "default_min_threshold = 5\n",
    "min_threshold_dict = {\n",
    "    'CMO-Swimmer-v2': 20, 'CMO-Ant-v2': 10, 'CMO-HalfCheetah-v2': 10, 'CMO-Hopper-v2': 10, 'CMO-Walker2d-v2': 10, \n",
    "}\n",
    "\n",
    "#env_names=['OfflineBallCircle-v0', 'OfflineCarCircle-v0', 'OfflineDroneCircle-v0', 'OfflinePointGoal1Gymnasium-v0', 'OfflineCarGoal1Gymnasium-v0']\n",
    "env_names=[ 'CMO-Ant-v2', 'CMO-HalfCheetah-v2', 'CMO-Hopper-v2', 'CMO-Swimmer-v2', 'CMO-Walker2d-v2' ]#   \n",
    "#env_names = ['OfflineBallRun-v0', 'OfflineCarRun-v0', 'OfflineDroneRun-v0', 'OfflineAntRun-v0']\n",
    "\n",
    "fig=plt.figure(figsize=(18*(num_threshold+1), 16*len(env_names)))\n",
    "\n",
    "for env_id, env_name in enumerate(env_names):\n",
    "    env=MOOfflineEnv(env_name, dataset_class='cmo', num_objective=3)\n",
    "    dataset = env.get_dataset(None)\n",
    "    test_dataset = []\n",
    "    \n",
    "    min_threshold = min_threshold_dict.get(env.spec.id, default_min_threshold)\n",
    "    last_thres = -1.0\n",
    "    all_pref = np.unique(np.round([x['preference'][0] for x in dataset], 2), axis=0)\n",
    "    cost_return = np.array([np.sum(x['raw_rewards'][:, -1]) for x in dataset])\n",
    "    reward_return = np.concatenate([np.sum(x['raw_rewards'][:, 0:2], axis=0, keepdims=True) for x in dataset], axis=0)\n",
    "    traj_len = np.array([len(x['raw_rewards'][:, -1]) for x in dataset])\n",
    "    cost_limit = np.array([x['cost_limit'][0] for x in dataset])\n",
    "    preference = np.round(np.array([x['preference'][0] for x in dataset]), 2)\n",
    "\n",
    "\n",
    "    # plt.subplot(len(env_names), num_threshold+1, env_id*(num_threshold+1)+1)\n",
    "    # sc=plt.scatter(reward_return[:,0], reward_return[:,1], c=preference[:,0], cmap='tab20c', s=cost_limit/np.max(cost_limit)*50, alpha=0.3)\n",
    "    # plt.colorbar(sc)\n",
    "    # continue\n",
    "\n",
    "    # ax1 = fig.add_subplot(len(env_names), num_threshold+1, env_id*(num_threshold+1)+1,projection = \"3d\")\n",
    "    # surf = ax1.scatter(reward_return[:,0], reward_return[:,1], cost_return[:], c=preference[:,0], s=2, alpha=1.0)\n",
    "    # ax1.view_init(45, 215)\n",
    "    # continue\n",
    "    \n",
    "    approx_pref = reward_return / np.linalg.norm(reward_return, ord=1, axis=-1, keepdims=True)\n",
    "    sub_trajs, sub_traj_idx = [], []\n",
    "    for idx in range(num_threshold+1):\n",
    "        thres = min_threshold+(env.max_episode_cost-min_threshold)*idx/num_threshold\n",
    "        safe_traj_idx = np.logical_and(cost_return<=thres, cost_return>last_thres)\n",
    "\n",
    "    # all_cost_limit = np.unique(np.round([x['cost_limit'][0] for x in dataset], 2), axis=0)\n",
    "    # for idx in range(num_threshold+1):\n",
    "    #     thres = all_cost_limit[idx] \n",
    "    #     safe_traj_idx = np.array([np.round(x['cost_limit'][0], 2)==all_cost_limit[idx] for x in dataset])\n",
    "\n",
    "        utility = np.sum(reward_return[safe_traj_idx]*preference[safe_traj_idx], axis=1)\n",
    "        print(np.mean(utility))\n",
    "\n",
    "        plt.subplot(len(env_names), num_threshold+1, env_id*(num_threshold+1)+idx+1)\n",
    "        # sc=plt.scatter(reward_return[safe_traj_idx,0], reward_return[safe_traj_idx,1], c=preference[safe_traj_idx,0], s=1)\n",
    "        # plt.colorbar(sc)\n",
    "        plt.title(f'thres: {thres}, num: {np.sum(safe_traj_idx)}, avg_cost: {np.mean(cost_return[safe_traj_idx])}, violated_num: {np.sum(cost_return[safe_traj_idx]>thres)}')\n",
    "        #plt.scatter(cost_limit, cost_return)\n",
    "        #plt.scatter(preference[:, 0], approx_pref[:,0])\n",
    "\n",
    "        for pref in all_pref:\n",
    "            expert_demo = []\n",
    "            safe_pref_traj_idx = np.logical_and(safe_traj_idx, (preference==pref).all(axis=1))\n",
    "\n",
    "            rew_mean = np.mean(reward_return[safe_pref_traj_idx], axis=0, keepdims=True)\n",
    "            distance = np.argsort(np.sum((reward_return[safe_pref_traj_idx]-rew_mean)**2, axis=1))\n",
    "            plot_idx = safe_pref_traj_idx #np.where(safe_pref_traj_idx)[0][distance[:int(len(distance)*0.5)]]\n",
    "\n",
    "            #safe_pref_traj_idx = np.where(np.logical_and(safe_pref_traj_idx, (cost_return<=all_cost_limit[idx])))[0]\n",
    "            safe_pref_traj_idx = np.where(safe_pref_traj_idx)[0]\n",
    "            if len(safe_pref_traj_idx)<num_demo_per_combo:\n",
    "                print(f'no demo for thres: {thres}, preference: {pref}')\n",
    "                continue\n",
    "\n",
    "            sc = plt.scatter(reward_return[plot_idx,0], reward_return[plot_idx,1], c=preference[plot_idx,0], label=f'{pref[0]}, cost: {np.mean(cost_return[plot_idx])}', cmap='tab20c')\n",
    "            plt.clim(0.4, 1.0)\n",
    "            plt.legend()\n",
    "\n",
    "            # utility = np.sum(reward_return[safe_pref_traj_idx]*pref.reshape(1, -1), axis=1)\n",
    "            # high_reward_idx = np.argsort(-utility)[:num_demo_per_combo]\n",
    "            \n",
    "            high_reward_idx = np.random.randint(0, len(safe_pref_traj_idx), size=2)\n",
    "\n",
    "            selected_idx = safe_pref_traj_idx[list(high_reward_idx)]\n",
    "            expert_demo = [copy.deepcopy(dataset[x]) for x in selected_idx]\n",
    "\n",
    "            test_dataset.append({'preference': np.array([pref[0], pref[1], 0]), 'cost_limit': np.array([0, 0, thres]), 'demo': expert_demo, 'selected_idx': selected_idx})\n",
    "            print(f'thres: {thres}, preference: {np.array([pref[0], pref[1], 0])}, \\\n",
    "                  demo_cost: {np.mean(cost_return[selected_idx])}, demo_utility: {np.mean(np.sum(reward_return[selected_idx]*pref.reshape(1, -1), axis=1))}, selected_idx: {selected_idx}')\n",
    "        last_thres = thres\n",
    "        plt.colorbar(sc)\n",
    "        print('-'*200)\n",
    "    print('*'*200)    \n",
    "    if not os.path.exists(f\"./data/cmo/test\"):\n",
    "        os.makedirs(f\"./data/cmo/test\")\n",
    "\n",
    "plt.savefig('../pareto.jpg')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PRMORL",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
