(setv attach-dir ".")
(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(import [sklearn.preprocessing [normalize]])

(defn partial-flatten [x]
  (np.reshape x (, (get (np.shape x) 0) -1)))

(defn mnist-data [[train-set "vanilla"] [test-set "vanilla"] [conv False]]
  (setv
    train-images (np.expand-dims (np.squeeze (with [f (open f"../../mnist_c/{train-set}/train_images.npy" "rb")] (np.load f))) 3)
    train-labels (with [f (open f"../../mnist_c/{train-set}/train_labels.npy" "rb")] (np.load f))
    test-images (np.expand-dims (np.squeeze (with [f (open f"../../mnist_c/{test-set}/test_images.npy" "rb")] (np.load f))) 3)
    test-labels (with [f (open f"../../mnist_c/{test-set}/test_labels.npy" "rb")] (np.load f))
    train-images (/ (if conv train-images (partial-flatten train-images)) (np.float32 255))
    test-images (/ (if conv test-images (partial-flatten test-images)) (np.float32 255)))
  (, train-images test-images train-labels test-labels))

(defn mnist-train-net [train-net input-shape [conv False] [optimizer None]]
  (setv
    [net-init net-apply] (stax.serial (unpack-iterable train-net))
    net-apply (jax.jit net-apply)
    [opt-init opt-update opt-get] (if (is optimizer None)
                                      (optimizers.sgd :step-size #_(optimizers.piecewise-constant [1500] [1e-1 1e-2])
                                                      (optimizers.exponential-decay
                                                        :step-size 1e-1
                                                        :decay-rate 0.99995
                                                        :decay-steps 1))
                                      (hy.eval optimizer))
    calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.ce-with-logits-loss (net-apply p x :rng rng) y)))
    new-opt-state (fn [rng] (opt-init (get (net-init rng (if conv
                                                             (, -1 input-shape input-shape 1)
                                                             (, -1 input-shape))) 1))))
  (, net-apply calc-loss opt-update opt-get new-opt-state))

(defn mnist-test-net [test-net]
  (setv net-apply (jax.jit (get (stax.serial (unpack-iterable test-net)) 1))
        calc-loss (jax.jit (fn [p x y [rng None]] (nn-utils.ce-with-logits-loss (net-apply p x :rng rng) y))))
  (, net-apply calc-loss))
(defn mnist-eval [train-images test-images train-labels test-labels train-apply test-apply
                  calc-loss-test epochs opt-get opt-step new-opt-state jax-rng batch-size attach-dir
                  [fname "perf.npy"] [splitter None] [label-noise 0.0]
                  [show-progress True]]
  (setv train-labels-one-hot (jax.nn.one-hot (nn-utils.add-label-noise train-labels label-noise) 10)
        test-labels-one-hot (jax.nn.one-hot test-labels 10)
        metrics [(nn-utils.setup-loss-tracker
                   calc-loss-test test-images test-labels-one-hot
                   opt-get 100)
                 (nn-utils.setup-accuracy-tracker
                   test-images test-labels-one-hot test-apply
                   opt-get True 100)
                 (nn-utils.setup-trace-tracker
                   train-apply opt-get
                   100)
                 (nn-utils.setup-determinant-tracker
                   train-apply opt-get
                   100)])
  (setv [_ #* subrng] (jax.random.split jax-rng (inc (.get-n-splits splitter)))
        perf (lfor [i data] (enumerate (.split splitter :X train-images :y train-labels-one-hot))
                   (let [[train test] data
                         x-train (jnp.take train-images train :axis 0)
                         ;; x-test (np.take train-images test :axis 0)
                         y-train (jnp.take train-labels-one-hot train :axis 0)
                         ;; y-test (np.take train-labels-one-hot test :axis 0)
                         n (np.size train)
                         [opt-state metrics]
                         (nn-utils.train-model
                           epochs (new-opt-state (get subrng i))
                           opt-step x-train y-train
                           :batch-size batch-size
                           :metrics metrics
                           :show-progress show-progress
                           :progress-pos 2
                           :jax-rng (get subrng i))]
                     [(:metric (:state (get metrics 0)))
                      (:metric (:state (get metrics 1)))
                      (:metric (:state (get metrics 2)))
                      (:metric (:state (get metrics 3)))]))
        loss (np.array (lfor p perf (get p 0)))
        acc (np.array (lfor p perf (get p 1)))
        trace (np.array (lfor p perf (get p 2)))
        det (np.array (lfor p perf (get p 3)))
        train-loss (np.take loss 0 :axis 2)
        train-acc (np.take acc 0 :axis 2)
        test-loss (np.take loss 1 :axis 2)
        test-acc (np.take acc 1 :axis 2)
        test-loss-mean (np.mean test-loss :axis 0)
        test-loss-std (np.std test-loss :axis 0)
        test-acc-mean (np.mean test-acc :axis 0)
        test-acc-std (np.std test-acc :axis 0)
        perf-file (os.path.join attach-dir fname))
  (print f"Test loss {(get test-loss-mean -1) :.4f} \pm ({(get test-loss-std -1) :.4f})")
  (print f"Test accuracy {(get test-acc-mean -1) :.4f} \pm ({(get test-acc-std -1)  :.4f})")
  (with [f (open perf-file "wb")]
    (np.save f train-loss)
    (np.save f train-acc)
    (np.save f test-loss)
    (np.save f test-acc)
    (np.save f trace)
    (np.save f det)))

(import [sklearn.model_selection [StratifiedShuffleSplit]])

(setv num-outputs 10
      net
      [(stax.Conv :out-chan 64 :filter-shape (, 3 3) :padding "SAME")
       stax.Relu
       (stax.MaxPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")

       (stax.Conv :out-chan 128 :filter-shape (, 3 3) :padding "SAME")
       stax.Relu
       (stax.MaxPool :window-shape (, 2 2) :strides (, 2 2) :padding "VALID")

       (stax.Conv :out-chan 256 :filter-shape (, 2 2) :padding "SAME")
       stax.Relu

       (stax.Conv :out-chan 128 :filter-shape (, 2 2) :padding "SAME")
       stax.Relu

       (stax.Conv :out-chan 64 :filter-shape (, 2 2) :padding "SAME")
       stax.Relu

       stax.Flatten

       (stax.Dense 256)
       stax.Relu

       (stax.Dense 256)
       stax.Relu

       (stax.Dense num-outputs)])

(setv [train-images test-images train-labels test-labels] (mnist-data :train-set "motion_blur" :conv True)
      input-shape (get (np.shape train-images) 1)
      [train-apply calc-loss-train opt-update opt-get new-opt-state] (mnist-train-net net input-shape :conv True)
      [test-apply calc-loss-test] (mnist-test-net net)
      penalty (constantly 0.0)
      opt-step (nn-utils.create-opt-step calc-loss-train penalty opt-update opt-get)
      splitter (StratifiedShuffleSplit :n-splits 5 :train-size (/ 6000 (np.size train-labels)))
      epochs 30)

(mnist-eval train-images test-images train-labels test-labels
            train-apply test-apply calc-loss-test epochs
            opt-get opt-step new-opt-state
            :splitter splitter
            :jax-rng (jax.random.PRNGKey 0)
            :batch-size 32
            :attach-dir attach-dir
            :label-noise 0.3
            :show-progress False)
