{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OpenVINO™ Device-Placement-Optimization-with-Reinforcement-Learning_final"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import openvino as ov\n",
    "import torch\n",
    "\n",
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)  # For all GPUs\n",
    "\n",
    "set_seed(42)\n",
    "\n",
    "core = ov.Core()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_available_device(core):\n",
    "    devices = core.available_devices\n",
    "    number_of_device = 0\n",
    "    for device in devices:\n",
    "        device_name = core.get_property(device, \"FULL_DEVICE_NAME\")\n",
    "        print(f\"{device}: {device_name}\")\n",
    "        number_of_device += 1\n",
    "        \n",
    "    return devices, number_of_device\n",
    "\n",
    "devices,number_of_device = get_available_device(core)\n",
    "print(devices)\n",
    "devices[1] = devices[-1]\n",
    "print(devices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create networkx graph of the given xml model address"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import xml.etree.ElementTree as ET\n",
    "import networkx as nx\n",
    "import os\n",
    "from pathlib import Path\n",
    "import shutil\n",
    "\n",
    "\n",
    "def xml2graph(xml):\n",
    "    running_model_fname = f\"./Model/\"\n",
    "    \n",
    "    print(running_model_fname)\n",
    "\n",
    "    # Save the model to path specified.\n",
    "    if os.path.isdir(running_model_fname):\n",
    "        print(f\"{running_model_fname} exists already. Deleting the folder\")\n",
    "        shutil.rmtree(running_model_fname)\n",
    "    os.mkdir(running_model_fname)\n",
    "    \n",
    "    ov.save_model(xml, running_model_fname+'resnet50_running.xml')\n",
    "    \n",
    "    xml_path=running_model_fname+'resnet50_running.xml'\n",
    "\n",
    "    tree = ET.parse(xml_path)\n",
    "    root = tree.getroot()\n",
    "\n",
    "    G = nx.DiGraph()\n",
    "\n",
    "    for layer in root.find('layers'):\n",
    "        layer_id = layer.get('id')\n",
    "        layer_name = layer.get('name')\n",
    "        layer_type = layer.get('type')\n",
    "        data_element = layer.find('data')\n",
    "        print(layer_name)\n",
    "        if data_element is not None:\n",
    "            # Get the 'shape' attribute from the 'data' element\n",
    "            data_shape = data_element.get('shape')\n",
    "            print(data_shape)\n",
    "        G.add_node(layer_id, name=layer_name, type=layer_type, shape=data_element)\n",
    "\n",
    "    for edge in root.find('edges'):\n",
    "        from_layer = edge.get('from-layer')\n",
    "        to_layer = edge.get('to-layer')\n",
    "\n",
    "        G.add_edge(from_layer, to_layer)\n",
    "        \n",
    "    print(G.number_of_nodes())\n",
    "\n",
    "    return G"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the resnet 50 model to the ov core and get the computation graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import openvino as ov\n",
    "import torch\n",
    "from torchvision.models import resnet50\n",
    "\n",
    "\n",
    "# prepare input_data\n",
    "input_data = torch.rand(1, 3, 224, 224)\n",
    "xml_file = '/your/resnet50orinceptionv3.xml'\n",
    "\n",
    "ov_model = core.read_model(model=xml_file)\n",
    "input_layer = ov_model.input()\n",
    "\n",
    "###### Option 1: Save to OpenVINO IR:\n",
    "\n",
    "# save model to OpenVINO IR for later use\n",
    "# ov.save_model(ov_model, 'resnet50.xml')\n",
    "\n",
    "###### Option 2: Compile and infer with OpenVINO:\n",
    "\n",
    "# compile model\n",
    "compiled_model = ov.compile_model(ov_model,device_name=\"HETERO:GPU.1,CPU\")\n",
    "# compiled_model = ov.compile_model(ov_model,device_name=\"CPU\")\n",
    "\n",
    "infer_request = compiled_model.create_infer_request()\n",
    "infer_request.infer(inputs={input_layer.any_name: input_data})\n",
    "# run inference\n",
    "result = infer_request.results\n",
    "cm = infer_request.get_compiled_model()\n",
    "# print(cm)\n",
    "runtime_model = cm.get_runtime_model()\n",
    "ops = ov_model.get_ordered_ops()\n",
    "# ops = runtime_model.get_ordered_ops()\n",
    "# print(len(ops))\n",
    "# Computation_G = xml2graph(runtime_model)\n",
    "Computation_G = xml2graph(ov_model)\n",
    "print(len(ops))\n",
    "ops_ov = ov_model.get_ordered_ops()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "op_type_name = []\n",
    "op_info = {}\n",
    "for op in ops_ov:\n",
    "    op_attributes = op.get_attributes()\n",
    "    # print(f\"{op}: {op_attributes}\")\n",
    "    print(op.get_friendly_name())\n",
    "    # print(op.get_input_size)\n",
    "    print(op.get_type_name())\n",
    "    op_type_name.append(op.get_type_name())\n",
    "    op_info[op.get_friendly_name()] = op.get_type_name()\n",
    "    print(f\"{op.get_friendly_name()}: {op.get_type_name()}\")\n",
    "unique_list = list(set(op_type_name))\n",
    "print(unique_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "\n",
    "# Create a directed graph\n",
    "G = nx.DiGraph()\n",
    "\n",
    "# Add nodes and edges\n",
    "# Assuming each tuple in edges is (X, Y) where X's output is consumed by Y\n",
    "edges = [('A', 'B'), ('B', 'C'), ('C', 'D'), ('E', 'F'), ('F', 'G'), ('H', 'I')]\n",
    "edges = [('A', 'C'), ('B', 'C'), ('C', 'D'), ('E', 'F'), ('D', 'F'),('G', 'E')]\n",
    "edges = [('A', 'C'), ('B', 'C'), ('C', 'D'), ('E', 'F'),('G', 'E'), ('F', 'I'), ('D','H'), ('H', 'F')]\n",
    "\n",
    "G.add_edges_from(edges)\n",
    "\n",
    "# Function to find and merge nodes based on the heuristic\n",
    "def merge_operations(graph):\n",
    "    current_groups = []\n",
    "    co_location_groups = []\n",
    "    merged = set()\n",
    "    \n",
    "    for node in nx.topological_sort(graph):\n",
    "        \n",
    "        # Check if the node has exactly one successor and no other predecessors of successor\n",
    "        successors = list(graph.successors(node))\n",
    "        predecessors = list(graph.predecessors(node))\n",
    "        # print(successors)\n",
    "        if len(predecessors) == 1 and predecessors[0] in merged:\n",
    "            for group in co_location_groups:\n",
    "                if predecessors[0] in group:\n",
    "                    current_groups = group\n",
    "                    co_location_groups.remove(group)\n",
    "        else:\n",
    "            current_groups = []\n",
    "            \n",
    "        # print(current_groups)\n",
    "        if node not in current_groups and node not in merged:\n",
    "            current_groups.append(node)\n",
    "            if len(successors) == 1:\n",
    "                successor = successors[0]\n",
    "                predecessors = list(graph.predecessors(successor))\n",
    "                # print(predecessors)\n",
    "                if len(predecessors) == 1 and predecessors[0] == node:\n",
    "                    if successor not in merged:\n",
    "                        current_groups.append(successor)\n",
    "                        merged.add(successor)\n",
    "            # print(current_groups)\n",
    "            \n",
    "        co_location_groups.append(current_groups)\n",
    "        current_groups = []\n",
    "    \n",
    "        merged.add(node)\n",
    "        # print(co_location_groups)\n",
    "                \n",
    "    return co_location_groups\n",
    "\n",
    "# Apply the heuristic\n",
    "co_location_groups = merge_operations(Computation_G)\n",
    "# co_location_groups = merge_operations(G)\n",
    "print(\"Co-location groups:\", co_location_groups)\n",
    "print(\"Number of groups: \", len(co_location_groups))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx\n",
    "\n",
    "# Original graph with some example edges and nodes\n",
    "G = nx.DiGraph()\n",
    "edges = [('A', 'C'), ('B', 'C'), ('C', 'D'), ('E', 'F'),('G', 'E'), ('F', 'I'), ('D','H'),('H', 'F')]\n",
    "G.add_edges_from(edges)\n",
    "\n",
    "# Example co-location groups identified\n",
    "# co_location_groups = [['A'], ['B'], ['G', 'E'], ['C', 'D', 'H'], ['F', 'I']]\n",
    "\n",
    "# Function to merge nodes and create a new graph\n",
    "def merge_nodes(graph, groups):\n",
    "    new_graph = nx.DiGraph()\n",
    "    group_map = {}\n",
    "\n",
    "    # Map each node to its group identifier (new node)\n",
    "    for idx, group in enumerate(groups):\n",
    "        node_name = f\"Group_{idx}\"\n",
    "        for node in group:\n",
    "            group_map[node] = node_name\n",
    "            if node_name not in new_graph:\n",
    "                new_graph.add_node(node_name)\n",
    "                \n",
    "    # print(group_map)\n",
    "    \n",
    "    # Add edges with respect to new group nodes\n",
    "    for u, v in graph.edges():\n",
    "        new_u = group_map.get(u, u)\n",
    "        new_v = group_map.get(v, v)\n",
    "        if new_u != new_v:\n",
    "            new_graph.add_edge(new_u, new_v)\n",
    "\n",
    "    return new_graph, group_map\n",
    "\n",
    "# Create the new graph\n",
    "new_G,group_map = merge_nodes(Computation_G, co_location_groups)\n",
    "\n",
    "# Print new graph nodes and edges\n",
    "print(\"Nodes in new graph:\")\n",
    "print(new_G.nodes())\n",
    "print(\"Edges in new graph:\")\n",
    "print(new_G.edges())\n",
    "print(group_map)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "class EmbeddingModel(nn.Module):\n",
    "    def __init__(self, unique_list, embedding_size):\n",
    "        super(EmbeddingModel, self).__init__()\n",
    "        # Create a ParameterDict to store the parameters\n",
    "        self.embedding_size = embedding_size\n",
    "        self.embeddings = nn.ParameterDict({\n",
    "            op_type: nn.Parameter(torch.randn(embedding_size))\n",
    "            for op_type in unique_list\n",
    "        })\n",
    "        \n",
    "    def forward(self, op_info,new_G,group_map,ops):\n",
    "        # Example forward pass that aggregates embeddings based on input types\n",
    "        # print(self.embeddings['Add'])\n",
    "        aggregated_embedding = self.pre_process(op_info,new_G,group_map,ops)\n",
    "        return aggregated_embedding\n",
    "    \n",
    "    def pre_process(self,op_info,new_G,group_map,ops):\n",
    "        op_info_name = op_info.keys()\n",
    "        # print(op_info_name)\n",
    "        max_size = 4\n",
    "        # num_nodes = Computation_G.number_of_nodes()\n",
    "        num_nodes = new_G.number_of_nodes()\n",
    "        group_embedding_length = self.embedding_size+num_nodes+4\n",
    "        group_type_embeddings = {node: torch.zeros(self.embedding_size) for node in new_G.nodes()}  \n",
    "        group_op_map = {}\n",
    "        for key, value in group_map.items():\n",
    "            if value in group_op_map:\n",
    "                group_op_map[value].append(key)\n",
    "            else:\n",
    "                group_op_map[value] = [key]\n",
    "        # print(group_op_map)\n",
    "        op_embeddings = torch.zeros([num_nodes,group_embedding_length])\n",
    "        # print(op_embeddings)\n",
    "\n",
    "        # one_hot_adjacency = {}\n",
    "        # for node in Computation_G.nodes(data=True):\n",
    "        #     one_hot_adjacency[node[1].get(\"name\")] = np.zeros(num_nodes, dtype=int)\n",
    "        #     # print(type(node[1].get(\"name\")))\n",
    "        #     for neighbor in Computation_G.neighbors(node[0]):\n",
    "        #         one_hot_adjacency[node[1].get(\"name\")][int(neighbor) - 1] = 1 \n",
    "                \n",
    "        one_hot_adjacency = {}\n",
    "        for node in new_G.nodes(data=True):\n",
    "            one_hot_adjacency[node[0]] = np.zeros(num_nodes, dtype=int)\n",
    "            # print(type(node[1].get(\"name\")))\n",
    "            for neighbor in new_G.neighbors(node[0]):\n",
    "                # print(neighbor[5:])\n",
    "                index_node = int(neighbor[6:])\n",
    "                # print(node[0])\n",
    "                one_hot_adjacency[node[0]][index_node] = 1 \n",
    "                \n",
    "\n",
    "        for op in ops:\n",
    "            name = op.get_friendly_name()\n",
    "            type_embedding = torch.zeros(self.embedding_size)\n",
    "            # print(name)\n",
    "            if name in op_info_name:\n",
    "                type_embedding = self.embeddings[op_info[op.get_friendly_name()]]\n",
    "                # print(type_embedding)\n",
    "                try:\n",
    "                    # print(op_info[op.get_friendly_name()])\n",
    "                    for i in range(op.get_output_size()):\n",
    "                        # print(op.get_output_shape(i))\n",
    "                        shape = list(op.get_output_shape(i))\n",
    "                        padded_shape = shape + [0] * (max_size - len(shape))\n",
    "                        # print(padded_shape)\n",
    "                except Exception as e1:\n",
    "                    if str(e1) == \"get_shape was called on a descriptor::Tensor with dynamic shape\":\n",
    "                        padded_shape = [1,100,100,100]\n",
    "                        # print(e1)\n",
    "            elif name[:8] == \"Constant\":\n",
    "                # print(\"Run time model node:\")\n",
    "                # print(name)\n",
    "                # print(op.get_type_name())\n",
    "                type_embedding = self.embeddings['Constant']\n",
    "                for i in range(op.get_output_size()):\n",
    "                    # print(op.get_output_shape(i))\n",
    "                    shape = list(op.get_output_shape(i))\n",
    "                    padded_shape = shape + [0] * (max_size - len(shape))\n",
    "                    # print(padded_shape)\n",
    "            else:\n",
    "                try:\n",
    "                    shape = list(op.get_output_shape(i))\n",
    "                    padded_shape = shape + [0] * (max_size - len(shape))\n",
    "                    # print(padded_shape)\n",
    "                except:\n",
    "                    padded_shape = [1,100,100,100]\n",
    "                \n",
    "            padded_shape_tensor = torch.tensor(padded_shape, dtype=type_embedding.dtype).detach()\n",
    "            node_with_feature = None\n",
    "            for node in Computation_G.nodes(data=True):\n",
    "                # print(node[1].get(\"name\"))\n",
    "                # print(name)\n",
    "                if node[1].get(\"name\") == name:\n",
    "                    # print(node[1].get(\"name\"))\n",
    "                    node_with_feature = node[0]\n",
    "                    # print(\"Find\")\n",
    "                    break\n",
    "                \n",
    "            # if node_with_feature == None:\n",
    "            #     print(\"name not detected\")\n",
    "            #     return\n",
    "            group_name = group_map[node_with_feature]\n",
    "            # print(group_name)\n",
    "            \n",
    "            # if node_with_feature in group_op_map[group_name]:\n",
    "            group_type_embeddings[group_name] = group_type_embeddings[group_name] + type_embedding\n",
    "            # print(group_type_embeddings[group_name])\n",
    "            \n",
    "            if node_with_feature == group_op_map[group_name][-1]:\n",
    "                # print(name)\n",
    "                adj = one_hot_adjacency[group_name]\n",
    "                # print(type_embedding.shape)\n",
    "                # print(padded_shape_tensor.shape)\n",
    "                op_embedding = torch.cat((group_type_embeddings[group_name], padded_shape_tensor), 0)\n",
    "                # print(op_embedding.shape)\n",
    "                adj = torch.tensor(adj, dtype=op_embedding.dtype).detach()\n",
    "                op_embedding = torch.cat((op_embedding, adj), 0)\n",
    "                # print(op_embedding)\n",
    "                group_index = int(group_name[6:])\n",
    "                # print(group_index)\n",
    "                op_embeddings[group_index] += op_embedding\n",
    "            # break\n",
    "        # print(op_embeddings)\n",
    "        return op_embeddings,group_op_map\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_size = 64\n",
    "em = EmbeddingModel(unique_list, embedding_size)\n",
    "em(op_info,new_G,group_map,ops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class EncoderRNN(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, dropout_p=0.1):\n",
    "        super(EncoderRNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "\n",
    "        self.embedding = nn.Linear(input_size, hidden_size)\n",
    "        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
    "        self.dropout = nn.Dropout(dropout_p)\n",
    "\n",
    "    def forward(self, input):\n",
    "        # print(\"input \", input)\n",
    "        embedded = self.dropout(self.embedding(input))\n",
    "        output, hidden = self.gru(embedded)\n",
    "        return output, hidden\n",
    "    \n",
    "class BahdanauAttention(nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.Wa = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Ua = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Va = nn.Linear(hidden_size, 1)\n",
    "\n",
    "    def forward(self, query, keys):\n",
    "        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
    "        # print(\"scores\",scores.shape)\n",
    "        scores = scores.squeeze(2).unsqueeze(1)\n",
    "        \n",
    "        weights = F.softmax(scores, dim=-1)\n",
    "        context = torch.bmm(weights, keys)\n",
    "\n",
    "        return context, weights\n",
    "\n",
    "class AttnDecoderRNN(nn.Module):\n",
    "    def __init__(self, hidden_size, output_size, dropout_p=0.1):\n",
    "        super(AttnDecoderRNN, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.output_size = output_size\n",
    "        self.embedding = nn.Embedding(self.output_size+1, self.hidden_size)\n",
    "        self.attention = BahdanauAttention(hidden_size)\n",
    "        self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
    "        self.out = nn.Linear(hidden_size, output_size)\n",
    "        self.dropout = nn.Dropout(dropout_p)\n",
    "\n",
    "    def forward(self, encoder_outputs, encoder_hidden, st, target_tensor=None, random_sampling=True, temperature=1.5):\n",
    "        displacement = []\n",
    "        displacement_log = []\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # print(encoder_outputs.dtype)\n",
    "        # print(encoder_outputs.device)\n",
    "        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=encoder_outputs.device).fill_(SOS_token)\n",
    "        # print(decoder_input)\n",
    "\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        attentions = []\n",
    "        MAX_LENGTH = encoder_outputs.shape[1]\n",
    "        # print(MAX_LENGTH)\n",
    "\n",
    "        for i in range(MAX_LENGTH):\n",
    "            # print(decoder_input)\n",
    "\n",
    "            decoder_output, decoder_hidden, attn_weights = self.forward_step(\n",
    "                decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            # print(decoder_output)\n",
    "            attentions.append(attn_weights)\n",
    "            \n",
    "            if random_sampling:\n",
    "                # Random sampling with temperature\n",
    "                if st < 20:\n",
    "                    decoder_output = decoder_output / 5\n",
    "                    \n",
    "                # print(decoder_output)\n",
    "                probs = F.softmax(decoder_output, dim=-1).squeeze(0)\n",
    "                # print(probs)\n",
    "                decoder_input = torch.multinomial(probs, num_samples=1).detach()\n",
    "                displacement.append(decoder_input.squeeze(0))\n",
    "                \n",
    "            else:\n",
    "                if target_tensor is not None:\n",
    "                    # Teacher forcing: Feed the target as the next input\n",
    "                    decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing\n",
    "                else:\n",
    "                    # Without teacher forcing: use its own predictions as the next input\n",
    "                    # print(decoder_output.shape)\n",
    "                    _, topi = decoder_output.topk(1)\n",
    "                    # print(\"topi:\",topi)\n",
    "                    decoder_input = topi.squeeze(-1).detach()  # detach from history as input\n",
    "                    # print(decoder_input.shape)\n",
    "            # displacement.append(decoder_input)\n",
    "            # displacement_log_prob.append(decoder_output[0,0,topi])\n",
    "\n",
    "        \n",
    "        displacement = torch.cat(displacement).unsqueeze(0)\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # print(outputs)\n",
    "        attentions = torch.cat(attentions, dim=1)\n",
    "        # print(\"outputs.shape\", outputs.shape)\n",
    "        if random_sampling:\n",
    "            batch_indices = torch.arange(outputs.size(0)).unsqueeze(1).expand_as(displacement)\n",
    "            # print(torch.arange(outputs.size(0)))\n",
    "            # print(torch.arange(outputs.size(0)).unsqueeze(1))\n",
    "            # print(batch_indices)\n",
    "            time_indices = torch.arange(outputs.size(1)).unsqueeze(0).expand_as(displacement)\n",
    "            # print(torch.arange(outputs.size(1)))\n",
    "            # print(torch.arange(outputs.size(1)).unsqueeze(1))\n",
    "            # print(time_indices)\n",
    "            displacement_log = outputs[batch_indices, time_indices, displacement]\n",
    "            # print(displacement_log)\n",
    "            # print(\"displacement_log_prob.shape\",displacement_log_prob.shape)\n",
    "        else:\n",
    "            displacement_log, displacement = torch.max(decoder_outputs, dim=2)\n",
    "\n",
    "        return decoder_outputs, decoder_hidden, attentions, displacement_log, displacement\n",
    "\n",
    "\n",
    "    def forward_step(self, input, hidden, encoder_outputs):\n",
    "        # print(input.dtype)\n",
    "\n",
    "        embedded =  self.dropout(self.embedding(input))\n",
    "        # print(embedded.dtype)\n",
    "        # print(hidden.dtype)\n",
    "        query = hidden.permute(1, 0, 2)\n",
    "        context, attn_weights = self.attention(query, encoder_outputs)\n",
    "        input_gru = torch.cat((embedded, context), dim=2)\n",
    "\n",
    "        output, hidden = self.gru(input_gru, hidden)\n",
    "        output = self.out(output)\n",
    "\n",
    "        return output, hidden, attn_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder, decoder, device,input_size,hidden_size,output_size):\n",
    "        super().__init__()\n",
    "        self.encoder = encoder(input_size, hidden_size)\n",
    "        self.decoder = decoder(hidden_size,output_size)\n",
    "        self.device = device\n",
    "        \n",
    "    def forward(self, op_embeddings, st,sample=True):\n",
    "        # Implementation here (omitted for brevity)\n",
    "        input_op = op_embeddings.unsqueeze(0)\n",
    "        encoder_outputs, encoder_hidden= self.encoder(input_op)\n",
    "        # print(\"encoder_outputs\", encoder_outputs)\n",
    "        # print(\"encoder_hidden\", encoder_hidden)\n",
    "\n",
    "        _, _, _, displacement_log_prob, displacement = self.decoder(encoder_outputs,encoder_hidden,st)\n",
    "        # if sample:\n",
    "        #     pass\n",
    "        # else:\n",
    "        #     displacement_log_prob, displacement = torch.max(decoder_outputs, dim=2)\n",
    "        \n",
    "        return displacement_log_prob,displacement\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "\n",
    "def measure_device_placement(displacement,devices,runtime_model,group_op_map,Computation_G):\n",
    "    \n",
    "    def calculate_average(values):\n",
    "        if not values:\n",
    "            return 0  # Return 0 for an empty list\n",
    "        total = sum(values)  # Summing up all the elements in the list\n",
    "        count = len(values)  # Getting the number of elements in the list\n",
    "        average = total / count  # Calculating the average\n",
    "        return average\n",
    "    \n",
    "    def group_to_operation(displacement,group_op_map):\n",
    "        operation_displacement = torch.zeros(len(runtime_model.get_ordered_ops()),dtype=int)\n",
    "        # print(operation_displacement.shape)\n",
    "        for index, group_displacement in enumerate(displacement):\n",
    "            group_name = 'Group_' + str(index)\n",
    "            # print(group_op_map[group_name])\n",
    "            # print(group_displacement)\n",
    "            ops = runtime_model.get_ordered_ops()\n",
    "            # print(group_op_map)\n",
    "            for op_index in group_op_map[group_name]:\n",
    "                # print(Computation_G.nodes[op_index]['name'])\n",
    "                op_name = Computation_G.nodes[op_index]['name']\n",
    "                op_index = int(op_index)\n",
    "                for i,op in enumerate(ops):\n",
    "                    if op.get_friendly_name() == op_name:\n",
    "                        operation_displacement[i] += group_displacement\n",
    "                        break\n",
    "        return operation_displacement\n",
    "                \n",
    "            \n",
    "        \n",
    "    displacement = displacement.squeeze()\n",
    "    op_displacement = group_to_operation(displacement,group_op_map)\n",
    "    Error = False\n",
    "    latencies = []\n",
    "    # print(op_displacement.shape)\n",
    "    try:\n",
    "        # op_displacement = torch.zeros(len(runtime_model.get_ordered_ops()),dtype=int)\n",
    "        # op_displacement = op_displacement.squeeze()\n",
    "        # print(displacement)\n",
    "        # print(len(runtime_model.get_ordered_ops()))\n",
    "        for i,op in enumerate(runtime_model.get_ordered_ops()):\n",
    "            rt_info = op.get_rt_info()\n",
    "            # print(devices[displacement[i]])\n",
    "            rt_info[\"affinity\"] = devices[op_displacement[i]]\n",
    "            # rt_info[\"affinity\"] = \"GPU.0\"\n",
    "\n",
    "            # print(rt_info[\"affinity\"] )\n",
    "            \n",
    "        input_data = torch.rand(1, 3, 224, 224)\n",
    "        compiled_model = ov.compile_model(runtime_model,\"HETERO:GPU,CPU\")\n",
    "        infer_request = compiled_model.create_infer_request()\n",
    "        for _ in range(10):\n",
    "            infer_request.wait()\n",
    "            infer_request.infer(inputs={input_layer.any_name: input_data})\n",
    "            infer_request.wait()\n",
    "            latency = infer_request.latency\n",
    "            latencies.append(latency)\n",
    "        latencies = latencies[5:]\n",
    "        latency = calculate_average(latencies)\n",
    "    except Exception as e:\n",
    "        latency = 10000.0\n",
    "        Error = str(e)\n",
    "        # print(e)\n",
    "    latency = math.sqrt(latency)\n",
    "    # print(f\"Is all on CPU: {is_all_zeros} Is all on GPU.0: {is_all_ones} Latency: {latency}\")\n",
    "    return latency, Error\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import random\n",
    "\n",
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)  # For all GPUs\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    # If using other libraries that use randomness, set their seed here\n",
    "\n",
    "# Set a seed\n",
    "set_seed(42)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from tqdm import tqdm\n",
    "import json\n",
    "import datetime\n",
    "\n",
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)  # For all GPUs\n",
    "\n",
    "\n",
    "class ReinforcementModel(nn.Module):\n",
    "    def __init__(self, encoder_class, decoder_class, embedding_class, device, input_size, hidden_size, output_size, unique_list):\n",
    "        super(ReinforcementModel, self).__init__()\n",
    "        self.embedding_class = embedding_class(unique_list, embedding_size)\n",
    "        self.seq2seq = Seq2Seq(encoder_class, decoder_class, device, input_size, hidden_size, output_size)\n",
    "        self.device = device\n",
    "        \n",
    "    def forward(self, op_info,new_G,group_map,ops,st):\n",
    "        op_embeddings, group_op_map = self.embedding_class(op_info,new_G,group_map,ops)\n",
    "        # print(\"op_embeddings\", op_embeddings)\n",
    "        # print(\"group_op_map\", group_op_map)\n",
    "\n",
    "        displacement_log_prob, displacement = self.seq2seq(op_embeddings,st)\n",
    "        return displacement_log_prob, displacement, group_op_map\n",
    "\n",
    "def reinforce_train(model, optimizer, n_episodes, op_info, new_G, group_map, ops, ov_devices, ov_model, Computation_G):\n",
    "    results = []\n",
    "    best_reward = 10000\n",
    "    for episode in tqdm(range(n_episodes), desc=\"Training Episodes\"):\n",
    "        model.zero_grad()\n",
    "        displacement_log_prob_out,displacement,group_op_map = model(op_info,new_G,group_map,ops,episode)\n",
    "        # print(displacement_log_prob_out)\n",
    "        # Compute the reward\n",
    "        # print(displacement_detached,ov_devices,ov_model,group_op_map, Computation_G)\n",
    "        reward, Error = measure_device_placement(displacement,ov_devices,ov_model,group_op_map, Computation_G)\n",
    "        # print(reward)\n",
    "\n",
    "        # Accumulate losses\n",
    "        loss = -(displacement_log_prob_out.squeeze() * (reward)).sum()  # Negative log likelihood loss\n",
    "        # Perform back-propagation\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        # print(f\"Epoch: {episode}, Loss: {loss}, Reward: {reward}\")\n",
    "        results.append({'episode': episode, 'loss': loss.item(), 'reward': reward**2, 'displacement': displacement.squeeze().tolist(), 'Error' : Error})\n",
    "        \n",
    "        if reward <99:\n",
    "            if best_reward > reward**2:\n",
    "                best_reward = reward**2\n",
    "                \n",
    "        print(best_reward)\n",
    "        with open('training_results.json', 'w') as f:\n",
    "            json.dump(results, f)\n",
    " \n",
    "# Parameters for the model\n",
    "input_size = 314  # Number of features in the input\n",
    "hidden_size = 5*input_size  # Number of features in the hidden state\n",
    "output_size = 2  # Number of output classes\n",
    "num_layers = 1  # Number of stacked LSTM layers\n",
    "SOS_token = output_size\n",
    "# Set a seed\n",
    "set_seed(42)\n",
    "num_operation = Computation_G.number_of_nodes()\n",
    "model = ReinforcementModel(EncoderRNN, AttnDecoderRNN, EmbeddingModel, \"cpu\", input_size, hidden_size, output_size, unique_list)\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "n_episodes = 100\n",
    "print(op_info)\n",
    "reinforce_train(model, optimizer, n_episodes, op_info, new_G, group_map ,ops, devices, ov_model, Computation_G)\n",
    "\n",
    "today = datetime.datetime.now()\n",
    "date_string = today.strftime('%Y%m%d')\n",
    "\n",
    "model_parameter_path = f\"./baseline_DPO/model_parameters_{date_string}.pth\"\n",
    "torch.save({\n",
    "    'model_state_dict': model.state_dict(),\n",
    "    'optimizer_state_dict': optimizer.state_dict(),\n",
    "    'epoch': n_episodes\n",
    "}, model_parameter_path)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "openvino",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
