open Base
open Owl
open Vae
open Variational
open Owl_parameters
open Accessor.O

let data_dir = Cmdargs.in_dir "-data"
let in_dir = Cmdargs.in_dir "-d"
let n = Cmdargs.(get_int "-n" |> force ~usage:"-n [n]")
let m = Cmdargs.(get_int "-m" |> force ~usage:"-m [m]")
let init_prms_file = Cmdargs.(get_string "-init_prms")
let reuse = Cmdargs.get_string "-reuse"
let n_beg = n / m

type setup =
  { n : int
  ; m : int
  ; n_trials : int
  ; n_val_trials : int
  ; n_test_trials : int
  ; n_steps : int
  ; n_neural : int
  ; n_hand : int
  }

let dt = 0.025

let setup =
  C.broadcast' (fun () ->
      match reuse with
      | Some _ -> Misc.read_bin (in_dir "setup.bin")
      | None ->
        let s =
          { n
          ; m
          ; n_trials = 160
          ; n_val_trials = 1
          ; n_test_trials = 40
          ; n_steps = 160
          ; n_neural = 130
          ; n_hand = 2
          }
        in
        Misc.save_bin ~out:(in_dir "setup.bin") s;
        s)


(* ----------------------------------------- 
   -- Define model
   ----------------------------------------- *)

module U = Priors.Gaussian_Init (struct
  let n_beg = n_beg
end)

module L_neural = Likelihoods.Poisson (struct
  let label = "neural"
  let link_function x = AD.Maths.exp x
  let d_link_function x = AD.Maths.exp x
  let d2_link_function x = AD.Maths.exp x

  (* let link_function x = AD.Maths.(F 0.5 * (x + sqrt (sqr x + F 4.)))
  let d_link_function x = AD.Maths.(F 0.5 * (F 1. + (x / sqrt (sqr x + F 4.))))
  let d2_link_function x = AD.Maths.(F 2. / ((sqr x + F 4.) * sqrt (sqr x + F 4.))) *)
  let dt = AD.F dt
  let normalize_c = false
end)

module L_hand = Likelihoods.Gaussian (struct
  let label = "hand"
  let normalize_c = false
end)

module L = Likelihoods.Pair (L_neural) (L_hand)

module D = Dynamics.Mini_GRU_IO (struct
  let phi x = AD.Maths.(AD.requad x - F 1.)
  let d_phi x = AD.d_requad x
  let sigma x = AD.Maths.sigmoid x
  let d_sigma x = AD.Maths.(exp (neg x) / sqr (F 1. + exp (neg x)))
  let n_beg = Some n_beg
end)

module E = Elbobo.ARD

module X = struct
  let n = setup.n
  let m = setup.m
  let n_steps = setup.n_steps
  let diag_time_cov = false
  let n_beg = n_beg
end

module Model = VAE (U) (D) (L) (E) (X)

(* define amputated models for testing purposes *)
module Model_neural = VAE (U) (D) (L_neural) (E) (X)
module Model_hand = VAE (U) (D) (L_hand) (E) (X)

(* ----------------------------------------- 
   -- Fetch the data and process slightly
   ----------------------------------------- *)

let squash x = Mat.(signum x * log (1. $+ abs x))
let unsquash x = Mat.(signum x * (exp (abs x) -$ 1.))

