{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cross_decomposition import CCA\n",
    "import pickle\n",
    "from pathlib import Path\n",
    "from omegaconf import OmegaConf\n",
    "import h5py\n",
    "import numpy as np\n",
    "import sys\n",
    "\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "from ours.utils.dataset_utils import get_dataset\n",
    "from utils.utils import process_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_alignment_score(cca, source_observations):\n",
    "    target_observations_trans = source_observations[:, ::-1]\n",
    "    _, target_latents = cca.transform(source_observations[:2], target_observations_trans)\n",
    "    source_observations_recon = cca.inverse_transform(target_latents)\n",
    "\n",
    "    alignment_score = np.mean(np.square(source_observations - source_observations_recon))\n",
    "    return alignment_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[6.509549592858533, 0.008293260827176978, 6.43685045035306, 0.010337459874083708, 4.34599011355219, 0.32325489679047625, 0.006225803074868486, 0.02498792984229289, 0.21831092730615298]\n",
      "1.987 \\pm 2.735\n"
     ]
    }
   ],
   "source": [
    "root_dirs = [\n",
    "    Path(\"results/eval_m2m_umaze_1\"),\n",
    "    Path(\"results/eval_m2m_umaze_4\"),\n",
    "    Path(\"results/eval_m2m_umaze_7\"),\n",
    "    Path(\"results/eval_m2m_medium_7\"),\n",
    "    Path(\"results/eval_m2m_medium_18\"),\n",
    "    Path(\"results/eval_m2m_medium_11\"),\n",
    "]\n",
    "\n",
    "config = OmegaConf.load(root_dirs[0] / \"config.yaml\")\n",
    "OmegaConf.resolve(config)\n",
    "args = process_args(\n",
    "    config,\n",
    "    phase=\"align\",\n",
    "    inference_task_ids=config.inference_task_ids,\n",
    ")\n",
    "source_dataset = get_dataset(\n",
    "    dataset_path=args.source_dataset,\n",
    "    task_ids=args.task_ids,\n",
    "    transform_observations=args.reverse_source_observations,\n",
    "    transform_actions=args.reverse_source_actions,\n",
    ")\n",
    "\n",
    "target_dataset = get_dataset(\n",
    "    dataset_path=args.target_dataset,\n",
    "    task_ids=args.task_ids,\n",
    "    transform_observations=args.reverse_target_observations,\n",
    "    transform_actions=args.reverse_target_actions,\n",
    ")\n",
    "\n",
    "source_observations = source_dataset[\"observations\"]\n",
    "target_observations = target_dataset[\"observations\"]\n",
    "\n",
    "alignment_scores = []\n",
    "for root_dir in sorted(root_dirs):\n",
    "    for logdir in sorted(root_dir.glob(\"*\")):\n",
    "        if logdir.is_dir():\n",
    "            cca = pickle.load(open(logdir / \"cca.pkl\", \"rb\"))\n",
    "            alignment_score = calc_alignment_score(cca, source_observations)\n",
    "            alignment_scores.append(alignment_score)\n",
    "\n",
    "mean = np.mean(alignment_scores)\n",
    "std = np.std(alignment_scores)\n",
    "\n",
    "print(alignment_scores)\n",
    "print(f\"{mean:.3f} \\pm {std:.3f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.7 64-bit ('3.9.7')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "07d7f74e6c90c4093dd1f2406fde2cbe7f63ecc267ad9bcb55318c81f59bb042"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
