open Base
open Owl
open Vae
open Variational
open Owl_parameters

let dir = Cmdargs.(get_string "-d" |> force ~usage:"-d [dir]")
let in_dir = Printf.sprintf "%s/%s" dir
let reuse = Cmdargs.check "-reuse"
let soc = Cmdargs.(get_string "-soc" |> force ~usage:"-soc [path]")

type setup =
  { n : int
  ; m : int
  ; n_output : int
  ; n_trials : int
  ; n_steps : int
  }

let setup =
  C.broadcast' (fun () ->
      match reuse with
      | true -> Misc.read_bin (in_dir "setup.bin")
      | false ->
        let s = { n = 8; m = 3; n_output = 10; n_trials = 56; n_steps = 100 } in
        Misc.save_bin ~out:(in_dir "setup.bin") s;
        s)


let dt = 1E-3
let tau = 20E-3
let dt_over_tau = Float.(dt / tau)
let noise_std = 0.2

(* ----------------------------------------- 
   -- Define model
   ----------------------------------------- *)

module U = Priors.Student

module L = Likelihoods.Gaussian (struct
  let label = "o"
  let normalize_c = false
end)

module D = Dynamics.Nonlinear (struct
  let phi x = x
  let d_phi x = AD.Arr.(ones (shape x))
end)

module E = Elbobo.ARD

module X = struct
  let n = setup.n
  let m = setup.m
  let n_steps = setup.n_steps
  let diag_time_cov = false
end

module Model = VAE (U) (D) (L) (E) (X)

(* ----------------------------------------- 
   -- Generate ground truth parameters
   ----------------------------------------- *)

let true_prms =
  C.broadcast' (fun () ->
      let w = Mat.load_txt soc in
      let b =
        (* find observability Gramian to use inputs that elicit amplification *)
        let q =
          Linalg.D.lyapunov Mat.(transpose w - eye setup.n) Mat.(neg (eye setup.n))
        in
        let u, _, _ = Linalg.D.svd q in
        Mat.get_slice [ []; [ 0; setup.m - 1 ] ] u |> Mat.transpose
      in
      let generative_prms =
        match reuse with
        | true ->
          let (prms : Model.P.p) = Misc.read_bin (in_dir "true.params.bin") in
          prms.generative
        | false ->
          let n = setup.n
          and m = setup.m
          and n_output = setup.n_output in
          let prior = Priors.Student.init ~spatial_std:1.0 ~nu:10. ~m pinned in
          let dynamics =
            let a = Linalg.D.expm Mat.(dt_over_tau $* w - eye n) in
            Dynamics.Nonlinear_P.
              { a = pinned (AD.pack_arr a)
              ; bias = pinned (AD.Mat.zeros 1 n)
              ; b = Some (pinned (AD.pack_arr b))
              }
          in
          let elbobo_prior : E.P_prior.p =
            { ard_scales = learned ~above:1E-6 (AD.Mat.ones 1 n)
            ; a = learned ~above:1E-6 (AD.F 0.3)
            ; b = Option.map dynamics.b ~f:(fun _ -> learned ~above:1E-6 (AD.F 0.3))
            ; c =
                [| learned ~above:1E-6 (AD.F 0.3), "neural"
                 ; learned ~above:1E-6 (AD.F 0.3), "hand"
                |]
            }
          in
          let likelihood = L.init ~sigma2:Float.(square noise_std) ~n ~n_output pinned in
          Generative_P.{ prior; dynamics; likelihood; elbobo_prior }
      in
      let elbobo_posterior : E.P_posterior.p =
        { a_mean = learned (AD.Mat.zeros setup.n setup.n)
        ; b_mean = generative_prms.dynamics.b
        ; c_mean =
            [| learned (AD.Mat.zeros setup.n_neural setup.n), "neural"
             ; learned (AD.Mat.gaussian ~sigma:0.1 setup.n_hand setup.n), "hand"
            |]
        ; a_var = learned ~above:1E-6 (AD.Mat.create setup.n setup.n 0.001)
        ; b_var =
            Option.map generative_prms.dynamics.b ~f:(fun _ ->
                learned ~above:1E-6 (AD.Mat.create setup.m setup.n 0.001))
        ; c_var =
            [| learned ~above:1E-6 (AD.Mat.create setup.n_neural setup.n 0.001)
             ; learned ~above:1E-6 (AD.Mat.create setup.n_hand setup.n 0.001)
            |]
        }
      in
      Model.init ~tie:true generative_prms elbobo_posterior pinned)


let elbobo_sample p = p
let elbobo_exploit p = p

(* ----------------------------------------- 
   -- Generate data from ground truth model
   ----------------------------------------- *)