let compute_R2 ~lambda x x' =
  let y = x' in
  let nx = Mat.col_num x in
  let c =
    let xt = Mat.transpose x in
    let xtx_inv = Linalg.D.linsolve Mat.((xt *@ x) + (lambda $* eye nx)) (Mat.eye nx) in
    Mat.(xtx_inv *@ xt *@ y)
  in
  let new_x = Mat.(x *@ c) in
  let residuals = Mat.(new_x - x') |> Mat.l2norm_sqr' in
  let sstot = Mat.(x' - mean ~axis:0 x') |> Mat.l2norm_sqr' in
  1. -. (residuals /. sstot), c


(* generate training and test data right away *)
let train_data, test_data, val_data =
  C.broadcast' (fun () ->
      match reuse with
      | Some _ ->
        ( Misc.read_bin (in_dir "train_data.bin")
        , Misc.read_bin (in_dir "test_data.bin")
        , Misc.read_bin (in_dir "val_data.bin") )
      | None ->
        let data =
          let x = Mat.load (data_dir "neurons.bin") in
          let y = Mat.load (data_dir "velocity.bin") in
          (* remove the silent episode at the end *)
          let x = Mat.get_slice [ [ 0; 35000 - 1 ] ] x in
          let y = Mat.get_slice [ [ 0; 35000 - 1 ] ] y in
          (* standardize and de-sparsity *)
          (* let y = Mat.((y - mean ~axis:0 y) / std ~axis:0 y) in *)
          let y = squash Mat.((y - mean ~axis:0 y) / std ~axis:0 y) in
          (* chunk the data *)
          Array.init
            (setup.n_trials + setup.n_test_trials + setup.n_val_trials)
            ~f:(fun i ->
              let slice = [ [ i * setup.n_steps; ((i + 1) * setup.n_steps) - 1 ] ] in
              let x = Mat.get_slice slice x |> AD.pack_arr in
              let y = Mat.get_slice slice y |> AD.pack_arr in
              let o : L.data = { fst = x; snd = y } in
              { u = None; z = None; o })
        in
        let ids =
          Array.init
            (setup.n_trials + setup.n_test_trials + setup.n_val_trials)
            ~f:(fun i -> i)
          |> Stats.shuffle
        in
        let train_data =
          let ids_train = Array.sub ids ~pos:0 ~len:setup.n_trials in
          Array.map ids_train ~f:(fun i -> data.(i))
        in
        let test_data =
          let ids_test = Array.sub ids ~pos:setup.n_trials ~len:setup.n_test_trials in
          Array.map ids_test ~f:(fun i -> data.(i))
        in
        let val_data =
          let ids_val =
            Array.sub
              ids
              ~pos:(setup.n_trials + setup.n_test_trials)
              ~len:setup.n_val_trials
          in
          Array.map ids_val ~f:(fun i -> data.(i))
        in
        Misc.save_bin ~out:(in_dir "random_ids.bin") ids;
        Misc.save_bin ~out:(in_dir "train_data.bin") train_data;
        Misc.save_bin ~out:(in_dir "test_data.bin") test_data;
        Misc.save_bin ~out:(in_dir "val_data.bin") val_data;
        let save_data label data =
          Array.iteri data ~f:(fun i data ->
              let file label' = in_dir (Printf.sprintf "%s_data_%s_%i" label label' i) in
              Option.iter data.z ~f:(fun z ->
                  Mat.save_txt ~out:(file "latent") (AD.unpack_arr z));
              L.save_data ~prefix:(file "o") data.o)
        in
        save_data "train" train_data;
        save_data "test" test_data;
        save_data "val" val_data;
        train_data, test_data, val_data)


let _ = C.print_endline "Data generated and broadcast."

(* ----------------------------------------- 
   -- Initialise parameters and train
   ----------------------------------------- *)

let (init_prms : Model.P.p) =
  C.broadcast' (fun () ->
      match init_prms_file with
      | Some f -> Misc.read_bin f
      | None ->
        let generative_prms =
          match reuse with
          | Some f ->
            let (prms : Model.P.p) = Misc.read_bin (in_dir f) in
            prms.generative
          | None ->
            let n = setup.n
            and m = setup.m in
            (* let prior = Priors.Student.init ~spatial_std:1.0 ~nu:20. ~m learned in *)
            let prior = Priors.Gaussian.init ~spatial_std:1.0 ~first_bin:1. ~m learned in
            let dynamics = D.init ~radius:0.1 ~n ~m learned in
            let likelihood : L.P.p =
              { fst = L_neural.init ~n ~n_output:setup.n_neural learned
              ; snd = L_hand.init ~sigma2:0.5 ~n ~n_output:setup.n_hand learned
              }
            in
            let elbobo_prior : E.P_prior.p =
              { ard_scales = learned ~above:1E-6 (AD.Mat.ones 1 n)
              ; a = learned ~above:1E-6 (AD.F 0.3)
              ; b = Option.map dynamics.b ~f:(fun _ -> learned ~above:1E-6 (AD.F 0.3))
              ; c =
                  [| learned ~above:1E-6 (AD.F 0.3), "neural"
                   ; learned ~above:1E-6 (AD.F 0.3), "hand"
                  |]
              }
            in
            Generative_P.{ prior; dynamics; likelihood; elbobo_prior }
        in
        let elbobo_posterior : E.P_posterior.p =
          { a_mean = learned (AD.Mat.zeros setup.n setup.n)
          ; b_mean = generative_prms.dynamics.b
          ; c_mean =
              [| learned (AD.Mat.zeros setup.n_neural setup.n), "neural"
               ; learned (AD.Mat.gaussian ~sigma:0.1 setup.n_hand setup.n), "hand"
              |]
          ; a_var = learned ~above:1E-6 (AD.Mat.create setup.n setup.n 0.001)
          ; b_var =
              Option.map generative_prms.dynamics.b ~f:(fun _ ->
                  learned ~above:1E-6 (AD.Mat.create setup.m setup.n 0.001))
          ; c_var =
              [| learned ~above:1E-6 (AD.Mat.create setup.n_neural setup.n 0.001)
               ; learned ~above:1E-6 (AD.Mat.create setup.n_hand setup.n 0.001)
              |]
          }
        in
        Model.init ~tie:true generative_prms elbobo_posterior learned)


