open Base
open Owl
open Owl_parameters
open Vae
open Variational
open Accessor.O

let data_dir = Cmdargs.in_dir "-data"
let in_dir = Cmdargs.in_dir "-d"

let compute_R2 ~lambda x x' =
  let y = x' in
  let nx = Mat.col_num x in
  let c =
    let xt = Mat.transpose x in
    let xtx_inv = Linalg.D.linsolve Mat.((xt *@ x) + (lambda $* eye nx)) (Mat.eye nx) in
    Mat.(xtx_inv *@ xt *@ y)
  in
  let new_x = Mat.(x *@ c) in
  let residuals = Mat.(new_x - x') |> Mat.l2norm_sqr' in
  let sstot = Mat.(x' - mean ~axis:0 x') |> Mat.l2norm_sqr' in
  1. -. (residuals /. sstot), c


type setup =
  { n : int
  ; m : int
  ; n_trials : int
  ; n_steps : int
  ; n_neural : int
  ; n_hand : int
  }

let _ = Stdio.printf "test0 %!"
let setup_prev = C.broadcast' (fun () -> Misc.read_bin (in_dir "setup.bin"))
let n_eff_trials = Cmdargs.(get_int "-n_eff_trials" |> force ~usage:"-n_eff_trials [int]")

(* here we will consider a single long chunk *)
let setup = { setup_prev with n_trials = 1; n_steps = n_eff_trials * setup_prev.n_steps }
let n_beg = setup.n / setup.m
let dt = 0.025

(* ----------------------------------------- 
   -- Define model
   ----------------------------------------- *)

module U = Priors.Gaussian_Init (struct
  let n_beg = n_beg
end)

module L_neural = Likelihoods.Poisson (struct
  let label = "neural"
  let link_function x = AD.Maths.exp x
  let d_link_function x = AD.Maths.exp x
  let d2_link_function x = AD.Maths.exp x

  (* let link_function x = AD.Maths.(F 0.5 * (x + sqrt (sqr x + F 4.)))
   let d_link_function x = AD.Maths.(F 0.5 * (F 1. + (x / sqrt (sqr x + F 4.))))
   let d2_link_function x = AD.Maths.(F 2. / ((sqr x + F 4.) * sqrt (sqr x + F 4.))) *)
  let dt = AD.F dt
  let normalize_c = false
end)

module L_hand = Likelihoods.Gaussian (struct
  let label = "hand"
  let normalize_c = false
end)

module L = Likelihoods.Pair (L_neural) (L_hand)

module D = Dynamics.Mini_GRU_IO (struct
  let phi x = AD.Maths.(AD.requad x - F 1.)
  let d_phi x = AD.d_requad x
  let sigma x = AD.Maths.sigmoid x
  let d_sigma x = AD.Maths.(exp (neg x) / sqr (F 1. + exp (neg x)))
  let n_beg = Some n_beg
end)

module E = Elbobo.ARD

module X = struct
  let n = setup.n
  let m = setup.m
  let n_steps = setup.n_steps
  let diag_time_cov = false
  let n_beg = n_beg
end

module Model = VAE (U) (D) (L) (E) (X)

(* define amputated models for testing purposes *)
module Model_neural = VAE (U) (D) (L_neural) (E) (X)
module Model_hand = VAE (U) (D) (L_hand) (E) (X)

(* ----------------------------------------- 
   --  Fetch parameters
   ----------------------------------------- *)

let (prms : Model.P.p) =
  C.broadcast' (fun () -> Misc.read_bin (in_dir "final.params.bin"))


let elbobo_exploit (p : Model.P.p) = p
let prms = elbobo_exploit prms

let sub_prms f g h =
  VAE_P.
    { generative = prms.generative |> Accessor.map Generative_P.A.likelihood ~f
    ; recognition =
        prms.recognition |> Accessor.map Recognition_P.A.generative ~f:(fun _ -> None)
    }


let prms_neural =
  let open Likelihoods.Pair_P in
  C.broadcast' (fun _ -> sub_prms (fun x -> x.fst) (fun x -> x.fst) (fun x -> x.fst))


let prms_hand =
  let open Likelihoods.Pair_P in
  C.broadcast' (fun _ -> sub_prms (fun x -> x.snd) (fun x -> x.snd) (fun x -> x.snd))


(* ----------------------------------------- 
   --  Fetch data
   ----------------------------------------- *)

let squash x = Mat.(signum x * log (1. $+ abs x))
let unsquash x = Mat.(signum x * (exp (abs x) -$ 1.))
let _ = Stdio.printf "test3 %!"

let _ =
  let module I = Dynamics.Integrate (D) in
  C.broadcast' (fun () ->
      let x = Mat.load (data_dir "neurons.bin") in
      let y = Mat.load (data_dir "velocity.bin") in
      (* remove the silent episode at the end *)
      let x_train = Mat.get_slice [ [ 0; setup.n_steps ] ] x in
      let x_test = Mat.get_slice [ [ setup.n_steps; 2 * setup.n_steps ] ] x in
      let y = Mat.get_slice [ [ 0; 57000 - 1 ] ] y in
      (* sqrt and standardize *)
      (* let x = Mat.(sqrt (1E-6 $+ x)) in
      let x = Mat.((x - mean ~axis:0 x) / std ~axis:0 x) in *)
      let y_train = Mat.get_slice [ [ 0; setup.n_steps ] ] y in
      let y_test = Mat.get_slice [ [ setup.n_steps; 2 * setup.n_steps ] ] y in
      (* slide a window of size 100 time steps, infer inputs,
         retain only the last one, and also grab z1 to be used as u0 in the sliding *)
      let u_init = ref None in
      let _ = Stdio.printf "%i %i %i %!" (Mat.row_num x) (Mat.col_num x) setup.n_steps in
      let _ =
        C.root_perform (fun () -> Mat.print (Mat.get_slice [ [ 0; 1000 - 1 ] ] x))
      in
      let data = { u = None; z = None; o = AD.pack_arr x_train } in
      let test_data = { u = None; z = None; o = AD.pack_arr x_test } in
      let mu_train =
        Model_neural.posterior_mean
          ~conv_threshold:1E-6
          ~u_init:!u_init
          ~prms:prms_neural
          data
      in
      let mu_test =
        Model_neural.posterior_mean
          ~conv_threshold:1E-6
          ~u_init:!u_init
          ~prms:prms_neural
          test_data
        |> AD.unpack_arr
      in
      let mu = AD.unpack_arr mu_train in
      Mat.save_txt ~out:(in_dir "long_smoothing_posterior_u") mu;
      let z_train =
        I.integrate
          ~prms:prms_hand.generative.dynamics
          ~n:setup.n
          ~u:AD.(expand0 (pack_arr mu))
        |> AD.squeeze0
      in
      let z_test =
        I.integrate
          ~prms:prms_hand.generative.dynamics
          ~n:setup.n
          ~u:AD.(expand0 (pack_arr mu_test))
        |> AD.squeeze0
      in
      let z_test = AD.Maths.get_slice [ [ n_beg - 1; -1 ] ] z_test in
      let z_train = AD.Maths.get_slice [ [ n_beg - 1; -1 ] ] z_train in
      let psths_train =
        L_neural.pre_sample ~prms:prms_neural.generative.likelihood ~z:z_train
        |> AD.unpack_arr
      in
      let psths_test =
        L_neural.pre_sample ~prms:prms_neural.generative.likelihood ~z:z_test
        |> AD.unpack_arr
      in
      let pred_hand =
        L_hand.pre_sample ~prms:prms_hand.generative.likelihood ~z:z_train
        |> AD.unpack_arr
      in
      let true_hand_train = Mat.get_slice [ [ 0; Mat.row_num mu - n_beg ] ] y_train in
      let true_hand_test = Mat.get_slice [ [ 0; Mat.row_num mu - n_beg ] ] y_test in
      let r2, c = compute_R2 ~lambda:1E-5 psths_train true_hand_train in
      let pred_hand_test = Mat.(psths_test *@ c) in
      Mat.save_txt ~out:(in_dir "long_smoothing_pred_hand") pred_hand;
      Mat.save_txt ~out:(in_dir "long_smoothing_latent_test") (AD.unpack_arr z_test);
      Mat.save_txt ~out:(in_dir "long_smoothing_latent_train") (AD.unpack_arr z_train);
      let r2_test =
        1.
        -. (Mat.(mean' (sqr (true_hand_test - pred_hand_test))) /. Mat.var' true_hand_test)
      in
      Mat.(save_txt ~out:(in_dir "R2_smoothing_train") (create 1 1 r2));
      Mat.(save_txt ~out:(in_dir "R2_smoothing_test") (create 1 1 r2_test));    Mat.(save_txt ~out:(in_dir "true_hand_test") true_hand_test);
      Mat.(save_txt ~out:(in_dir "pred_hand_test") (pred_hand_test)))