(* generate training and test data right away *)
let data =
  C.broadcast' (fun () ->
      match reuse with
      | true -> Misc.read_bin (in_dir "train_data.bin")
      | false ->
        let module D = Dynamics.Integrate (D) in
        let data =
          Array.init (2 * setup.n_trials) ~f:(fun _ ->
              let u =
                Mat.(gaussian 1 setup.m @= zeros Int.(setup.n_steps - 1) setup.m)
                |> AD.pack_arr
              in
              let z =
                D.integrate
                  ~prms:true_prms.generative.dynamics
                  ~n:setup.n
                  ~u:(AD.expand0 u)
                |> AD.squeeze0
              in
              let o = L.sample ~prms:true_prms.generative.likelihood ~z in
              { u = None; z = Some z; o })
        in
        let train_data = Array.sub data ~pos:0 ~len:setup.n_trials in
        let test_data = Array.sub data ~pos:setup.n_trials ~len:setup.n_trials in
        Misc.save_bin ~out:(in_dir "train_data.bin") train_data;
        Misc.save_bin ~out:(in_dir "test_data.bin") test_data;
        let save_data label data =
          Array.iteri data ~f:(fun i data ->
              let file label' = in_dir (Printf.sprintf "%s_data_%s_%i" label label' i) in
              Option.iter data.z ~f:(fun z ->
                  Mat.save_txt ~out:(file "z") (AD.unpack_arr z));
              L.save_data ~prefix:(file "o") data.o)
        in
        save_data "train" train_data;
        save_data "test" test_data;
        train_data)


(* ----------------------------------------- 
   -- Saving results
   ----------------------------------------- *)

let save_results ?u_init prefix prms data =
  let prms = C.broadcast prms in
  let file s = prefix ^ "." ^ s in
  C.root_perform (fun () ->
      Misc.save_bin ~out:(file "params.bin") prms;
      Model.P.save_to_files ~prefix ~prms);
  Array.iteri data ~f:(fun i dat_trial ->
      if Int.(i % C.n_nodes = C.rank)
      then (
        let u_init =
          match u_init with
          | None -> None
          | Some u -> u.(i)
        in
        Option.iter u_init ~f:(fun u ->
            Owl.Mat.save_txt ~out:(file (Printf.sprintf "u_init_%i" i)) u);
        let mu = Model.posterior_mean ~u_init ~prms dat_trial in
        Owl.Mat.save_txt
          ~out:(file (Printf.sprintf "posterior_u_%i" i))
          (AD.unpack_arr mu);
        let us, zs, os = Model.predictions ~n_samples:100 ~prms mu in
        let process label a =
          let a = AD.unpack_arr a in
          Owl.Arr.(mean ~axis:2 a @|| var ~axis:2 a)
          |> (fun z -> Owl.Arr.reshape z [| setup.n_steps; -1 |])
          |> Mat.save_txt ~out:(file (Printf.sprintf "predicted_%s_%i" label i))
        in
        process "u" us;
        process "z" zs;
        Array.iter ~f:(fun (label, x) -> process label x) os))


let _ = save_results (in_dir "true") true_prms data

(* ----------------------------------------- 
   -- Initialise parameters and train
   ----------------------------------------- *)

let init_prms =
  C.broadcast' (fun () ->
      let generative_prms =
        match reuse with
        | true ->
          let (prms : Model.P.p) = Misc.read_bin (in_dir "final.params.bin") in
          prms.generative
        | false ->
          let n = setup.n
          and m = setup.m in
          let prior = Priors.Student.init ~spatial_std:1.0 ~nu:10. ~m learned in
          let dynamics = D.init ~radius:0.1 ~n ~m learned in
          let likelihood =
            (* we fix the C matrix to fix the scale of the problem *)
            let lik = true_prms.generative.likelihood in
            Likelihoods.Gaussian_P.
              { c_mask = None
              ; c = pinned (extract lik.c)
              ; bias = learned (extract lik.bias)
              ; variances =
                  learned ~above:0.001 (AD.pack_arr (Mat.create 1 setup.n_output 1.))
              }
          in
          Generative_P.{ prior; dynamics; likelihood }
      in
      Model.init ~tie:true generative_prms learned)


let _ = save_results (in_dir "init") init_prms data

let final_prms =
  let in_each_iteration ~u_init ~prms k =
    if Int.(k % 200 = 0) then save_results ~u_init (in_dir "final") prms data
  in
  Model.train
    ~n_samples:100
    ~max_iter:Cmdargs.(get_int "-max_iter" |> default 20000)
    ~save_progress_to:(10, 200, in_dir "progress")
    ~in_each_iteration
    ~eta:(`of_iter (fun k -> Float.(0.004 / (1. + sqrt (of_int k / 10.)))))
    ~init_prms
    data


let _ = save_results (in_dir "final") final_prms data