let _ = C.print "prms initialized"
let elbobo_sample (p : Model.P.p) = p
let elbobo_exploit (p : Model.P.p) : Model.P.p = p

let save_results ?u_init prefix prms data =
  let prms = C.broadcast prms |> elbobo_exploit in
  let file s = prefix ^ "." ^ s in
  C.root_perform (fun () ->
      Misc.save_bin ~out:(file "params.bin") prms;
      Model.P.save_to_files ~prefix ~prms);
  Array.iteri data ~f:(fun i dat_trial ->
      if Int.(i % C.n_nodes = C.rank)
      then (
        try
          let u_init =
            match u_init with
            | None -> None
            | Some u -> u.(i)
          in
          Option.iter u_init ~f:(fun u ->
              Owl.Mat.save_txt ~out:(file (Printf.sprintf "u_init_%i" i)) u);
          let _ = C.print_endline "more saving 0" in
          let mu = Model.posterior_mean ~conv_threshold:1E-7 ~u_init ~prms dat_trial in
          (* Owl.Mat.save_txt
          ~out:(file (Printf.sprintf "posterior_u_%i" i))
          (AD.unpack_arr mu); *)
          let _ = C.print_endline "more saving 1" in
          let us, zs, os = Model.predictions ~n_samples:100 ~prms mu in
          let process label a =
            let a = AD.unpack_arr a in
            Owl.Arr.(mean ~axis:2 a @|| var ~axis:2 a)
            |> (fun z -> Owl.Arr.reshape z [| setup.n_steps; -1 |])
            |> Mat.save_txt ~out:(file (Printf.sprintf "predicted_%s_%i" label i))
          in
          process "u" us;
          process "z" zs;
          Array.iter ~f:(fun (label, x) -> process label x) os
        with
        | _ -> Stdio.printf "Trial %i failed with some exception in save_results." i))


