{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa7db2c7-beb3-483f-a3c3-7e36d7bb9256",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from syn_dataset import SynGraphDataset\n",
    "from spmotif_dataset import *\n",
    "import torch_geometric.transforms as T\n",
    "from torch_geometric.loader import DataLoader\n",
    "from torch_geometric.nn import GINConv, global_mean_pool, global_max_pool, global_add_pool\n",
    "from utils import *\n",
    "from sklearn.model_selection import train_test_split\n",
    "import shutil\n",
    "import glob\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "import pandas as pd\n",
    "import argparse\n",
    "import pickle\n",
    "import json\n",
    "import io\n",
    "from model import GIN, GINTELL\n",
    "import sys\n",
    "from train_logic import test_epoch\n",
    "\n",
    "import torch\n",
    "import io\n",
    "\n",
    "\n",
    "\n",
    "def get_best_baseline_path(dataset_name):\n",
    "    l = glob.glob(f'results/{dataset_name}/*/results.json')\n",
    "    fl = [json.load(open(f)) for f in l]\n",
    "    df = pd.DataFrame(fl)\n",
    "    if df.shape[0] == 0: return None\n",
    "    df['fname'] = l\n",
    "    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])\n",
    "    df = df[df.fname.str.contains('nogumbel=False')]\n",
    "    fname = df.iloc[-1]['fname']\n",
    "    fname = fname.replace('/results.json', '')\n",
    "    return fname\n",
    "\n",
    "def get_best_path(dataset_name):\n",
    "    l = glob.glob(f'results_logic/{dataset_name}/*/*/results.json')\n",
    "    fl = [json.load(open(f)) for f in l]\n",
    "    df = pd.DataFrame(fl)\n",
    "    if df.shape[0] == 0: return None\n",
    "    df['fname'] = l\n",
    "    df = df.sort_values(by=['val_acc_mean', 'val_acc_std', 'test_acc_std'], ascending=[True,False,False])\n",
    "    df = df[df.fname.str.contains('nogumbel=False')]\n",
    "    print(df.tail())\n",
    "    fname = df.iloc[-1]['fname']\n",
    "    fname = fname.replace('/results.json', '')\n",
    "    return fname\n",
    "\n",
    "import torch\n",
    "\n",
    "def inverse_sigmoid(x):\n",
    "    \"\"\"Computes the inverse of the sigmoid function (logit function).\"\"\"\n",
    "    return torch.log(x / (1 - x))\n",
    "\n",
    "torch.no_grad()\n",
    "def find_logic_rules(w, t_in, t_out, activations=None, max_rule_len=10, max_rules=100, min_support=5):\n",
    "    w = w.clone()\n",
    "    t_in = t_in.clone()\n",
    "    t_out = t_out.clone()\n",
    "    t_out = t_out.item()\n",
    "    ordering_scores = w\n",
    "    sorted_idxs = torch.argsort(ordering_scores, 0, descending=True)\n",
    "    mask = w > 1e-5\n",
    "    if activations is not None:\n",
    "        mask = mask & (activations.sum(0) >= min_support)\n",
    "    total_result = set()\n",
    "\n",
    "    # Filter and sort indices based on the mask\n",
    "    idxs_to_visit = sorted_idxs[mask[sorted_idxs]]\n",
    "    if idxs_to_visit.numel() == 0:\n",
    "        return total_result\n",
    "\n",
    "    # Sort weights based on the filtered indices\n",
    "    sorted_weights = w[idxs_to_visit]\n",
    "    current_combination = []\n",
    "    result = set()\n",
    "\n",
    "    def find_logic_rules_recursive(index, current_sum):\n",
    "        # Stop if the maximum number of rules has been reached\n",
    "        if len(result) >= max_rules:\n",
    "            return\n",
    "\n",
    "        if len(current_combination) > max_rule_len:\n",
    "            return\n",
    "\n",
    "        # Check if the current combination satisfies the condition\n",
    "        if current_sum >= t_out:\n",
    "            c = idxs_to_visit[current_combination].cpu().detach().tolist()\n",
    "            c = tuple(sorted(c))\n",
    "            result.add(c)\n",
    "            return\n",
    "\n",
    "        # Prune if remaining weights can't satisfy t_out\n",
    "        remaining_max_sum = current_sum + sorted_weights[index:].sum()\n",
    "        if remaining_max_sum < t_out:\n",
    "            return\n",
    "\n",
    "        # Explore further combinations\n",
    "        for i in range(index, idxs_to_visit.shape[0]):\n",
    "            # Prune based on activations if provided\n",
    "            if activations is not None and len(current_combination) > 0 and activations[:, idxs_to_visit[current_combination + [i]]].all(-1).sum().item() < min_support:\n",
    "                continue\n",
    "\n",
    "            current_combination.append(i)\n",
    "            find_logic_rules_recursive(i + 1, current_sum + sorted_weights[i])\n",
    "            current_combination.pop()\n",
    "\n",
    "    # Start the recursive process\n",
    "    find_logic_rules_recursive(0, 0)\n",
    "    return result\n",
    "\n",
    "\n",
    "def extract_rules(self, feature=None, activations=None, max_rule_len=float('inf'), max_rules=5, min_support=10, out_threshold=0.5):\n",
    "    ws = self.weight\n",
    "    t_in = self.phi_in.t\n",
    "    t_out = -self.b + inverse_sigmoid(torch.tensor(out_threshold))\n",
    "\n",
    "    rules = []\n",
    "    if feature is None:\n",
    "        features = range(self.out_features)\n",
    "    else:\n",
    "        features = [feature]\n",
    "    for i in features:\n",
    "        w = ws[i].to('cpu')\n",
    "        ti = t_in.to('cpu')\n",
    "        to = t_out[i].to('cpu')\n",
    "        rules.append(find_logic_rules(w, ti, to, activations, max_rule_len, max_rules, min_support))\n",
    "\n",
    "    return rules\n",
    "\n",
    "\n",
    "def hoyer_sparsity_loss(weights, lambda_=1.0, epsilon=1e-12):\n",
    "    \"\"\"\n",
    "    Hoyer's sparsity loss to promote sparsity.\n",
    "    \n",
    "    Args:\n",
    "        weights (torch.Tensor): The weights to regularize.\n",
    "        lambda_ (float): Regularization strength.\n",
    "        epsilon (float): Small value to prevent division by zero.\n",
    "    \n",
    "    Returns:\n",
    "        torch.Tensor: The Hoyer's sparsity loss.\n",
    "    \"\"\"\n",
    "    l1_norm = torch.sum(torch.abs(weights), -1)\n",
    "    l2_norm = torch.sqrt(torch.sum(weights**2, -1) + epsilon)\n",
    "    hoyer = (torch.sqrt(torch.tensor(weights.numel())) - l1_norm / l2_norm) / \\\n",
    "            (torch.sqrt(torch.tensor(weights.numel())) - 1 + epsilon)\n",
    "    loss = lambda_ * (1 - hoyer)\n",
    "    return loss.mean()\n",
    "\n",
    "\n",
    "def train_sparsity_epoch(model_tell, loader, device, optimizer, num_classes, conv_reg=1, fc_reg=1):\n",
    "    model_tell.train()\n",
    "    \n",
    "    total_loss = 0\n",
    "    total_correct = 0\n",
    "    \n",
    "    for data in loader:\n",
    "        try:\n",
    "            loss = 0\n",
    "            if data.x is None:\n",
    "                data.x = torch.ones((data.num_nodes, model_tell.num_features))\n",
    "            if data.y.numel() == 0: continue\n",
    "            if data.x.isnan().any(): continue\n",
    "            if data.y.isnan().any(): continue\n",
    "            y = data.y.reshape(-1).to(device).long()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            model_tell.fc.phi_in.tau = 10\n",
    "            out = model_tell(data.x.float().to(device), data.edge_index.to(device), data.batch.to(device))       \n",
    "            pred = out.argmax(-1)\n",
    "            loss += F.binary_cross_entropy(out.reshape(-1), torch.nn.functional.one_hot(y, num_classes=num_classes).float().reshape(-1)) + F.nll_loss(F.log_softmax(out, dim=-1), y.long())\n",
    "            # tells = [c.nn[0] for c in model_tell.convs] + [model_tell.fc]\n",
    "            for conv in model_tell.convs:\n",
    "                loss += conv_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)\n",
    "                #oss += (hoyer_sparsity_loss(torch.clamp(conv.nn[0].weight, min=1e-5)) + conv.nn[0].reg_loss + conv.nn[0].phi_in.entropy) + conv_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)\n",
    "                # loss += conv_reg*(torch.sqrt(torch.clamp(conv.nn[0].weight, min=1e-5)).sum(-1).mean()+ conv.nn[0].phi_in.entropy)\n",
    "            # loss += fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)\n",
    "            loss += (hoyer_sparsity_loss(torch.clamp(model_tell.fc.weight, min=1e-5)) + model_tell.fc.reg_loss + model_tell.fc.phi_in.entropy) + fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)\n",
    "            # loss += fc_reg*(torch.sqrt(torch.clamp(model_tell.fc.weight, min=1e-5)).sum(-1).mean() + model_tell.fc.phi_in.entropy)\n",
    "            loss.backward()\n",
    "            zero_nan_gradients(model_tell)\n",
    "            optimizer.step()\n",
    "            total_loss += loss.item() * data.num_graphs / len(loader.dataset)\n",
    "            total_correct += pred.eq(y).sum().item() / len(loader.dataset)\n",
    "        except Exception as e:\n",
    "            print(e)\n",
    "            pass\n",
    "\n",
    "    return total_loss, total_correct\n",
    "\n",
    "import torch\n",
    "import torch_scatter\n",
    "\n",
    "import torch\n",
    "import torch_scatter\n",
    "def scatter_sum(x, edge_index):\n",
    "    # Get target nodes (i.e., the nodes receiving the messages)\n",
    "    target_nodes = edge_index[1]\n",
    "    \n",
    "    # Perform scatter sum\n",
    "    out = torch_scatter.scatter_add(x[edge_index[0]], target_nodes, dim=0, dim_size=x.size(0)) + x\n",
    "\n",
    "    return out\n",
    "@torch.no_grad()\n",
    "def forward_with_activations(self, x, edge_index, batch, *args, **kwargs):\n",
    "    returns = []\n",
    "    x = self.input_bnorm(x)\n",
    "    xs = []\n",
    "    for i, conv in enumerate(self.convs):\n",
    "        ret = {}\n",
    "        ret['x'] = torch.hstack([x, 1-x])\n",
    "        ret['x_sum'] = scatter_sum(ret['x'], edge_index)\n",
    "        ret['x_bin'] = conv.nn[0].phi_in(ret['x_sum']) >= 0.5\n",
    "        x = conv(torch.hstack([x, 1-x]), edge_index)\n",
    "        xs.append(x)\n",
    "        ret['y'] = x\n",
    "        ret['y_bin'] = x>=0.5\n",
    "        returns.append(ret)\n",
    "    ret = {}\n",
    "    x_mean = global_mean_pool(torch.hstack(xs), batch)\n",
    "    x_max = global_max_pool(torch.hstack(xs), batch)\n",
    "    x_sum = global_add_pool(torch.hstack(xs), batch)\n",
    "    x = torch.hstack([x_mean, x_max, x_sum])\n",
    "    x = self.output_bnorm(x)\n",
    "    ret['x'] = torch.hstack([x, 1-x])\n",
    "    ret['x_bin'] = self.fc.phi_in(ret['x']) >= 0.5\n",
    "    x = self.fc(torch.hstack([x, 1-x]))\n",
    "    ret['y'] = x\n",
    "    ret['y_bin'] = x>=0.5\n",
    "    returns.append(ret)\n",
    "    return x, returns\n",
    "\n",
    "def sigmoid(x, tau=10):\n",
    "    return 1/(1+torch.exp(-tau*x))\n",
    "    \n",
    "def forward_tell(self, tau):\n",
    "    def fw(x):\n",
    "\n",
    "        # x = self.phi_in(torch.hstack([x, 1-x]))\n",
    "        x = self.phi_in(x)\n",
    "        self.max_in, _ = x.max(0)\n",
    "        reg_loss = 0\n",
    "        entropy_loss = 0\n",
    "        if self.use_weight_sigma:\n",
    "            reg_loss += torch.clamp(self.weight_s, min=1e-5).sum(-1).mean()\n",
    "        else:\n",
    "            reg_loss += torch.clamp(self.weight, min=1e-5).sum(-1).mean()\n",
    "        if self.phi_in.entropy is not None:\n",
    "            entropy_loss += self.phi_in.entropy\n",
    "        # print('b', reg_loss, entropy_loss)\n",
    "        self.reg_loss = reg_loss\n",
    "        \n",
    "        w = self.weight\n",
    "        o = sigmoid(x @ w.t() + self.b, tau=tau)\n",
    "        \n",
    "        self.entropy_loss = entropy_loss + -(o*torch.log(o+1e-8) + (1-o)*torch.log(1-o + 1e-8)).mean()\n",
    "        return o\n",
    "    return fw\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "704f6ed7-e516-4988-8f99-18bf66b2d503",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def clone_model1(model):\n",
    "    # Save the model to an in-memory buffer\n",
    "    buffer = io.BytesIO()\n",
    "\n",
    "    do_dropout = 'dropout' not in model.__dict__\n",
    "    torch.save(model, buffer)\n",
    "    \n",
    "    # Rewind the buffer\n",
    "    buffer.seek(0)\n",
    "    \n",
    "    \n",
    "    # Load the saved state into the new instance\n",
    "    cloned_model = torch.load(buffer)\n",
    "    return cloned_model\n",
    "\n",
    "def clone_model2(model):\n",
    "    # Save the model to an in-memory buffer\n",
    "    buffer = io.BytesIO()\n",
    "\n",
    "    do_dropout = 'dropout' not in model.__dict__\n",
    "    \n",
    "    model.dropout = None\n",
    "    torch.save(model, buffer)\n",
    "    \n",
    "    model.dropout = lambda x:x\n",
    "    \n",
    "    # Rewind the buffer\n",
    "    buffer.seek(0)\n",
    "    \n",
    "    \n",
    "    # Load the saved state into the new instance\n",
    "    cloned_model = torch.load(buffer)\n",
    "    \n",
    "    cloned_model.dropout = lambda x:x\n",
    "    return cloned_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81734a6f-ae1d-46a9-964d-f8a6b170a846",
   "metadata": {},
   "outputs": [],
   "source": [
    "##### import tqdm\n",
    "final_res = {}\n",
    "\n",
    "for dataset_name in ['PROTEINS']:\n",
    "    print(dataset_name)\n",
    "    final_res[dataset_name] = []\n",
    "    for seed in range(5):\n",
    "        results_path = os.path.join(get_best_path(dataset_name), str(seed))\n",
    "        data = pickle.load(open(os.path.join(results_path, 'data.pkl'), 'rb'))\n",
    "        args = json.load(open(os.path.join(results_path, 'args.json'), 'r'))\n",
    "        device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')\n",
    "        dataset = get_dataset(dataset_name)\n",
    "        num_classes = dataset.num_classes\n",
    "        num_features = dataset.num_features\n",
    "    \n",
    "        indices = list(range(len(dataset)))\n",
    "        train_indices, val_test_indices = train_test_split(indices, test_size=0.2,\n",
    "        shuffle=True, stratify=dataset.data.y, random_state=1)\n",
    "        \n",
    "        val_indices = val_test_indices[:len(val_test_indices)//2]\n",
    "        test_indices = val_test_indices[len(val_test_indices)//2:]\n",
    "        \n",
    "        train_dataset = dataset[train_indices]\n",
    "        val_dataset = dataset[val_indices]\n",
    "        test_dataset = dataset[test_indices]\n",
    "        train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "        val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)\n",
    "        test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)\n",
    "    \n",
    "        model_tell = torch.load(os.path.join(results_path, 'best.pt'), map_location=device)\n",
    "        model_tell = model_tell.to(device)\n",
    "        \n",
    "        if dataset_name in ['Mutagenicity']:\n",
    "            clone_model = clone_model1\n",
    "        else:\n",
    "            clone_model = clone_model2\n",
    "            model_tell.dropout = lambda x: x\n",
    "\n",
    "        initial_acc = test_acc = test_epoch(model_tell, test_loader, device)\n",
    "        \n",
    "        best_weights = clone_model(model_tell)\n",
    "        layers = [model_tell.fc, *[l.nn[0] for l in model_tell.convs[::-1]]]\n",
    "        \n",
    "        initial_weights = [(layer.weight>1e-4).sum().item() for layer in layers]\n",
    "        for l in reversed(range(len(layers))):\n",
    "            print('Layer', l)\n",
    "            layer = layers[l]\n",
    "            optimizer = torch.optim.Adam([layers[l].weight_sigma, layers[l].weight_exp], lr=0.005)\n",
    "            val_acc = test_epoch(model_tell, val_loader, device)\n",
    "            test_acc = test_epoch(model_tell, test_loader, device)\n",
    "            n_w =  (layer.weight>1e-4).sum().item()\n",
    "            best_situation = (val_acc, -n_w)\n",
    "            print(best_situation)\n",
    "            patience = max_patience = 150\n",
    "            for i in tqdm.tqdm(range(1000)):\n",
    "                # model_tell.fc.forward = forward_tell(model_tell.fc, 10)\n",
    "                train_loss, train_acc = train_sparsity_epoch(model_tell, train_loader, device, optimizer, num_classes, conv_reg=0.1, fc_reg=0.01)\n",
    "                val_acc = test_epoch(model_tell, val_loader, device)\n",
    "                test_acc = test_epoch(model_tell, test_loader, device)\n",
    "                n_w =  (layers[l].weight>1e-4).sum().item()\n",
    "                # if (val_acc, -n_w) > best_situation:\n",
    "                if (val_acc, -n_w) > best_situation or (-n_w > best_situation[1] and val_acc >= 0.99*best_situation[0]):\n",
    "                    best_weights = clone_model(model_tell)\n",
    "                    patience = max_patience\n",
    "                    best_situation = (val_acc, -n_w)\n",
    "                patience -= 1 \n",
    "                if patience == 0:\n",
    "                    break\n",
    "            model_tell = clone_model(best_weights)\n",
    "            layers = [model_tell.fc, *[l.nn[0] for l in model_tell.convs[::-1]]]\n",
    "        \n",
    "        final_weights = [(layer.weight>1e-4).sum().item() for layer in layers]\n",
    "        final_acc = test_epoch(model_tell, test_loader, device)\n",
    "        final_res[dataset_name].append([initial_acc, initial_weights, final_acc, final_weights])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cea21b74-5f77-4d3f-a756-3e0888312e4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "final_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8c5f5bb-7b46-44e2-9072-0b754f9a99fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_acc, initial_weights, final_acc, final_weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86810e2e-8046-484c-82fb-6db29d05ec2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "{'Ba2Motifs': [[1.0,\n",
    "   [192, 1088, 870, 1008, 1000, 270],\n",
    "   1.0,\n",
    "   [2, 1043, 12, 16, 375, 10]],\n",
    "  [1.0, [162, 1065, 862, 959, 1087, 324], 1.0, [2, 65, 209, 604, 28, 18]],\n",
    "  [1.0, [218, 1171, 908, 1026, 980, 266], 1.0, [2, 33, 18, 955, 531, 37]],\n",
    "  [1.0, [90, 850, 759, 931, 845, 223], 1.0, [2, 38, 52, 44, 38, 37]],\n",
    "  [1.0, [27, 865, 817, 867, 774, 179], 1.0, [2, 38, 47, 53, 34, 56]]],\n",
    " 'BaMultiShapes': [[1.0,\n",
    "   [160, 1192, 883, 1008, 1175, 276],\n",
    "   0.97,\n",
    "   [12, 1192, 31, 971, 1041, 210]],\n",
    "  [1.0,\n",
    "   [147, 1052, 986, 1028, 1012, 262],\n",
    "   0.98,\n",
    "   [147, 357, 986, 1028, 982, 147]],\n",
    "  [1.0, [203, 1094, 922, 980, 1125, 278], 1.0, [6, 1094, 18, 455, 1125, 184]],\n",
    "  [1.0, [145, 939, 926, 996, 1044, 296], 1.0, [14, 425, 926, 996, 726, 186]],\n",
    "  [1.0,\n",
    "   [255, 1080, 824, 1053, 1098, 292],\n",
    "   1.0,\n",
    "   [7, 1080, 792, 464, 924, 253]]],\n",
    " 'MUTAG': [[0.8421052631578947,\n",
    "   [110, 543, 714, 542, 696, 203],\n",
    "   0.8947368421052632,\n",
    "   [26, 1, 54, 46, 165, 69]],\n",
    "  [0.8421052631578947,\n",
    "   [78, 359, 386, 320, 587, 224],\n",
    "   0.8421052631578947,\n",
    "   [28, 1, 386, 320, 238, 224]],\n",
    "  [0.9473684210526315,\n",
    "   [69, 563, 498, 702, 543, 248],\n",
    "   0.7894736842105263,\n",
    "   [50, 32, 44, 555, 543, 57]],\n",
    "  [0.8421052631578947,\n",
    "   [47, 206, 377, 268, 787, 210],\n",
    "   0.8421052631578947,\n",
    "   [34, 3, 31, 6, 100, 67]],\n",
    "  [0.8947368421052632,\n",
    "   [65, 278, 518, 390, 687, 257],\n",
    "   0.8947368421052632,\n",
    "   [64, 5, 21, 3, 134, 251]]],\n",
    " 'BBBP': [[0.8780487804878049,\n",
    "   [94, 763, 792, 723, 824, 305],\n",
    "   0.848780487804878,\n",
    "   [14, 0, 5, 723, 824, 90]],\n",
    "  [0.8682926829268293,\n",
    "   [152, 7, 701, 873, 910, 332],\n",
    "   0.8585365853658536,\n",
    "   [152, 0, 13, 9, 910, 332]],\n",
    "  [0.9024390243902439,\n",
    "   [122, 1, 755, 894, 869, 325],\n",
    "   0.8439024390243902,\n",
    "   [122, 0, 755, 894, 5, 325]],\n",
    "  [0.8926829268292683,\n",
    "   [113, 4, 686, 372, 935, 301],\n",
    "   0.8829268292682927,\n",
    "   [113, 0, 0, 206, 935, 301]],\n",
    "  [0.8829268292682927,\n",
    "   [100, 2, 3, 577, 1028, 313],\n",
    "   0.8780487804878049,\n",
    "   [14, 0, 0, 577, 1028, 313]]]}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35a91cee-8c9e-4fd9-a92e-5c0d731d4b66",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-imp]",
   "language": "python",
   "name": "conda-env-.conda-imp-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
