{"nbformat":4,"nbformat_minor":4,"metadata":{"kernelspec":{"name":"python3","display_name":"Yandex DataSphere Kernel","language":"python"},"language_info":{"file_extension":".py","pygments_lexer":"ipython3","mimetype":"text/x-python","name":"python","codemirror_mode":{"name":"ipython","version":3},"nbconvert_exporter":"python","version":"3.7.7"},"notebookId":"69b2b8bb-7d53-4bca-9869-12bdd5b9d61a"},"cells":[{"cell_type":"code","source":"#!L\nimport warnings\nwarnings.filterwarnings('ignore')\n\nimport torch\nimport torch.nn as nn\nimport torch.optim \nimport torch.nn.functional as F\nimport torchvision.datasets as datasets \nimport torchvision.transforms as transforms   \nimport torchvision.models as models\n\nimport numpy as np","metadata":{"cellId":"7j7c6jno5jgfdom5jxq5st"},"outputs":[],"execution_count":null},{"cell_type":"code","source":"#!L\nimport code\nimport wideresnet\n\nimport importlib\nimportlib.reload(code)\nimportlib.reload(wideresnet)","metadata":{"cellId":"1qim2v1qdoam3c35w8v2oo"},"outputs":[],"execution_count":null},{"cell_type":"code","source":"#!g1.1\ndef evaluate_and_bootstrap(method_params, path):\n    seed = 228\n    torch.manual_seed(seed)\n    np.random.seed(seed)\n    torch.cuda.manual_seed(seed)\n    \n    eps_s = method_params['eps_s']\n    t_s = method_params['t_s']\n    eps_g = method_params['eps_g']\n    t_g = method_params['t_g']\n    eps_sg = method_params['eps_sg']\n    t_sg = method_params['t_sg']\n    \n    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n    \n    model_path = f'{0}.pth.tar'\n    net = wideresnet.WideResNet(28, 10, widen_factor=10)\n\n    checkpoint = torch.load(model_path)\n    net.load_state_dict(checkpoint['state_dict'])\n\n    net = net.to(device)\n        \n    net.eval()\n    \n    in_loader_big_test, out_loader_big_test, _, _ = code.get_data_cifar()\n#     _, _, in_loader_big_test, out_loader_big_test = code.get_data_cifar()\n        \n    ### msp ###\n    in_msp_big_test = code.baseline(in_loader_big_test, net, device, silent=True)\n    out_msp_big_test = code.baseline(out_loader_big_test, net, device, silent=True)\n    \n    code.save_and_bootstrap(path, 'msp', in_msp_big_test, out_msp_big_test, bootstrap_num=100)\n    print('msp finished!')\n    ### msp ###\n    \n    ### ml ###\n    in_ml_big_test = code.max_logit(in_loader_big_test, net, device, silent=True)\n    out_ml_big_test = code.max_logit(out_loader_big_test, net, device, silent=True)\n    \n    code.save_and_bootstrap(path, 'ml', in_ml_big_test, out_ml_big_test, bootstrap_num=100)\n    print('ml finished!')\n    ### ml ###\n    \n    ### s ###\n    in_s_big_test, _ = code.get_s_and_g(eps_s, t_s, in_loader_big_test, net, device, silent=True, mode='cifar')\n    out_s_big_test, _ = code.get_s_and_g(eps_s, t_s, out_loader_big_test, net, device, silent=True, mode='cifar')\n    \n    code.save_and_bootstrap(path, 's_part', -in_s_big_test, -out_s_big_test, bootstrap_num=100)\n    print('s_part finished!')\n    ### s ###\n    \n    ### g ###\n    _, in_g_big_test = code.get_s_and_g(eps_g, t_g, in_loader_big_test, net, device, silent=True, mode='cifar')\n    _, out_g_big_test = code.get_s_and_g(eps_g, t_g, out_loader_big_test, net, device, silent=True, mode='cifar')\n    \n    code.save_and_bootstrap(path, 'g_part', in_g_big_test, out_g_big_test, bootstrap_num=100)\n    print('g_part finished!')\n    ### g ###\n    \n    \n    ### sg ###\n    a, b = code.get_s_and_g(eps_sg, t_sg, in_loader_big_test, net, device, silent=True, mode='cifar')\n    in_sg_big_test = a * b\n    \n    a, b = code.get_s_and_g(eps_sg, t_sg, out_loader_big_test, net, device, silent=True, mode='cifar')\n    out_sg_big_test = a * b\n    \n    code.save_and_bootstrap(path, 'sg_part', in_sg_big_test, out_sg_big_test, bootstrap_num=100)\n    print('sg_part finished!')\n    ### sg ###","metadata":{"cellId":"fv8p2lc43mp4o0u39qeel"},"outputs":[],"execution_count":null},{"cell_type":"code","source":"#!g1.1\nparams = {'eps_s': 0.0006, 't_s': 200.0, 'eps_g': 0.002, 't_g': 5.0, 'eps_sg': 0.002, 't_sg': 5.0}\nevaluate_and_bootstrap(params, 'cifar')","metadata":{"cellId":"1tittf0ukogvooe1pkkqtc"},"outputs":[],"execution_count":null},{"cell_type":"code","source":"#!g1.1\n","metadata":{"cellId":"td83oq5f5vjjgjf82bruxq"},"outputs":[],"execution_count":null}]}