let update_R2 ?u_init (prms : Model.P.p) label data =
  let in_dir' s = in_dir Printf.(sprintf "%s_%s" label s) in
  let prms = elbobo_exploit prms in
  let sub_prms f =
    VAE_P.
      { generative = prms.generative |> Accessor.map Generative_P.A.likelihood ~f
      ; recognition =
          prms.recognition |> Accessor.map Recognition_P.A.generative ~f:(fun _ -> None)
      }
  in
  let prms_neural = C.broadcast' (fun _ -> sub_prms (fun x -> x.fst)) in
  let prms_hand = C.broadcast' (fun _ -> sub_prms (fun x -> x.snd)) in
  Array.foldi data ~init:[] ~f:(fun i accu data ->
      if Int.(i % C.n_nodes = C.rank)
      then (
        try
          let open Likelihoods.Pair_P in
          let data_neural = { u = None; z = None; o = data.o.fst } in
          let data_hand = { u = None; z = None; o = data.o.snd } in
          let u_init =
            match u_init with
            | None -> None
            | Some u -> u.(i)
          in
          let mu =
            Model_neural.posterior_mean
              ~conv_threshold:1E-6
              ~u_init
              ~prms:prms_neural
              data_neural
          in
          Owl.Mat.save_txt
            ~out:(in_dir' (Printf.sprintf "posterior_u_%i" i))
            (AD.unpack_arr mu);
          let _, zs, os = Model_hand.predictions ~n_samples:1000 ~prms:prms_hand mu in
          let process label a =
            let a = AD.unpack_arr a in
            Owl.Arr.(mean ~axis:2 a @|| var ~axis:2 a)
            |> (fun z -> Owl.Arr.reshape z [| setup.n_steps; -1 |])
            |> Mat.save_txt ~out:(in_dir' (Printf.sprintf "predicted_%s_%i" label i))
          in
          process "z" zs;
          Array.iter ~f:(fun (label, x) -> process label x) os;
          let pred_hand =
            let _, os = os.(0) in
            let os = AD.unpack_arr os in
            Owl.Arr.(mean ~axis:2 os |> squeeze)
          in
          (AD.unpack_arr data_hand.o, pred_hand) :: accu
        with
        | _ ->
          Stdio.printf "Trial %i failed with some exception in compute_test." i;
          accu)
      else accu)
  |> C.gather
  |> fun v ->
  C.root_perform (fun () ->
      (* v is an array of lists *)
      let v = v |> Array.to_list |> List.concat |> Array.of_list in
      let true_hand = Array.map v ~f:fst |> Mat.concatenate ~axis:0 in
      let pred_hand = Array.map v ~f:snd |> Mat.concatenate ~axis:0 in
      Mat.save_txt ~out:(in_dir' "pred_hand_squashed") pred_hand;
      Mat.save_txt ~out:(in_dir' "true_hand_squashed") true_hand;
      Mat.save_txt ~out:(in_dir' "pred_hand") (unsquash pred_hand);
      Mat.save_txt ~out:(in_dir' "true_hand") (unsquash true_hand);
      let r2_squashed =
        1. -. (Mat.(mean' (sqr (true_hand - pred_hand))) /. Mat.var' true_hand)
      in
      let r2 =
        1.
        -. (Mat.(mean' (sqr (unsquash true_hand - unsquash pred_hand)))
           /. Mat.var' (unsquash true_hand))
      in
      Mat.(save_txt ~append:true ~out:(in_dir' "R2") (of_array [| r2; r2_squashed |] 1 2)))


let get_r2_val (prms : Model.P.p) data =
  let prms = elbobo_exploit prms in
  let sub_prms f =
    VAE_P.
      { generative = prms.generative |> Accessor.map Generative_P.A.likelihood ~f
      ; recognition =
          prms.recognition |> Accessor.map Recognition_P.A.generative ~f:(fun _ -> None)
      }
  in
  let prms_neural = C.broadcast' (fun _ -> sub_prms (fun x -> x.fst)) in
  let prms_hand = C.broadcast' (fun _ -> sub_prms (fun x -> x.snd)) in
  Array.foldi data ~init:[] ~f:(fun i accu data ->
      if Int.(i % C.n_nodes = C.rank)
      then (
        try
          let open Likelihoods.Pair_P in
          let data_neural = { u = None; z = None; o = data.o.fst } in
          let data_hand = { u = None; z = None; o = data.o.snd } in
          let mu =
            Model_neural.posterior_mean
              ~conv_threshold:1E-6
              ~u_init:None
              ~prms:prms_neural
              data_neural
          in
          let _, _, os = Model_hand.predictions ~n_samples:1000 ~prms:prms_hand mu in
          let pred_hand =
            let _, os = os.(0) in
            let os = AD.unpack_arr os in
            Owl.Arr.(mean ~axis:2 os |> squeeze)
          in
          (AD.unpack_arr data_hand.o, pred_hand) :: accu
        with
        | _ ->
          Stdio.printf "Trial %i failed with some exception in regR2." i;
          accu)
      else accu)
  |> C.gather
  |> fun v ->
  C.broadcast' (fun () ->
      (* v is an array of lists *)
      let v = v |> Array.to_list |> List.concat |> Array.of_list in
      let true_hand = Array.map v ~f:fst |> Mat.concatenate ~axis:0 in
      let pred_hand = Array.map v ~f:snd |> Mat.concatenate ~axis:0 in
      let r2_squashed =
        1. -. (Mat.(mean' (sqr (true_hand - pred_hand))) /. Mat.var' true_hand)
      in
      r2_squashed)


let save_reg_r2
    ~lambda
    prefix
    prms
    (data : L.data data array)
    (test_data : L.data data array)
  =
  let prms = C.broadcast prms |> elbobo_exploit in
  let sub_prms f =
    VAE_P.
      { generative = prms.generative |> Accessor.map Generative_P.A.likelihood ~f
      ; recognition =
          prms.recognition |> Accessor.map Recognition_P.A.generative ~f:(fun _ -> None)
      }
  in
  let prms_neural = C.broadcast' (fun _ -> sub_prms (fun x -> x.fst)) in
  let file s = prefix ^ "." ^ s in
  C.root_perform (fun () ->
      Misc.save_bin ~out:(file "params.bin") prms;
      Model.P.save_to_files ~prefix ~prms);
  Array.foldi data ~init:[] ~f:(fun i accu (dat_trial : L.data data) ->
      if Int.(i % C.n_nodes = C.rank)
      then (
        let open Likelihoods.Pair_P in
        let (data_neural : L_neural.data data) =
          { u = None; z = None; o = dat_trial.o.fst }
        in
        let (data_hand : L_hand.data data) =
          { u = None; z = None; o = dat_trial.o.snd }
        in
        let mu =
          Model_neural.posterior_mean
            ~saving_iter:(file (Printf.sprintf "iter_%i" i))
            ~conv_threshold:1E-4
            ~u_init:None
            ~prms:prms_neural
            data_neural
        in
        let us, zs, os = Model.predictions ~n_samples:100 ~prms mu in
        let process label a =
          let a = AD.unpack_arr a in
          Owl.Arr.(mean ~axis:2 a @|| var ~axis:2 a)
          |> (fun z -> Owl.Arr.reshape z [| setup.n_steps; -1 |])
          |> Mat.save_txt ~out:(file (Printf.sprintf "predicted_%s_%i" label i))
        in
        process "u" us;
        process "z" zs;
        Array.iter ~f:(fun (label, x) -> process label x) os;
        let mean_z =
          let z = AD.unpack_arr zs in
          Owl.Arr.(mean ~axis:2 z) |> fun z -> Owl.Arr.reshape z [| setup.n_steps; -1 |]
        in
        (AD.unpack_arr data_hand.o, mean_z) :: accu)
      else accu)
  |> C.gather
  |> fun v ->
  C.root_perform (fun () ->
      (* v is an array of lists *)
      let v = v |> Array.to_list |> List.concat |> Array.of_list in
      let true_hand = Array.map v ~f:fst |> Mat.concatenate ~axis:0 in
      let latent_neural = Array.map v ~f:snd |> Mat.concatenate ~axis:0 in
      Mat.save_txt ~out:(in_dir "pred_neural") latent_neural;
      Mat.save_txt ~out:(in_dir "true_train_hand") true_hand);
  let _, c =
    C.broadcast' (fun () ->
        let x = Mat.load_txt (in_dir "pred_neural") in
        let y = unsquash (Mat.load_txt (in_dir "true_train_hand")) in
        compute_R2 ~lambda x y)
  in
  (* let train_r2, _ =
      compute_R2
        (Mat.load_txt (in_dir "pred_neural"))
        (unsquash (Mat.load_txt (in_dir "true_train_hand")))
    in *)
  Array.foldi test_data ~init:[] ~f:(fun i accu dat_trial ->
      if Int.(i % C.n_nodes = C.rank)
      then (
        try
          let open Likelihoods.Pair_P in
          let data_neural = { u = None; z = None; o = dat_trial.o.fst } in
          let data_hand = { u = None; z = None; o = dat_trial.o.snd } in
          let mu =
            Model_neural.posterior_mean
              ~conv_threshold:1E-4
              ~u_init:None
              ~prms:prms_neural
              data_neural
          in
          let us, zs, os = Model.predictions ~n_samples:100 ~prms mu in
          let process label a =
            let a = AD.unpack_arr a in
            Owl.Arr.(mean ~axis:2 a @|| var ~axis:2 a)
            |> (fun z -> Owl.Arr.reshape z [| setup.n_steps; -1 |])
            |> Mat.save_txt ~out:(file (Printf.sprintf "predicted_%s_%i" label i))
          in
          process "u" us;
          process "z" zs;
          Array.iter ~f:(fun (label, x) -> process label x) os;
          let mean_z =
            let z = AD.unpack_arr zs in
            Owl.Arr.(mean ~axis:2 z) |> fun z -> Owl.Arr.reshape z [| setup.n_steps; -1 |]
          in
          (AD.unpack_arr data_hand.o, mean_z) :: accu
        with
        | _ ->
          Stdio.printf "Trial %i failed with some exception in save_reg_R2_test." i;
          accu)
      else accu)
  |> C.gather
  |> fun v ->
  C.root_perform (fun () ->
      (* v is an array of lists *)
      let v = v |> Array.to_list |> List.concat |> Array.of_list in
      let true_hand = Array.map v ~f:fst |> Mat.concatenate ~axis:0 |> unsquash in
      let latent_neural = Array.map v ~f:snd |> Mat.concatenate ~axis:0 in
      let pred_hand = Mat.(latent_neural *@ c) in
      let r2 = 1. -. (Mat.(mean' (sqr (true_hand - pred_hand))) /. Mat.var' true_hand) in
      Mat.(
        save_txt
          ~append:true
          ~out:(in_dir (Printf.sprintf "regressed_R2_%f" lambda))
          (of_array [| r2 |] 1 1)))


(* Mat.(
        save_txt
          ~append:true
          ~out:(in_dir "train_regressed_R2")
          (of_array [| train_r2 |] 1 1))) *)

let _ = C.print_endline "pre init"
let _ = save_results (in_dir "init") init_prms train_data
let _ = C.print_endline "post init"
let r2_val_ = C.broadcast' (fun () -> [| 0.; 1. |])

let check_early_stop k prms =
  let new_r2_val = get_r2_val prms val_data in
  let r2_val_pred = C.broadcast' (fun () -> r2_val_.(0)) in
  let _ = C.root_perform (fun () -> r2_val_.(0) <- new_r2_val) in
  let pred_diff = C.broadcast' (fun () -> r2_val_.(1)) in
  let diff = new_r2_val -. r2_val_pred in
  let _ = C.root_perform (fun () -> r2_val_.(1) <- diff) in
  if Float.(diff < 0. && pred_diff < 0.) && k > 500 then true else false


let early_stop = C.broadcast' (fun () -> [| false |])

let reg ~(prms : Model.P.p) =
  let z = Float.(3E-5 / of_int Int.(setup.n * setup.n)) in
  let part1 = AD.Maths.(F z * l2norm_sqr' (extract prms.generative.dynamics.wh)) in
  (*let part2 = AD.Maths.(F z * l2norm_sqr' (extract prms.generative.dynamics.wf)) in *)
  let part3 = AD.Maths.(F z * l2norm_sqr' (extract prms.generative.dynamics.uh)) in
  let part4 = AD.Maths.(F z * l2norm_sqr' (extract prms.generative.dynamics.uf)) in
  AD.Maths.(part1 + part3 + part4)


let final_prms =
  let in_each_iteration ~u_init ~prms k =
    if Int.(k % 1000 = 0)
    then (
      let s = check_early_stop k prms in
      C.root_perform (fun () -> early_stop.(0) <- s));
    if Int.(k % 100 = 0)
    then (
      save_results (in_dir "final") prms train_data;
      (* update_R2 prms ~u_init "R2_train_uinit" train_data; *)
      update_R2 prms "R2_train" train_data;
      (* update_R2 prms "R2_val" val_data; *)
      Array.iter [| 1E-2; 1E-3; 1E-4; 1E-5; 1E-6; 1E-7 |] ~f:(fun lambda ->
          save_reg_r2 ~lambda (in_dir "reg_R2") prms train_data test_data);
      update_R2 prms "R2_test" test_data)
  in
  Model.train
  (* ~early_stop:(fun k ->
      if Int.(k % 30 = 1)
      then (
        let es = C.broadcast' (fun () -> early_stop.(0)) in
        es)
      else false) *)
    ~n_samples:(fun _ -> 100)
    ~max_iter:Cmdargs.(get_int "-max_iter" |> default 30000)
    ~conv_threshold:(fun _ -> 1E-6)
      (* ~conv_threshold:(fun k -> if k < 1000 then Float.(1E-6 / sqrt (of_int k)) else 1E-8) *)
    ~save_progress_to:(2, 300, in_dir "progress")
    ~in_each_iteration
    ~recycle_u:(fun k -> if k % 1 = 0 then false else true)
    ~eta:(`of_iter (fun k -> Float.(0.04 / (1. + sqrt (of_int Int.(k) / 24.)))))
    ~reg
    ~init_prms
    ~elbobo_sample
    train_data


let _ = save_results (in_dir "final") final_prms train_data
