{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "63237658-7728-4244-a53a-6998a259613e",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "806c1a86-038c-485e-8760-d393f1ac485f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "base = \"xx/research/2025_mip/\"\n",
    "sys.path.append(base)\n",
    "sys.path.append(os.path.join(base, 'forge'))\n",
    "\n",
    "from forge.forge import Forge\n",
    "from forge.utils import *\n",
    "import subprocess\n",
    "\n",
    "try:\n",
    "    from gurobi_onboarder import init_gurobi\n",
    "    gurobi_venv, GUROBI_FOUND = init_gurobi.initialize_gurobi()\n",
    "except:\n",
    "    gurobi_venv = gp.Env(empty=True)\n",
    "\n",
    "import json"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "519bd66a-5843-4d4b-b1a5-06f717432b4e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "base = '2025_mip/data/gap_data_example/'\n",
    "mps_base = 'OptiGuide/milp-evolve/src/multi_class_learning/instances/tab1_compressed/'\n",
    "model = Forge(prob_head=True, cut_head=True)\n",
    "model.load_model(os.path.join('research/2025_mip/', 'models/lp_gap_model_pq.pkl'), model_type='lp_gap')\n",
    "# res = {}\n",
    "with open('./res.pkl', 'rb') as file: \n",
    "    res = pkl.load(file)\n",
    "    \n",
    "for cat in tqdm(os.listdir(os.path.join(base, 'tab1'))[100:]):\n",
    "    p_count = 0\n",
    "    for prob in tqdm(os.listdir(os.path.join(base, 'tab1', cat))):\n",
    "        if 1:\n",
    "            if 'json' in prob:\n",
    "    \n",
    "                if cat not in os.listdir('tab1'):\n",
    "                    # Untar files \n",
    "                    subprocess.run('tar -xvf ' + mps_base + cat + '.tar.gz', shell = True)\n",
    "                try:\n",
    "                    size = os.path.getsize('tab1/' + cat + '/' + prob.split('.')[0].split('_')[1] + '.mps.gz')\n",
    "                    if os.path.join('tab1', cat, prob) not in res and size <= 800000:\n",
    "                        \n",
    "                        \n",
    "                        _, cut_ratio, _ = model.mip_to_lp_cut(mip_instance_path='tab1/' + cat + '/' + prob.split('.')[0].split('_')[1] + '.mps.gz',\n",
    "                                                                         prob_type='SC',\n",
    "                                                                         return_metadata=True,\n",
    "                                                                         threads = 1)\n",
    "            \n",
    "                        with open(os.path.join(base, 'tab1', cat, prob), 'r') as file: \n",
    "                            d = json.loads(file.read())\n",
    "            \n",
    "                        \n",
    "                        res[os.path.join('tab1', cat, prob)] = {}\n",
    "                        res[os.path.join('tab1', cat, prob)]['forge'] = cut_ratio\n",
    "                        res[os.path.join('tab1', cat, prob)]['lp_ip_gap'] = d['lp_ip_gap']\n",
    "                        res[os.path.join('tab1', cat, prob)]['lp_ip_gap_v2'] = d['lp_ip_gap_v2']\n",
    "                        res[os.path.join('tab1', cat, prob)]['lp'] = d['lp_value']\n",
    "                        res[os.path.join('tab1', cat, prob)]['ip'] = d['obj_value']\n",
    "                    p_count += 1\n",
    "                    \n",
    "                except Exception as e: \n",
    "                    print (e)\n",
    "            else: \n",
    "                None\n",
    "            print (\"Num Instances: \", len(res))\n",
    "    with open('./res.pkl', 'wb') as file: \n",
    "        pkl.dump(res, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c22e01f8-716c-403f-a048-f0c36bf01792",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../../data/intermediate_files/res.pkl', 'rb') as file:\n",
    "    res = pkl.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78564f83-b1a1-4d63-956f-67331fba7faa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "mae = []\n",
    "res_dict = {}\n",
    "for t, x in res.items():\n",
    "    ip = abs(x['ip'])\n",
    "    lp = abs(x['lp'])\n",
    "    if max(ip, lp) != 0:\n",
    "        rat = min(ip, lp) / max(ip, lp)\n",
    "        mae.append(abs(x['forge'] - rat))\n",
    "        try:\n",
    "            res_dict[t.split('/')[1].split('_')[1]].append(abs(x['forge'] - rat))\n",
    "        except: \n",
    "            res_dict[t.split('/')[1].split('_')[1]] = [abs(x['forge'] - rat)]\n",
    "\n",
    "type_means = np.array([np.mean(x) for x in res_dict.values()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c403690-a21d-4dec-8af0-a58e3983451f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_types = np.array(sorted([(x, np.mean(y)) for x, y in res_dict.items()], key = lambda x: x[1], reverse = True))\n",
    "t_names = sorted_types[:, 0]\n",
    "t_mae = sorted_types[:, 1].astype(float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca637a29-6319-49a9-b211-059dcadf887b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (15, 2))\n",
    "plt.bar(x = range(len(t_mae[100:])), height = t_mae[100:].astype(float))\n",
    "plt.xticks(range(len(t_mae[100:])), t_names[100:], rotation = 90, fontsize = 3)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa888e06-8ad5-4683-b0ed-9c6095e780f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.mean(t_mae)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2c21775-bfe2-454b-a321-eed3c55ed92f",
   "metadata": {},
   "outputs": [],
   "source": [
    "type_count = {}\n",
    "for x in t_names:\n",
    "    try:\n",
    "        type_count[x.split('-')[1]] += 1\n",
    "    except: \n",
    "        type_count[x.split('-')[1]] = 1\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05228ff1-092f-4992-a840-63aa3e7247c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "sum(type_count.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9474d74f-6079-448e-84f7-8dfa5439b7ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(set([x.split('-')[0] for x in t_names]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baa7606a-97f2-4c05-9730-d7b0b6afd559",
   "metadata": {},
   "outputs": [],
   "source": [
    "inst = 0\n",
    "for n in t_names:\n",
    "    inst += len(res_dict[n])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "287348f4-e040-4d5d-9a46-0aa72fd7495f",
   "metadata": {},
   "outputs": [],
   "source": [
    "inst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adaa5172-b4f5-4ac9-a591-76ba28acddd3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "np.unique([x.split('-')[1] for x in t_names], return_counts = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09843c90-131c-4b6e-95b5-2ff821f4cfe2",
   "metadata": {},
   "source": [
    "# Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e5890db-3ab8-482c-8a0a-20b05e2c9ace",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "base = 'research/2025_mip/data/gap_data_example/'\n",
    "mps_base = 'OptiGuide/milp-evolve/src/multi_class_learning/instances/tab1_compressed/'\n",
    "model = Forge(prob_head=True, cut_head=True)\n",
    "model.load_model(os.path.join('/research/2025_mip/', 'models/lp_gap_model_pq.pkl'), model_type='lp_gap')\n",
    "\n",
    "res = {}\n",
    "with open('./res_viz.pkl', 'rb') as file: \n",
    "    res = pkl.load(file)\n",
    "    \n",
    "for cat in tqdm(os.listdir(os.path.join(base, 'tab1'))):\n",
    "    for prob in tqdm(os.listdir(os.path.join(base, 'tab1', cat))[:10]):\n",
    "        if 'json' in prob:\n",
    "            if cat not in os.listdir('tab1'):\n",
    "                continue\n",
    "            try:\n",
    "                size = os.path.getsize('tab1/' + cat + '/' + prob.split('.')[0].split('_')[1] + '.mps.gz')\n",
    "                if os.path.join('tab1', cat, prob) not in res and size <= 800000:\n",
    "\n",
    "                    mip_vec, _, _ = model.mip_to_vector(mip_instance = 'tab1/' + cat + '/' + prob.split('.')[0].split('_')[1] + '.mps.gz', gnn_model_path = None)\n",
    "        \n",
    "                    \n",
    "                    res[os.path.join('tab1', cat, prob)] = {}\n",
    "                    res[os.path.join('tab1', cat, prob)]['mip_vec'] = mip_vec\n",
    "                \n",
    "            except Exception as e: \n",
    "                print (e)\n",
    "                \n",
    "            print (\"Num Instances: \", len(res))\n",
    "            \n",
    "    with open('./res_viz.pkl', 'rb') as file: \n",
    "        res_ = pkl.load(file)\n",
    "\n",
    "    with open('./res_viz.pkl', 'wb') as file: \n",
    "        pkl.dump(res | res_, file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f433338-e160-4ac6-8ab7-5c6bf558cff6",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('MIPEmbed/data/large_files/d_mip_processed_no_graph.pkl', 'rb') as file: \n",
    "    mip_to_dgl = pkl.load(file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d202841-a960-414c-a589-da84b402aed6",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf79420f-bbb2-4264-9bf4-77571183cd68",
   "metadata": {},
   "outputs": [],
   "source": [
    "dmip_mat = []\n",
    "\n",
    "sampled = np.random.choice(list(mip_to_dgl.keys()), 100)\n",
    "color_vec = [\"-\".join(x.split('-')[:-1]) for x in sampled]\n",
    "\n",
    "for inst in tqdm(sampled):\n",
    "\n",
    "\n",
    "    # Process LP or MPS files or a Gurobi object into DGL format\n",
    "    g, features, num_cons, num_vars = mip_to_dgl[inst]\n",
    "\n",
    "    # Forward Pass Through GNN\n",
    "    h_list, logits, loss, distances, codebook_ = model.forward(g.to(device), features.to(device), num_cons, num_vars)\n",
    "\n",
    "    # Compute a Vector for Each MIP Instance\n",
    "    # This Vector is a Distribution of the Codes that Constraints and Variables\n",
    "    # in the MIP Instance Belong to.\n",
    "    assigned_codes = torch.argmin(distances, axis=1).detach().cpu().numpy()\n",
    "    mip_vec = np.zeros(model.codebook_size,)\n",
    "    for c in assigned_codes:\n",
    "        mip_vec[c] += 1\n",
    "\n",
    "    dmip_mat.append(mip_vec)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8e2df02-216c-4192-9b87-f0814ad55c9b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df8b33e7-f1e4-4132-8506-d4c9398e44ad",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "type_col = []\n",
    "for i in color_vec:\n",
    "    if 'SC' in i:\n",
    "        type_col.append('purple')\n",
    "    elif 'GISP' in i:\n",
    "        type_col.append('black')\n",
    "    elif 'CA' in i:\n",
    "        type_col.append('yellow')\n",
    "    elif 'MIS' in i:\n",
    "        type_col.append('blue')\n",
    "    elif 'MVC' in i:\n",
    "        type_col.append('grey')\n",
    "    elif 'IP' in i:\n",
    "        type_col.append('maroon')\n",
    "    elif 'MIRP' in i:\n",
    "        type_col.append('magenta')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0531101a-aaa0-4bb4-b4c0-ab1232d46768",
   "metadata": {},
   "outputs": [],
   "source": [
    "names = [x.split('/')[1] for x in res.keys()]\n",
    "difficulties = [x.split('-')[1] for x in names]\n",
    "types = [int(x.split('-')[0].strip('milp_')) for x in names]\n",
    "\n",
    "diff_col = []\n",
    "\n",
    "for i in names: \n",
    "    if 'easy' in i: \n",
    "        diff_col.append('green')\n",
    "    elif 'medium' in i: \n",
    "        diff_col.append('orange')\n",
    "    else: \n",
    "        diff_col.append('red')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d4c021d-116e-472f-af56-d0d0cda9e6cd",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf5ccea9-76b5-4d3e-b1b9-4e3c383d339d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pacmap\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "with open('./res_viz.pkl', 'rb') as file: \n",
    "    res = pkl.load(file)\n",
    "print (len(res))\n",
    "mip_mat = np.array([x['mip_vec'] for x in res.values()])\n",
    "mip_mat = np.vstack([mip_mat, dmip_mat])\n",
    "\n",
    "# p = PCA(n_components = 2).fit_transform(mip_mat)\n",
    "p = TSNE(n_components = 2, init = 'pca').fit_transform(mip_mat)\n",
    "# p = pacmap.PaCMAP(n_components = 2, n_neighbors=2, MN_ratio=0.5, FP_ratio=2.0).fit_transform(mip_mat, init = 'pca')\n",
    "\n",
    "p = (p - np.min(p)) / np.ptp(p)\n",
    "\n",
    "plt.scatter(p[:, 0], p[:, 1], s = [1 if idx < len(res) else 10 for idx in range(len(mip_mat))], c = diff_col + type_col)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fb47ced-94b0-4ab8-a253-d4a086277e2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('./res_viz.pkl', 'rb') as file: \n",
    "    res = pkl.load(file)\n",
    "print (len(res))\n",
    "mip_mat = np.array([x['mip_vec'] for x in res.values()])\n",
    "\n",
    "# p = PCA(n_components = 2).fit_transform(mip_mat)\n",
    "p = TSNE(n_components = 2, init = 'pca').fit_transform(mip_mat)\n",
    "# p = pacmap.PaCMAP(n_components = 2, n_neighbors=2, MN_ratio=0.5, FP_ratio=2.0).fit_transform(mip_mat, init = 'pca')\n",
    "\n",
    "p = (p - np.min(p)) / np.ptp(p)\n",
    "\n",
    "cmap = plt.get_cmap('gist_rainbow')\n",
    "\n",
    "plt.scatter(p[:, 0], p[:, 1], s = 1, c = cmap(types))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e551131-953e-42d7-b423-b0bc283de8bd",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gpu",
   "language": "python",
   "name": "gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
