(setv attach-dir ".")
(setv drop-dir (quote "../robustnes/dropout"))
(setv wp-dir (quote "../robustness/weight_penalty"))
(setv lg-dir (quote "../robustness/loss_gradient_parameter"))
(setv mgs-dir (quote "../robustness/mgs_trace"))
(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(defmain [[#** args]]
  (setv attach-dir ".")


(import [radar_factory [radar_factory]])

;; All tested settings
(setv settings [(, "Training size" (/ (np.array [10000 3000 1500 500]) 60e3) "robustness_training-size_{:.2f}.npy")
                (, "Label noise" (np.array [0.0 0.3 0.5 0.8]) "robustness_label-noise_{:.1f}.npy")
                (, "Batch size" (np.array [16 32 64 128]) "robustness_batch-size_{}.npy")
                (, "Learning rate" (np.array [5e-1 1e-1 1e-2 1e-3]) "robustness_learning-rate_{:.3f}.npy")
                (, "Epochs" (np.array [25 50 100 150]) "robustness_epochs_{}.npy")]
      mgs (, "MGS" "tab:red" mgs-dir)
      others [(, "Loss grad." "tab:green" lg-dir)
              (, "Weight" "tab:blue" wp-dir)
              (, "Dropout" "tab:brown" drop-dir)])

(defn calc-stats [adir vals fname [data 'test-acc] [metric 'np.mean] [window 5]]
  (setv m (hy.eval metric))
  (-> (lfor [i v] (enumerate vals)
            :do (setv f (open (os.path.join adir (.format fname v)) "rb")
                      train-loss (np.load f)
                      train-acc (np.load f)
                      test-loss (np.load f)
                      test-acc (np.load f)
                      trace (np.load f)
                      det (np.load f))
            #_(np.take (np.flipud (np.mean (hy.eval data) :axis 0))
                       (np.arange window))
            (np.mean (np.take (np.fliplr (hy.eval data))
                      (np.arange window)
                      :axis 1) :axis 1))
      (np.array)
      (m)))

(do
  (setv theta (radar-factory 5 :frame "polygon")
        [fig axs] (plt.subplots 1 3 :figsize (, 20 7) :subplot-kw {"projection" "radar"}))

  (plt.subplots-adjust :hspace 0.2 :wspace 0.25 :left 0.03 :top 0.92 :bottom 0.03 :right 0.97)
  (for [[i [[m c adir] ax]] (enumerate (zip others (. axs flat)))]
    (setv c-mgs (get mgs 1)
          adir-mgs (get mgs 2))
    (.set-ylim ax [0.0 1.0])
    (.set-rgrids ax (, 0.2 0.4 0.6 0.8))
    (.set-yticklabels ax [])

    (if (= i 1)
        (do
          (.text ax (np.radians 15) 1.05 "Training size" :fontsize 20)
          (.text ax (np.radians 78) 1.6 "Label noise" :fontsize 20 :rotation -15)
          (.text ax (np.radians 127) 1.44 "Batch size" :fontsize 20)
          (.text ax (np.radians 216) 1.05 "Learning rate" :fontsize 20)
          (.text ax (np.radians 288) 1.0 "Epochs" :fontsize 20 :rotation 15)
          #_(for [[tick pos] (zip (lfor [s _ _] settings s) (np.linspace 0 360 6))]
              (.text ax (np.radians pos) 1.0 f"{tick}" :fontsize 15))
          (for [[tick pos] (zip [20 40 60 80 100] [0.1 0.23 0.4 0.6 0.8])]
            (.text ax (np.radians (- (/ 180 5))) pos (if (= tick 100) f"{tick}% (test accuracy)"  f"{tick}%")
                   :fontsize 17 :fontweight "semibold"))
          (.set-varlabels ax [] #_(lfor [s _ _] settings s)))
        (.set-varlabels ax [] #_(lfor [s _ _] settings s)))

    (.set-title ax m :fontsize 30 :pad 40)

    (setv median (np.array (lfor [_ vals fname] settings (calc-stats adir vals fname :metric 'np.median)))
          upper (np.array (lfor [_ vals fname] settings (calc-stats adir vals fname
                                                                    :metric '(fn [x] (np.quantile x 0.75)))))
          lower (np.array (lfor [_ vals fname] settings (calc-stats adir vals fname
                                                                    :metric '(fn [x] (np.quantile x 0.25)))))
          median-mgs (np.array (lfor [_ vals fname] settings (calc-stats adir-mgs vals fname :metric 'np.median)))
          upper-mgs (np.array (lfor [_ vals fname] settings (calc-stats adir-mgs vals fname
                                                                        :metric '(fn [x] (np.quantile x 0.75)))))
          lower-mgs (np.array (lfor [_ vals fname] settings (calc-stats adir-mgs vals fname
                                                                        :metric '(fn [x] (np.quantile x 0.25))))))
    (.plot ax theta median :color c :lw 3)
    (.plot ax theta upper :color c :lw 1.5 :ls "dashed")
    (.plot ax theta lower :color c :lw 1.5 :ls "dashed")
    (.plot ax theta median-mgs :color c-mgs :lw 3)
    (.plot ax theta upper-mgs :color c-mgs :lw 1.5 :ls "dashed")
    (.plot ax theta lower-mgs :color c-mgs :lw 1.5 :ls "dashed")
    #_(.fill ax theta mean :facecolor c :alpha 0.25 :label "_nolegend_")
    #_(.fill-between ax theta (- mean std) (+ mean std))
    )
  (setv names (lfor [n c _] others (, n c)))
  (.append names (, "MGS" "tab:red"))
  (for [[n c] names] (.plot ax np.NaN np.NaN :c c :lw 4 :label n))

  (.plot ax #_(get axs 0) np.NaN np.NaN :c "black" :lw 4 :label "Median")
  (.plot ax #_(get axs 0) np.NaN np.NaN :c "black" :lw 4 :ls "dashed" :label "Quartile")

  (.legend fig :handles [(get (. ax lines) 6) (get (. ax lines) 7) (get (. ax lines) 8) (get (. ax lines) 9)]
           :loc "lower center" :ncol 4 :fontsize 20)
  (.legend fig :handles [(get (. ax lines) 10) (get (. ax lines) 11)] :loc (, 0.63 0.25) :ncol 1 :fontsize 20)
  (plt.tight-layout)
)

(plt.savefig (os.path.join attach-dir "mnist-radar-plot.svg"))
