{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import cv2\n",
    "import torch.nn.functional as F\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "outputs": [],
   "source": [
    "category = [\"cat.npz\", \"fish.npz\", \"angel.npz\", \"car.npz\", 'flower.npz']"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "cats = np.load('./dataset/cat.npz',allow_pickle=True,encoding='latin1')"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "cat1 = cats['train'][0]"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "data": {
      "text/plain": "array([[  -3,   -8,    0],\n       [   2,  -38,    0],\n       [  14,    8,    0],\n       [  20,   26,    0],\n       [  22,  -14,    0],\n       [  10,   -3,    0],\n       [  51,    0,    0],\n       [  31,   -5,    0],\n       [  12,    3,    0],\n       [   5,   -6,    0],\n       [   7,  -21,    0],\n       [   5,   -3,    0],\n       [   9,   31,    0],\n       [   0,   91,    0],\n       [  -2,    8,    0],\n       [ -12,   11,    0],\n       [ -44,   22,    0],\n       [ -51,    2,    0],\n       [ -17,   -4,    0],\n       [ -11,   -5,    0],\n       [ -42,  -31,    0],\n       [ -11,  -13,    0],\n       [  -5,  -15,    0],\n       [   0,  -15,    0],\n       [   3,  -14,    0],\n       [   9,  -18,    1],\n       [  44,   41,    0],\n       [   8,   24,    1],\n       [  83,  -49,    0],\n       [   2,   36,    0],\n       [   5,    9,    1],\n       [ -74,   15,    0],\n       [  43,   -3,    1],\n       [ -18,    0,    0],\n       [  -1,   36,    1],\n       [ -42, -181,    0],\n       [   1,  -36,    0],\n       [  13,    1,    0],\n       [  22,   14,    0],\n       [  13,    3,    0],\n       [  -3,  -34,    0],\n       [   3,   25,    0],\n       [   9,   14,    0],\n       [  10,    6,    1],\n       [   5,  -28,    0],\n       [  16,  -13,    0],\n       [   3,   -8,    0],\n       [  -9,    0,    0],\n       [  -9,    9,    0],\n       [  -3,    8,    0],\n       [  10,   12,    0],\n       [  25,    1,    0],\n       [   5,   -3,    1],\n       [   2,  -31,    0],\n       [  -7,   12,    0],\n       [   4,   13,    0],\n       [   9,    0,    0],\n       [   8,   -6,    0],\n       [   6,  -12,    0],\n       [  -4,  -11,    0],\n       [  -9,   -3,    1],\n       [  28,   -9,    0],\n       [  12,   31,    0],\n       [   7,  -34,    0],\n       [  22,   25,    0],\n       [   2,  -46,    1],\n       [-208,  273,    0],\n       [ -13,   45,    0],\n       [  -3,   41,    0],\n       [   5,   27,    0],\n       [  20,   38,    1],\n       [ 122, -180,    0],\n       [   2,   35,    0],\n       [  12,   76,    0],\n       [   2,   50,    1],\n       [-162,  -83,    0],\n       [ -34, -126,    0],\n       [  -6,   -6,    0],\n       [ -16,    4,    0],\n       [ -16,   29,    0],\n       [  -1,   41,    0],\n       [  18,   61,    0],\n       [  16,   31,    0],\n       [  33,   44,    1],\n       [ -70, -102,    0],\n       [  12,    1,    0],\n       [  27,  -13,    0],\n       [  11,   -9,    0],\n       [   9,  -18,    1],\n       [  65,   63,    0],\n       [   2,   55,    0],\n       [   7,   29,    1],\n       [  49, -106,    0],\n       [  -3,   12,    0],\n       [   0,   66,    0],\n       [   3,   21,    1]], dtype=int16)"
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat1"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "def sketch_normalize(sketch):\n",
    "    '''\n",
    "\n",
    "    :param sketch: (points, abs x, abs y, state)\n",
    "    :return: normalize(sketch)\n",
    "    '''\n",
    "    n_sketch = sketch.copy().astype(np.float32)\n",
    "    w_max = np.max(n_sketch[:,0], axis=0)\n",
    "    w_min = np.min(n_sketch[:,0], axis=0)\n",
    "\n",
    "    h_max = np.max(n_sketch[:,1], axis=0)\n",
    "    h_min = np.min(n_sketch[:,1], axis=0)\n",
    "\n",
    "    scale_w = w_max - w_min\n",
    "    scale_h = h_max - h_min\n",
    "\n",
    "    scale = scale_w if scale_w> scale_h else scale_h\n",
    "\n",
    "    n_sketch[:,0] = (n_sketch[:,0] - w_min)/scale *2 -1\n",
    "    n_sketch[:,1] = (n_sketch[:,1] - h_min)/scale *2 -1\n",
    "    return n_sketch"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "outputs": [
    {
     "data": {
      "text/plain": "array([[-0.75      , -0.1745283 ,  0.        ],\n       [-0.740566  , -0.3537736 ,  0.        ],\n       [-0.6745283 , -0.3160377 ,  0.        ],\n       [-0.5801887 , -0.19339621,  0.        ],\n       [-0.4764151 , -0.25943398,  0.        ],\n       [-0.4292453 , -0.2735849 ,  0.        ],\n       [-0.18867922, -0.2735849 ,  0.        ],\n       [-0.04245281, -0.2971698 ,  0.        ],\n       [ 0.01415098, -0.2830189 ,  0.        ],\n       [ 0.03773582, -0.31132078,  0.        ],\n       [ 0.07075477, -0.41037738,  0.        ],\n       [ 0.09433961, -0.4245283 ,  0.        ],\n       [ 0.13679242, -0.2783019 ,  0.        ],\n       [ 0.13679242,  0.1509434 ,  0.        ],\n       [ 0.12735844,  0.18867922,  0.        ],\n       [ 0.07075477,  0.24056602,  0.        ],\n       [-0.13679248,  0.3443396 ,  0.        ],\n       [-0.3773585 ,  0.3537736 ,  0.        ],\n       [-0.4575472 ,  0.33490562,  0.        ],\n       [-0.509434  ,  0.31132078,  0.        ],\n       [-0.7075472 ,  0.16509438,  0.        ],\n       [-0.759434  ,  0.10377359,  0.        ],\n       [-0.7830189 ,  0.03301883,  0.        ],\n       [-0.7830189 , -0.03773582,  0.        ],\n       [-0.7688679 , -0.10377359,  0.        ],\n       [-0.7264151 , -0.18867922,  1.        ],\n       [-0.5188679 ,  0.00471699,  0.        ],\n       [-0.4811321 ,  0.11792457,  1.        ],\n       [-0.08962262, -0.11320752,  0.        ],\n       [-0.08018869,  0.05660379,  0.        ],\n       [-0.05660379,  0.0990566 ,  1.        ],\n       [-0.4056604 ,  0.16981137,  0.        ],\n       [-0.2028302 ,  0.15566039,  1.        ],\n       [-0.28773582,  0.15566039,  0.        ],\n       [-0.2924528 ,  0.32547164,  1.        ],\n       [-0.49056602, -0.5283019 ,  0.        ],\n       [-0.48584908, -0.6981132 ,  0.        ],\n       [-0.4245283 , -0.6933962 ,  0.        ],\n       [-0.3207547 , -0.6273585 ,  0.        ],\n       [-0.25943398, -0.6132076 ,  0.        ],\n       [-0.2735849 , -0.7735849 ,  0.        ],\n       [-0.25943398, -0.6556604 ,  0.        ],\n       [-0.21698111, -0.5896226 ,  0.        ],\n       [-0.16981131, -0.5613208 ,  1.        ],\n       [-0.1462264 , -0.6933962 ,  0.        ],\n       [-0.07075471, -0.754717  ,  0.        ],\n       [-0.05660379, -0.7924528 ,  0.        ],\n       [-0.0990566 , -0.7924528 ,  0.        ],\n       [-0.14150941, -0.75      ,  0.        ],\n       [-0.15566039, -0.7122642 ,  0.        ],\n       [-0.10849059, -0.6556604 ,  0.        ],\n       [ 0.00943398, -0.6509434 ,  0.        ],\n       [ 0.03301883, -0.6650944 ,  1.        ],\n       [ 0.04245281, -0.8113208 ,  0.        ],\n       [ 0.00943398, -0.754717  ,  0.        ],\n       [ 0.02830184, -0.6933962 ,  0.        ],\n       [ 0.07075477, -0.6933962 ,  0.        ],\n       [ 0.10849059, -0.7216981 ,  0.        ],\n       [ 0.13679242, -0.7783019 ,  0.        ],\n       [ 0.11792457, -0.8301887 ,  0.        ],\n       [ 0.07547164, -0.8443396 ,  1.        ],\n       [ 0.20754719, -0.8867924 ,  0.        ],\n       [ 0.26415098, -0.740566  ,  0.        ],\n       [ 0.2971698 , -0.9009434 ,  0.        ],\n       [ 0.4009434 , -0.7830189 ,  0.        ],\n       [ 0.41037738, -1.        ,  1.        ],\n       [-0.5707547 ,  0.28773582,  0.        ],\n       [-0.6320754 ,  0.5       ,  0.        ],\n       [-0.6462264 ,  0.6933962 ,  0.        ],\n       [-0.6226415 ,  0.82075477,  0.        ],\n       [-0.5283019 ,  1.        ,  1.        ],\n       [ 0.0471698 ,  0.1509434 ,  0.        ],\n       [ 0.05660379,  0.31603777,  0.        ],\n       [ 0.11320758,  0.67452836,  0.        ],\n       [ 0.12264156,  0.9103774 ,  1.        ],\n       [-0.6415094 ,  0.51886797,  0.        ],\n       [-0.8018868 , -0.0754717 ,  0.        ],\n       [-0.8301887 , -0.10377359,  0.        ],\n       [-0.9056604 , -0.08490568,  0.        ],\n       [-0.9811321 ,  0.0518868 ,  0.        ],\n       [-0.9858491 ,  0.24528301,  0.        ],\n       [-0.9009434 ,  0.5330188 ,  0.        ],\n       [-0.8254717 ,  0.67924523,  0.        ],\n       [-0.6698113 ,  0.8867924 ,  1.        ],\n       [-1.        ,  0.4056604 ,  0.        ],\n       [-0.9433962 ,  0.41037738,  0.        ],\n       [-0.8160377 ,  0.3490566 ,  0.        ],\n       [-0.764151  ,  0.3066038 ,  0.        ],\n       [-0.7216981 ,  0.22169816,  1.        ],\n       [-0.41509432,  0.51886797,  0.        ],\n       [-0.4056604 ,  0.77830184,  0.        ],\n       [-0.3726415 ,  0.9150944 ,  1.        ],\n       [-0.14150941,  0.41509438,  0.        ],\n       [-0.15566039,  0.47169816,  0.        ],\n       [-0.15566039,  0.7830188 ,  0.        ],\n       [-0.14150941,  0.8820754 ,  1.        ]], dtype=float32)"
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cat_abs = cat1.copy()\n",
    "cat_abs[:,:2] = np.cumsum(cat1[:,:2], axis=0)\n",
    "n_cat1 = sketch_normalize(cat_abs)\n",
    "n_cat1"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "outputs": [],
   "source": [
    "def draw_sketch(sketch, size=256, thickness=6):\n",
    "    '''\n",
    "\n",
    "    :param sketch: (points, abs x, abs y, state) normalized\n",
    "    :return: canvas\n",
    "    '''\n",
    "    n_sketch = sketch.copy().astype(np.float32)\n",
    "    canvas = np.zeros([size,size,3])\n",
    "    n_sketch[:,:2] = (n_sketch[:,:2]+1)/2*size\n",
    "    x1 = n_sketch[0,0]\n",
    "    y1 = n_sketch[0,1]\n",
    "    pri_state = n_sketch[0,2]\n",
    "\n",
    "    for (x2,y2,curr_state) in n_sketch[1:,:]:\n",
    "        if curr_state !=1 and curr_state!=0:\n",
    "            break\n",
    "        if pri_state == 1:\n",
    "            x1 = x2\n",
    "            y1 = y2\n",
    "            pri_state = curr_state\n",
    "            continue\n",
    "        cv2.line(canvas,(int(x1), int(y1)),(int(x2), int(y2)), color=(1,1,1), thickness=thickness)\n",
    "\n",
    "        x1 = x2\n",
    "        y1 = y2\n",
    "        pri_state = curr_state\n",
    "\n",
    "    return canvas"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "outputs": [
    {
     "data": {
      "text/plain": "-1"
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "img = draw_sketch(n_cat1)\n",
    "cv2.imshow(\"test\", img)\n",
    "cv2.waitKey(0)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "outputs": [],
   "source": [
    "# 已经debug\n",
    "points_per_stroke_th  = 64\n",
    "strokes_per_sketch_th  = 48\n",
    "def points2stroke(sketch, points_ps_num, stroke_num):\n",
    "    '''\n",
    "    :param sketch:\n",
    "    :param points_ps_num: points in per stroke hyperparameters\n",
    "    :param stroke_num: strokes in per sketch   hyperparameters\n",
    "    :return:\n",
    "    '''\n",
    "    stroke_idx = np.where(sketch[:,2]>0)[0]\n",
    "    stroke_idx_copy = np.concatenate([[-1], stroke_idx[:-1]],axis=0)\n",
    "    stroke_length = stroke_idx - stroke_idx_copy\n",
    "    n_sketch = -np.ones([stroke_num, points_ps_num*3]).astype(np.float32)\n",
    "    start = 0\n",
    "\n",
    "    # only one stroke\n",
    "    if len(stroke_idx) < 2:\n",
    "        if len(sketch)>points_per_stroke_th:\n",
    "            n_sketch[0,:points_per_stroke_th*3] = sketch[:points_per_stroke_th,:].reshape(-1)\n",
    "            stroke_length[0] = points_per_stroke_th\n",
    "        else:\n",
    "            n_sketch[0,:len(sketch)*3] = sketch.reshape(-1)\n",
    "        return n_sketch, stroke_length\n",
    "\n",
    "    for i, end in enumerate(stroke_idx):\n",
    "        if i >= strokes_per_sketch_th:\n",
    "            break\n",
    "        end =end+1\n",
    "        if stroke_length[i] > points_per_stroke_th:\n",
    "            n_sketch[i:i+1,:points_per_stroke_th*3] = sketch[start:start+points_per_stroke_th,:].reshape(1,-1)\n",
    "            stroke_length[i] = points_per_stroke_th\n",
    "        else:\n",
    "            n_sketch[i:i+1,:stroke_length[i]*3] = sketch[start:end,:].reshape(1,-1)\n",
    "        start = end\n",
    "\n",
    "    return n_sketch, stroke_length\n",
    "\n",
    "# 已经debug\n",
    "def stroke2points(sketch, stroke_len):\n",
    "    n_sketch_part = []\n",
    "    for i, length in enumerate(stroke_len):\n",
    "        if i>= strokes_per_sketch_th:\n",
    "            break\n",
    "        n_sketch_part.append(sketch[i,:length*3].reshape(-1,3))\n",
    "    return np.concatenate(n_sketch_part, axis=0)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "outputs": [],
   "source": [
    "test_stroke, test_len = points2stroke(n_cat1, 64, 50)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "outputs": [
    {
     "data": {
      "text/plain": "array([[-0.75      , -0.1745283 ,  0.        , ..., -1.        ,\n        -1.        , -1.        ],\n       [-0.5188679 ,  0.00471699,  0.        , ..., -1.        ,\n        -1.        , -1.        ],\n       [-0.08962262, -0.11320752,  0.        , ..., -1.        ,\n        -1.        , -1.        ],\n       ...,\n       [-1.        , -1.        , -1.        , ..., -1.        ,\n        -1.        , -1.        ],\n       [-1.        , -1.        , -1.        , ..., -1.        ,\n        -1.        , -1.        ],\n       [-1.        , -1.        , -1.        , ..., -1.        ,\n        -1.        , -1.        ]], dtype=float32)"
     },
     "execution_count": 103,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_stroke"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "outputs": [
    {
     "data": {
      "text/plain": "array([[-0.75      , -0.1745283 ,  0.        ],\n       [-0.740566  , -0.3537736 ,  0.        ],\n       [-0.6745283 , -0.3160377 ,  0.        ],\n       [-0.5801887 , -0.19339621,  0.        ],\n       [-0.4764151 , -0.25943398,  0.        ],\n       [-0.4292453 , -0.2735849 ,  0.        ],\n       [-0.18867922, -0.2735849 ,  0.        ],\n       [-0.04245281, -0.2971698 ,  0.        ],\n       [ 0.01415098, -0.2830189 ,  0.        ],\n       [ 0.03773582, -0.31132078,  0.        ],\n       [ 0.07075477, -0.41037738,  0.        ],\n       [ 0.09433961, -0.4245283 ,  0.        ],\n       [ 0.13679242, -0.2783019 ,  0.        ],\n       [ 0.13679242,  0.1509434 ,  0.        ],\n       [ 0.12735844,  0.18867922,  0.        ],\n       [ 0.07075477,  0.24056602,  0.        ],\n       [-0.13679248,  0.3443396 ,  0.        ],\n       [-0.3773585 ,  0.3537736 ,  0.        ],\n       [-0.4575472 ,  0.33490562,  0.        ],\n       [-0.509434  ,  0.31132078,  0.        ],\n       [-0.7075472 ,  0.16509438,  0.        ],\n       [-0.759434  ,  0.10377359,  0.        ],\n       [-0.7830189 ,  0.03301883,  0.        ],\n       [-0.7830189 , -0.03773582,  0.        ],\n       [-0.7688679 , -0.10377359,  0.        ],\n       [-0.7264151 , -0.18867922,  1.        ],\n       [-0.5188679 ,  0.00471699,  0.        ],\n       [-0.4811321 ,  0.11792457,  1.        ],\n       [-0.08962262, -0.11320752,  0.        ],\n       [-0.08018869,  0.05660379,  0.        ],\n       [-0.05660379,  0.0990566 ,  1.        ],\n       [-0.4056604 ,  0.16981137,  0.        ],\n       [-0.2028302 ,  0.15566039,  1.        ],\n       [-0.28773582,  0.15566039,  0.        ],\n       [-0.2924528 ,  0.32547164,  1.        ],\n       [-0.49056602, -0.5283019 ,  0.        ],\n       [-0.48584908, -0.6981132 ,  0.        ],\n       [-0.4245283 , -0.6933962 ,  0.        ],\n       [-0.3207547 , -0.6273585 ,  0.        ],\n       [-0.25943398, -0.6132076 ,  0.        ],\n       [-0.2735849 , -0.7735849 ,  0.        ],\n       [-0.25943398, -0.6556604 ,  0.        ],\n       [-0.21698111, -0.5896226 ,  0.        ],\n       [-0.16981131, -0.5613208 ,  1.        ],\n       [-0.1462264 , -0.6933962 ,  0.        ],\n       [-0.07075471, -0.754717  ,  0.        ],\n       [-0.05660379, -0.7924528 ,  0.        ],\n       [-0.0990566 , -0.7924528 ,  0.        ],\n       [-0.14150941, -0.75      ,  0.        ],\n       [-0.15566039, -0.7122642 ,  0.        ],\n       [-0.10849059, -0.6556604 ,  0.        ],\n       [ 0.00943398, -0.6509434 ,  0.        ],\n       [ 0.03301883, -0.6650944 ,  1.        ],\n       [ 0.04245281, -0.8113208 ,  0.        ],\n       [ 0.00943398, -0.754717  ,  0.        ],\n       [ 0.02830184, -0.6933962 ,  0.        ],\n       [ 0.07075477, -0.6933962 ,  0.        ],\n       [ 0.10849059, -0.7216981 ,  0.        ],\n       [ 0.13679242, -0.7783019 ,  0.        ],\n       [ 0.11792457, -0.8301887 ,  0.        ],\n       [ 0.07547164, -0.8443396 ,  1.        ],\n       [ 0.20754719, -0.8867924 ,  0.        ],\n       [ 0.26415098, -0.740566  ,  0.        ],\n       [ 0.2971698 , -0.9009434 ,  0.        ],\n       [ 0.4009434 , -0.7830189 ,  0.        ],\n       [ 0.41037738, -1.        ,  1.        ],\n       [-0.5707547 ,  0.28773582,  0.        ],\n       [-0.6320754 ,  0.5       ,  0.        ],\n       [-0.6462264 ,  0.6933962 ,  0.        ],\n       [-0.6226415 ,  0.82075477,  0.        ],\n       [-0.5283019 ,  1.        ,  1.        ],\n       [ 0.0471698 ,  0.1509434 ,  0.        ],\n       [ 0.05660379,  0.31603777,  0.        ],\n       [ 0.11320758,  0.67452836,  0.        ],\n       [ 0.12264156,  0.9103774 ,  1.        ],\n       [-0.6415094 ,  0.51886797,  0.        ],\n       [-0.8018868 , -0.0754717 ,  0.        ],\n       [-0.8301887 , -0.10377359,  0.        ],\n       [-0.9056604 , -0.08490568,  0.        ],\n       [-0.9811321 ,  0.0518868 ,  0.        ],\n       [-0.9858491 ,  0.24528301,  0.        ],\n       [-0.9009434 ,  0.5330188 ,  0.        ],\n       [-0.8254717 ,  0.67924523,  0.        ],\n       [-0.6698113 ,  0.8867924 ,  1.        ],\n       [-1.        ,  0.4056604 ,  0.        ],\n       [-0.9433962 ,  0.41037738,  0.        ],\n       [-0.8160377 ,  0.3490566 ,  0.        ],\n       [-0.764151  ,  0.3066038 ,  0.        ],\n       [-0.7216981 ,  0.22169816,  1.        ],\n       [-0.41509432,  0.51886797,  0.        ],\n       [-0.4056604 ,  0.77830184,  0.        ],\n       [-0.3726415 ,  0.9150944 ,  1.        ],\n       [-0.14150941,  0.41509438,  0.        ],\n       [-0.15566039,  0.47169816,  0.        ],\n       [-0.15566039,  0.7830188 ,  0.        ],\n       [-0.14150941,  0.8820754 ,  1.        ]], dtype=float32)"
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stroke2points(test_stroke, test_len)"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "outputs": [],
   "source": [
    "def get_data(category, path):\n",
    "    train_set = []\n",
    "    test_set = []\n",
    "    valid_set = []\n",
    "    train_label = []\n",
    "    test_label = []\n",
    "    valid_label = []\n",
    "    stroke_max = 0\n",
    "    stroke_len = 0\n",
    "    for label, c in enumerate(category):\n",
    "            print(c)\n",
    "            npz_file = np.load(os.path.join(path,c),allow_pickle=True,encoding='latin1')\n",
    "            for file in npz_file['train']:\n",
    "                # train_set.append(file)\n",
    "                # train_label.append(label)\n",
    "                stroke_idx = np.where(file[:,2]>0)[0]\n",
    "                stroke_idx_copy = np.concatenate([[-1], stroke_idx[:-1]],axis=0)\n",
    "                stroke_length = stroke_idx - stroke_idx_copy\n",
    "\n",
    "                if max(stroke_length)>stroke_max:\n",
    "                    stroke_max = max(stroke_length)\n",
    "                if len(stroke_length)>stroke_len:\n",
    "                    stroke_len = len(stroke_length)\n",
    "            for file in npz_file['test']:\n",
    "                # test_set.append(file)\n",
    "                # test_label.append(label)\n",
    "                stroke_idx = np.where(file[:,2]>0)[0]\n",
    "                stroke_idx_copy = np.concatenate([[-1], stroke_idx[:-1]],axis=0)\n",
    "                stroke_length = stroke_idx - stroke_idx_copy\n",
    "                if max(stroke_length)>stroke_max:\n",
    "                    stroke_max = max(stroke_length)\n",
    "                if len(stroke_length)>stroke_len:\n",
    "                    stroke_len = len(stroke_length)\n",
    "            for file in npz_file['valid']:\n",
    "                # valid_set.append(file)\n",
    "                # valid_label.append(label)\n",
    "                stroke_idx = np.where(file[:,2]>0)[0]\n",
    "                stroke_idx_copy = np.concatenate([[-1], stroke_idx[:-1]],axis=0)\n",
    "                stroke_length = stroke_idx - stroke_idx_copy\n",
    "                if max(stroke_length)>stroke_max:\n",
    "                    stroke_max = max(stroke_length)\n",
    "                if len(stroke_length)>stroke_len:\n",
    "                    stroke_len = len(stroke_length)\n",
    "    print(stroke_max)\n",
    "    print(stroke_len)\n",
    "    # return {\"train_set\":train_set, \"train_label\": train_label,\n",
    "    #         \"test_set\":test_set, \"test_label\": test_label,\n",
    "    #         \"valid_set\":valid_set, \"valid_label\": valid_label}"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cat.npz\n",
      "fish.npz\n",
      "angel.npz\n",
      "car.npz\n",
      "146\n",
      "48\n"
     ]
    }
   ],
   "source": [
    "#category_all = os.listdir(\"D:\\\\SketchHealer\\\\quickdraw\\\\quickdraw\\\\\")\n",
    "get_data(category, \"./dataset\")"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "\"max: 321 points per stroke\"\n",
    "\"max: 178 stroke per sketch\""
   ],
   "metadata": {
    "collapsed": false
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
