(import [numpy :as np]
        jax
        [jax.numpy :as jnp]
        [jax.experimental.optimizers :as optimizers]
        [neural_tangents :as nt]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [tqdm [tqdm]]
        [matplotlib.pyplot :as plt]
        matplotlib
        [fast_finite_width_ntk :as fast-ntk])

(require [hy.contrib.walk [let]]
         [hy.contrib.slicing [ncut "#:"]])

(pys "def get_one_hot(targets, nb_classes):
          res = np.eye(nb_classes)[np.array(targets).reshape(-1)]
          return res.reshape(list(targets.shape)+[nb_classes])")

(pys "def one_hot_encode(array):
          unique, inverse = np.unique(array, return_inverse=True)
          onehot = np.eye(unique.shape[0])[inverse]
          return onehot")

(defn add-label-noise [y noise [np-rng-seed 3]]
  (setv yhat (np.copy y)
        np-rng (np.random.default-rng np-rng-seed)
        n (get (np.shape y) 0)
        flipped-idx (.choice np-rng (range 0 n) (int (np.floor (* noise n))) :replace False)
        flipped-labels (.integers np-rng 0 (inc (np.max y)) (get (np.shape flipped-idx) 0)))
  (assoc yhat flipped-idx flipped-labels)
  yhat)

(defn ce-loss [predictions target]
  (- (jnp.mean (jnp.sum (* predictions target) :axis 1))))

(defn ce-loss-ind [predictions target]
  (jnp.sum (* predictions target) :axis 1))

(defn ce-loss [predictions target]
  (- (jnp.mean (jnp.sum (* predictions target) :axis 1))))

(defn ce-loss-ind [predictions target]
  (jnp.sum (* predictions target) :axis 1))

(import [jax.nn [log_softmax]])

(defn ce-with-logits-loss [predictions target]
  (ce-loss (log-softmax predictions) target))

(defn ce-with-logits-loss-ind [predictions target]
  (ce-loss-ind (log-softmax predictions) target))

(defn bce-loss [predictions target]
  (setv  [n _] (. target shape))
  (jnp.squeeze (* (/ -1.0 n)
                  (+ (jnp.dot (jnp.transpose (jnp.log (np.clip predictions 1e-12 (- 1 1e-12))))
                              target)
                     (jnp.dot (jnp.transpose (jnp.log (- 1 (np.clip predictions 1e-12 (- 1 1e-12)))))
                              (- 1 target))))))

(defn bce-loss-ind [predictions target]
  (setv  [n _] (. target shape))
  (+ (* target (jnp.log (np.clip predictions 1e-12 (- 1 1e-12))))
     (* (- 1 target) (jnp.log (- 1 (np.clip predictions 1e-12 (- 1 1e-12)))))))

(defn bce-loss [predictions target]
  (setv  [n _] (. target shape))
  (jnp.squeeze (* (/ -1.0 n)
                  (+ (jnp.dot (jnp.transpose (jnp.log (np.clip predictions 1e-12 (- 1 1e-12))))
                              target)
                     (jnp.dot (jnp.transpose (jnp.log (- 1 (np.clip predictions 1e-12 (- 1 1e-12)))))
                              (- 1 target))))))

(defn bce-loss-ind [predictions target]
  (setv  [n _] (. target shape))
  (+ (* target (jnp.log (np.clip predictions 1e-12 (- 1 1e-12))))
     (* (- 1 target) (jnp.log (- 1 (np.clip predictions 1e-12 (- 1 1e-12)))))))

(import [jax.nn [sigmoid]])

(defn bce-with-logits-loss [predictions target]
  (bce-loss (sigmoid predictions) target))

(defn bce-with-logits-loss-ind [predictions target]
  (bce-loss-ind (sigmoid predictions) target))

(defn mse-loss [predictions target]
  #_(jnp.mean (jnp.linalg.norm (- target predictions) :ord 2 :axis 1))
  (jnp.mean (jnp.square (- target predictions))))

(defn mse-loss-ind [predictions target]
  #_(jnp.mean (jnp.linalg.norm (- target predictions) :ord 2 :axis 1))
  (jnp.square (- target predictions)))

(defn create-weight-penalty [penalty]
  (fn [p x y]
    (* penalty (optimizers.l2-norm p))))

(defn create-loss-gradient-penalty-param [calc-loss penalty]
  (setv jac (jax.grad calc-loss :argnums 0))
  (fn [p x y]
    (* penalty (jnp.linalg.norm (get (jax.flatten-util.ravel-pytree (jac p x y)) 0) :ord 2))))

(defn create-loss-gradient-penalty-data [calc-loss penalty]
  (setv jac (jax.grad calc-loss :argnums 1))
  (fn [p x y]
    (* penalty (jnp.linalg.norm (get (jax.flatten-util.ravel-pytree (jac p x y)) 0) :ord 2))))

(defn create-mgs-penalty-trace [net-apply penalty]
  (setv calc-ntk (fast-ntk.empirical-ntk-fn net-apply
                   :trace-axes (, 0 1)
                   :vmap-axes 0
                   :implementation fast-ntk.empirical.NtkImplementation.STRUCTURED_DERIVATIVES
                   ))
  (fn [p x y]
    (* penalty (* (calc-ntk x None p)))))

(defn create-mgs-penalty-det [net-apply bound penalty]
  (setv calc-ntk (fast-ntk.empirical-ntk-fn net-apply
                   :trace-axes (, 1)
                   :vmap-axes 0
                   :implementation fast-ntk.empirical.NtkImplementation.STRUCTURED_DERIVATIVES))
  (fn [p x y]
    (setv ntk  (+ (calc-ntk x None p)
                  (* 1e-1 (jnp.identity (get (jnp.shape x) 0)))))
    (* penalty (jnp.max (jnp.array [bound (get (jnp.linalg.slogdet ntk) 1)])))))

(defn create-rkhs-norm-penalty [net-apply init-params penalty]
  (setv calc-ntk (fast-ntk.empirical-ntk-fn net-apply
                   :trace-axes (, 1)
                   :vmap-axes 0
                   :implementation fast-ntk.empirical.NtkImplementation.STRUCTURED_DERIVATIVES)
        [_ unravel] (jax.flatten-util.ravel-pytree init-params)
        f (fn [p x] (net-apply (unravel p) x))
        jac (jax.jacobian f))
  (fn [p x y]
    (setv sample (jnp.take x (np.arange 10) :axis 0)
          ntk (calc-ntk sample None p)
          f-out (-> (jnp.squeeze (f (get (jax.flatten-util.ravel-pytree p) 0) sample))
                #_(jax.nn.log-softmax)
                #_(jnp.argmax :axis 1)))
    (* penalty (jnp.sum (jnp.squeeze (jnp.matmul (jnp.matmul (jnp.transpose f-out)
                                                             (jnp.linalg.inv ntk))
                                                 f-out))))))

(defn create-rkhs-norm-penalty-2 [net-apply penalty [n 32]]
  (setv jac (jax.jacobian net-apply :argnums 1))
  (fn [p x y]
    (setv j (jnp.diagonal (jnp.swapaxes (jac p (jnp.take x (jnp.arange n) :axis 0)) 1 2)))
    (* penalty (jnp.max (jnp.sum (jnp.linalg.norm j :axis 1 :ord 2) :axis 0)))))

(defn setup-loss-tracker [calc-loss x-test y-test opt-get n-samples]
  (dict :init (fn [n-iters] (dict :metric (np.full (, (ceil (/ n-iters (// n-iters n-samples))) #_(ceil (/ n-iters run-every)) 2) np.Inf)))
        :should-run (fn [i n-iters] False (= (% i (// n-iters n-samples)) #_(% i run-every) 0))
        :run (fn [i n-iters metric-state opt-state x y [rng None]]
               (setv loss (np.array [(calc-loss (opt-get opt-state) x y :rng rng)
                                                      (calc-loss (opt-get opt-state) x-test y-test :rng rng)]))
               (, (dict :metric (do (assoc (:metric metric-state)
                                           (int (// i (// n-iters n-samples)) #_(/ i run-every))
                                           loss)
                                    (:metric metric-state)))
                  {"train-loss" (get loss 0)
                   "test-loss" (get loss 1)
                   "min-test-loss" (np.min (np.take (:metric metric-state) 1 :axis 1))}))))

(defn setup-loss-tracker-run-last [calc-loss x-test y-test opt-get run-last [rng None] [debug False]]
  (dict :init (fn [n-iters] (dict :metric (np.zeros (, run-last 2))))
        :should-run (fn [i n-iters] (> i (- n-iters run-last)))
        :run (fn [i n-iters metric-state opt-state x y]
               (, (dict :metric (do (assoc (:metric metric-state)
                                           (int (- n-iters i))
                                           (np.array [(calc-loss (opt-get opt-state) x y :rng rng)
                                                      (calc-loss (opt-get opt-state) x-test y-test :rng rng)]))
                                    (when debug
                                      (print i (get (:metric metric-state) (int (/ i run-every)))))
                                    (:metric metric-state)))
                  None))))


(defn create-accuracy-metric [net-apply multi-class]
  (jax.jit (if multi-class
               (fn [params x y [rng None]] (-> (net-apply params x :rng rng)
                                    (jax.nn.log-softmax)
                                    (jnp.argmax :axis 1)
                                    (= (jnp.argmax y :axis 1))
                                    (jnp.mean)))
               (fn [params x y [rng None]] (-> (net-apply params x :rng rng)
                                    (jax.nn.sigmoid)
                                    (< 0.5)
                                    (= y)
                                    (jnp.mean))))))

(defn setup-accuracy-tracker [x-test y-test net-apply opt-get multi-class n-samples #_run-every [plot False]]
  (setv calc-accuracy (create-accuracy-metric net-apply multi-class))
  (dict :init (fn [n-iters]
                (setv d (dict :metric (np.zeros (, (ceil (/ n-iters (// n-iters n-samples)))
                                                   #_(ceil (/ n-iters run-every)) 2))
                              :fig (when plot (plt.figure))
                              :ax (when plot (plt.axes))))
                (when plot
                  (.set-ylim (:ax d) 0 1)
                  (assoc d (. :line name) (get (.plot (:ax d) (np.take (:metric d) 1 :axis 1)) 0))
                  (plt.show :block False))
                d)
        :should-run (fn [i n-iters] (= (% i (// n-iters n-samples) #_run-every) 0))
        :run (fn [i n-iters metric-state opt-state x y [rng None]]
               (setv acc (np.array [(calc-accuracy (opt-get opt-state) x y :rng rng)
                                    (calc-accuracy (opt-get opt-state) x-test y-test :rng rng)]))
               (when plot
                 (.set-ydata (:line metric-state) (np.take (:metric metric-state) 1 :axis 1))
                 (.draw (. (:fig metric-state) canvas))
                 (.flush-events (. (:fig metric-state) canvas)))

               (, (dict :metric (do (assoc (:metric metric-state)
                                           (int (// i (// n-iters n-samples)) #_(/ i run-every))
                                           acc)
                                    (:metric metric-state)))
                  {"train-acc" (get acc 0)
                   "test-acc" (get acc 1)
                   "max-test-acc" (np.max (np.take (:metric metric-state) 1 :axis 1))}))))

(defn setup-accuracy-tracker-run-last [x-test y-test net-apply opt-get multi-class run-last [rng None] [plot False]]
  (setv calc-accuracy (create-accuracy-metric net-apply multi-class :rng rng))
  (dict :init (fn [n-iters]
                (dict :metric (np.zeros (, run-last 2))))
        :should-run (fn [i n-iters] (> i (- n-iters run-last)))
        :run (fn [i n-iters metric-state opt-state x y]
               (setv acc (np.array [(calc-accuracy (opt-get opt-state) x y)
                                    (calc-accuracy (opt-get opt-state) x-test y-test)]))
               (, (dict :metric (do (assoc (:metric metric-state)
                                           (int (- n-iters i))
                                           acc)
                                    (:metric metric-state)))
                  None))))

(defn create-trace-metric [net-apply opt-get]
  (setv calc-ntk (fast-ntk.empirical-ntk-fn net-apply
                   :trace-axes (, 0 1)
                   :vmap-axes 0
                   :implementation fast-ntk.empirical.NtkImplementation.STRUCTURED_DERIVATIVES))
  (jax.jit (fn [opt-state x [rng None]]
             (calc-ntk x None (opt-get opt-state) :rng rng))))

(defn setup-trace-tracker [net-apply opt-get n-samples [rng None]]
  (setv calc-trace (create-trace-metric net-apply opt-get))
  (dict :init (fn [n-iters] (dict :metric (np.zeros (ceil (/ n-iters (// n-iters n-samples)))
                                                    #_(ceil (/ n-iters run-every)))))
        :should-run (fn [i n-iters] (= (% i (// n-iters n-samples)) 0))
        :run (fn [i n-iters metric-state opt-state x y [rng None]]
               (, (dict :metric (do
                                  (assoc (:metric metric-state)
                                         (int (// i (// n-iters n-samples)))
                                         (calc-trace opt-state x :rng rng))
                                  (:metric metric-state)))
                  {"trace" (get (:metric metric-state) (int (// i (// n-iters n-samples))))}))))

(defn create-determinant-metric [net-apply opt-get]
  (setv calc-ntk (fast-ntk.empirical-ntk-fn net-apply
                   :trace-axes (, 1)
                   :vmap-axes 0
                   :implementation fast-ntk.empirical.NtkImplementation.STRUCTURED_DERIVATIVES))
  (jax.jit (fn [opt-state x [rng None]]
             (get (jnp.linalg.slogdet (calc-ntk x None (opt-get opt-state) :rng rng)) 1))))

(defn setup-determinant-tracker [net-apply opt-get n-samples]
  (setv calc-det (create-determinant-metric net-apply opt-get))
  (dict :init (fn [n-iters] (dict :metric (np.zeros (ceil (ceil (/ n-iters (// n-iters n-samples)))
                                                          #_(/ n-iters run-every)))))
        :should-run (fn [i n-iters] (= (% i (// n-iters n-samples)) 0))
        :run (fn [i n-iters metric-state opt-state x y [rng None]]
               (, (dict :metric (do
                                  (assoc (:metric metric-state)
                                         (int (// i (// n-iters n-samples)))
                                         (calc-det opt-state x :rng rng))
                                  (:metric metric-state)))
                  {"det" (get (:metric metric-state) (int (// i (// n-iters n-samples))))}))))

(defn create-decision-grid [X [h 100] [bounds None]]
  (setv X1 (jnp.take X 0 :axis 1)
        X2 (jnp.take X 1 :axis 1)
        x1-min (if bounds (get bounds 0) (- (jnp.min X1) 0.01))
        x1-max (if bounds (get bounds 1) (+ (jnp.max X1) 0.01))
        x2-min (if bounds (get bounds 2) (- (jnp.min X2) 0.01))
        x2-max (if bounds (get bounds 3) (+ (jnp.max X2) 0.01))
        [x1-grid x2-grid] (jnp.meshgrid (jnp.linspace x1-min x1-max :num h)
                                        (jnp.linspace x2-min x2-max :num h))
        pred-grid (jnp.hstack [(jnp.reshape (jnp.ravel x1-grid) (, -1 1))
                                    (jnp.reshape (jnp.ravel x2-grid) (, -1 1))]))
  (, x1-grid x2-grid pred-grid))

(defn create-decision-grid [X [h 100] [bounds None]]
  (setv X1 (jnp.take X 0 :axis 1)
        X2 (jnp.take X 1 :axis 1)
        x1-min (if bounds (get bounds 0) (- (jnp.min X1) 0.01))
        x1-max (if bounds (get bounds 1) (+ (jnp.max X1) 0.01))
        x2-min (if bounds (get bounds 2) (- (jnp.min X2) 0.01))
        x2-max (if bounds (get bounds 3) (+ (jnp.max X2) 0.01))
        [x1-grid x2-grid] (jnp.meshgrid (jnp.linspace x1-min x1-max :num h)
                                        (jnp.linspace x2-min x2-max :num h))
        pred-grid (jnp.hstack [(jnp.reshape (jnp.ravel x1-grid) (, -1 1))
                                    (jnp.reshape (jnp.ravel x2-grid) (, -1 1))]))
  (, x1-grid x2-grid pred-grid))

(defn plot-decision-boundary [pred-fun x [h 100] [norm None] [colourbar False] [bounds None] [alpha 1.0]]
  (setv [x1-grid x2-grid pred-grid] (create-decision-grid x h :bounds bounds)
        y (pred-fun pred-grid)
        y-grid (np.reshape y (np.shape x1-grid))
        my-cmap (plt.cm.viridis (np.arange (. plt.cm.viridis N))))

  (pys "my_cmap[:, -1] = alpha")

  (plt.contourf x1-grid x2-grid y-grid :norm norm :cmap (matplotlib.colors.ListedColormap my-cmap))
  (when colourbar (plt.colorbar)))

(defn create-opt-step [calc-loss calc-penalty opt-update opt-get]
  (setv loss (fn [p x y [rng None]]
               (+ (calc-loss p x y :rng rng)
                  (calc-penalty p x y)))
        loss-grad (jax.grad loss))
  (jax.jit (fn [i opt-state x y [rng None]]
             (opt-update i (loss-grad (opt-get opt-state) x y :rng rng) opt-state))))


(require [hy.contrib.slicing ["#:" ncut]])

(defn batch-data [x y batch-size epochs [jax-rng None]]
  (setv jax-rng (if (is jax-rng None) (jax.random.PRNGKey 0) jax-rng)
        subrng (jax.random.split jax-rng epochs)
        n (get (jnp.shape x) 0)
        [num-batches leftover] (divmod n batch-size)
        num-batches (+ num-batches (bool leftover)))
  (, (* epochs num-batches)
     (fn [] (for [e (range epochs)]
              (setv perm (jax.random.permutation (get subrng e) n))
              (for [i (range num-batches)]
                (setv idx (jnp.take perm (jnp.arange (* i batch-size) (* (inc i) batch-size))) #_(ncut perm (: (* i batch-size) (* (inc i) batch-size))))
                (when (< (jnp.size idx) batch-size)
                  (setv idx (jnp.concatenate [idx (jnp.take perm (jnp.arange 0 (- batch-size (jnp.size idx)))) #_(ncut perm (: 0 (- batch-size (len idx))))])))
                (yield (, (jnp.take x idx :axis 0) (jnp.take y idx :axis 0))))))))

(defn train-model [epochs opt-state opt-step
                   x-train y-train
                   [batch-size -1]
                   [metrics None]
                   [show-progress True]
                   [progress-pos 0]
                   [jax-rng None]]
  (when (<= batch-size 0)
    (setv batch-size (get (jnp.shape x-train) 0)))
  (setv [n-iters create-batches] (batch-data x-train y-train batch-size epochs)
        batches (create-batches)
        subrng (if (is jax-rng None) None (jax.random.split jax-rng n-iters)))
  (when (not (is metrics None))
    (for [m metrics]
      (assoc m (. :state name) ((:init m) n-iters))))
  (with [t (tqdm (range n-iters) :disable (not show-progress) :leave True :position progress-pos)]
    (for [i t]
      (setv [x y] (next batches) #_[(jnp.take x-train (jnp.arange 32) :axis 0) (jnp.take y-train (jnp.arange 32) :axis 0)]
            opt-state (opt-step i opt-state x y :rng (when (not (is subrng None)) (get subrng i))))
      (unless (is metrics None)
        (setv status (dict))
        (for [m metrics]
          (when ((:should-run m) i n-iters)
            (setv [m-update s] ((:run m) i n-iters (:state m) opt-state x y
                                :rng (when (not (is subrng None)) (get subrng i)))
                  status (merge status (if (is s None) (dict) s)))
            (assoc m (. :state name) (merge (:state m) m-update))))
        #_(setv updates (lfor m metrics
                            :if ((:should-run m) i n-iters)
                            :setv [m-update s] ((:run m) i n-iters (:state m) opt-state x y)
                            :do (assoc m (. :state name) (merge (:state m) m-update))
                            (, m s))
              metrics (lfor u updates (get u 0))
              status (dfor u updates (.iteritems (get u 1))))
        (when (> (len status) 0) (.set-postfix t status)))))
  (, opt-state metrics))

(import [sklearn.model_selection [ParameterGrid KFold]])

(defn grid-search-cv [param-space epochs create-opt-step new-opt-state
                      batch-size metric create-iterator]
  (setv grid (ParameterGrid param-space)
        perf (dict)
        rng-seed (itertools.count))
  (with [t (tqdm grid)]
    (lfor params t
          [params (let [opt-step (create-opt-step params)
                        it (create-iterator)]
                    (-> (lfor [x-train x-test y-train y-test] it
                              :do (setv rng (jax.random.PRNGKey (next rng-seed))
                                    [opt-state _] (train-model epochs
                                                                   (new-opt-state rng)
                                                                   opt-step
                                                                   x-train y-train
                                                                   :batch-size batch-size
                                                                   :show-progress True
                                                                   :jax-rng rng))
                              (metric opt-state x-test y-test))
                        (np.array)
                        (np.mean)))])))

