open Base
open Owl
open Vae
open Variational
open Owl_parameters

let dir = Cmdargs.(get_string "-d" |> force ~usage:"-d [dir]")
let in_dir = Printf.sprintf "%s/%s" dir
let data_dir = Cmdargs.(get_string "-data" |> force ~usage:"-d [dir]")
let in_data_dir = Printf.sprintf "%s/%s" data_dir
let n = 3
let m = n
let dt = 1E-3
let tau = 20E-3
let sigma = 0.3
let ar_coeff = 0.9
let dt_over_tau = Float.(dt / tau)
let n_trials = Cmdargs.(get_int "-n_trials" |> force ~usage:"-n_trials")
let n_steps = Cmdargs.(get_int "-n_steps" |> force ~usage:"-n_steps")
let n_output = 10
let n_beg = 1

let compute_R2 x x' =
  let y = x' in
  let nx = Mat.col_num x in
  let c =
    let xt = Mat.transpose x in
    let xtx_inv = Linalg.D.linsolve Mat.(xt *@ x) (Mat.eye nx) in
    Mat.(xtx_inv *@ xt *@ y)
  in
  let new_x = Mat.(x *@ c) in
  let residuals = Mat.(new_x - x') |> Mat.l2norm_sqr' in
  let sstot = Mat.(x' - mean ~axis:0 x') |> Mat.l2norm_sqr' in
  1. -. (residuals /. sstot)


module U = Priors.Student (struct
  let n_beg = Some 1
end)

module D = Dynamics.Linear (struct
  let n_beg = Some 1
end)

module L = Likelihoods.Gaussian (struct
  let label = "o"
  let normalize_c = false
end)

module E = struct
  include Elbobo.Matrix_column_prior

  let label = "output"
end

module Model =
  VAE (U) (D) (L) (E)
    (struct
      let n = n
      let m = n
      let n_steps = n_steps
      let diag_time_cov = false
      let n_beg = 1
    end)

let generative_prms ~noise_var ~nu (set : setter) =
  (* we pin the prior scale here, because everything is linear *)
  let prior =
    let spatial_var =
      (* make sure we get the correct ground truth variance for true_prms *)
      let p = 0.03 in
      Float.((p * (1. - p)) + (p * p))
    in
    U.init ~pin_std:true ~spatial_std:Float.(sqrt spatial_var) ~nu ~m set
  in
  let dynamics = D.init ~dt_over_tau ~alpha:0.4 ~beta:4. set n m in
  let likelihood = L.init ~sigma2:noise_var ~n ~n_output set in
  let elbobo_prior : E.P_prior.p =
    { prior_vars = learned ~above:1E-6 (AD.Mat.create 1 n 0.1) }
  in
  Generative_P.{ prior; dynamics; likelihood; elbobo_prior }


let generative_prms = C.broadcast generative_prms

let true_prms =
  let gen = generative_prms ~noise_var:0.1 ~nu:3. pinned in
  let elbobo_posterior : E.P_posterior.p =
    { posterior_mean = learned (AD.Mat.zeros n_output n)
    ; posterior_vars = learned ~above:1E-6 (AD.Mat.create n_output n 0.001)
    }
  in
  Model.init ~tie:true gen elbobo_posterior learned


let true_prms = C.broadcast true_prms

let init_prms =
  let gen = generative_prms ~noise_var:0.1 ~nu:10. learned in
  let elbobo_posterior : E.P_posterior.p =
    { posterior_mean = learned (AD.Mat.gaussian ~sigma:0.001 n_output n)
    ; posterior_vars = learned ~above:1E-6 (AD.Mat.create n_output n 0.001)
    }
  in
  Model.init ~tie:true gen elbobo_posterior learned


let init_prms = C.broadcast init_prms

(* let init_prms = Misc.read_bin (in_dir "final") *)
let ar_input ~us =
  let rec ar_inputs k ipts prev_ipt =
    if k = n_steps
    then List.rev ipts |> List.to_array |> fun z -> Mat.concatenate ~axis:0 z
    else (
      let new_ipt = Mat.(ar_coeff $* prev_ipt + gaussian ~sigma 1 m) in
      ar_inputs (k + 1) (new_ipt :: ipts) new_ipt)
  in
  ar_inputs 0 [] us


(* 
let data =
  let module I = Dynamics.Integrate (D) in
  Array.init n_trials ~f:(fun _ ->
      let u = ar_input ~us:Mat.(gaussian ~sigma 1 m) |> AD.pack_arr in
      let z = I.integrate ~prms:true_prms.generative.dynamics ~n ~u:(AD.expand0 u) in
      let z = AD.squeeze0 z in
      let o = L.sample ~prms:true_prms.generative.likelihood ~z in
      { u = Some u; z = Some z; o })
  |> C.broadcast


let test_data =
  let module I = Dynamics.Integrate (D) in
  Array.init n_trials ~f:(fun _ ->
      let u = ar_input ~us:Mat.(gaussian ~sigma 1 m) |> AD.pack_arr in
      let z = I.integrate ~prms:true_prms.generative.dynamics ~n ~u:(AD.expand0 u) in
      let z = AD.squeeze0 z in
      let o = L.sample ~prms:true_prms.generative.likelihood ~z in
      { u = Some u; z = Some z; o })
  |> C.broadcast


let _ =
  C.root_perform (fun () -> Misc.save_bin ~out:(in_data_dir "data") data);
  C.root_perform (fun () -> Misc.save_bin ~out:(in_data_dir "test_data") test_data) *)

let data = C.broadcast' (fun () -> Misc.read_bin (in_data_dir "data"))
let test_data = C.broadcast' (fun () -> Misc.read_bin (in_data_dir "test_data"))
let _ = C.print "data data done"

let _ =
  C.root_perform (fun () ->
      let arr =
        Array.map data ~f:(fun dat ->
            let x = AD.Maths.reshape dat.o [| 1; -1; n_output |] in
            AD.unpack_arr x)
      in
      Arr.save_npy ~out:(in_data_dir "train_output.npy") (Arr.concatenate ~axis:0 arr));
  C.root_perform (fun () ->
      let arr =
        Array.map data ~f:(fun dat ->
            match dat.u with
            | Some u ->
              let x = AD.Maths.reshape u [| 1; -1; m |] in
              AD.unpack_arr x
            | None -> Mat.zeros 1 1)
      in
      Arr.save_npy ~out:(in_data_dir "train_u.npy") (Arr.concatenate ~axis:0 arr));
  C.root_perform (fun () ->
      let arr =
        Array.map data ~f:(fun dat ->
            match dat.z with
            | Some z ->
              let x = AD.Maths.reshape z [| 1; -1; n |] in
              AD.unpack_arr x
            | None -> Mat.zeros 1 1)
      in
      Arr.save_npy ~out:(in_data_dir "train_z.npy") (Arr.concatenate ~axis:0 arr));
  C.root_perform (fun () ->
      let arr =
        Array.map test_data ~f:(fun dat ->
            let x = AD.Maths.reshape dat.o [| 1; -1; n_output |] in
            AD.unpack_arr x)
      in
      Arr.save_npy ~out:(in_data_dir "test_output.npy") (Arr.concatenate ~axis:0 arr));
  C.root_perform (fun () ->
      let arr =
        Array.map test_data ~f:(fun dat ->
            match dat.u with
            | Some u ->
              let x = AD.Maths.reshape u [| 1; -1; m |] in
              AD.unpack_arr x
            | None -> Mat.zeros 1 1)
      in
      Arr.save_npy ~out:(in_data_dir "test_u.npy") (Arr.concatenate ~axis:0 arr));
  C.root_perform (fun () ->
      let arr =
        Array.map test_data ~f:(fun dat ->
            match dat.z with
            | Some z ->
              let x = AD.Maths.reshape z [| 1; -1; n |] in
              AD.unpack_arr x
            | None -> Mat.zeros 1 1)
      in
      Arr.save_npy ~out:(in_data_dir "test_z.npy") (Arr.concatenate ~axis:0 arr));
  C.root_perform (fun () ->
      match test_data.(0).u with
      | Some u ->
        let x = AD.Maths.reshape u [| -1; m |] in
        let y = AD.unpack_arr x in
        Mat.save_txt ~out:(in_data_dir "test_u0") y
      | None -> ());
  C.root_perform (fun () ->
      match data.(0).u with
      | Some u ->
        let x = AD.Maths.reshape u [| -1; m |] in
        let y = AD.unpack_arr x in
        Mat.save_txt ~out:(in_data_dir "train_u0") y
      | None -> ());
  C.root_perform (fun () ->
      match test_data.(0).z with
      | Some u ->
        let x = AD.Maths.reshape u [| -1; n |] in
        let y = AD.unpack_arr x in
        Mat.save_txt ~out:(in_data_dir "test_z0") y
      | None -> ());
  C.root_perform (fun () ->
      match data.(0).z with
      | Some u ->
        let x = AD.Maths.reshape u [| -1; n |] in
        let y = AD.unpack_arr x in
        Mat.save_txt ~out:(in_data_dir "train_z0") y
      | None -> ())


let _ = C.print "|| saved test and train data \n"

let train_u =
  C.broadcast' (fun () ->
      Array.map data ~f:(fun dat ->
          match dat.u with
          | Some u ->
            let x = AD.Maths.reshape u [| 1; -1; m |] in
            AD.unpack_arr x
          | None -> Mat.zeros 1 1)
      |> Arr.concatenate ~axis:0)


let test_u =
  C.broadcast
    (Array.map test_data ~f:(fun dat ->
         match dat.u with
         | Some u ->
           let x = AD.Maths.reshape u [| 1; -1; m |] in
           AD.unpack_arr x
         | None -> Mat.zeros 1 1)
    |> Arr.concatenate ~axis:0)


let train_z =
  C.broadcast
    (Array.map data ~f:(fun dat ->
         match dat.z with
         | Some u ->
           let x = AD.Maths.reshape u [| 1; -1; n |] in
           AD.unpack_arr x
         | None -> Mat.zeros 1 1)
    |> Arr.concatenate ~axis:0)


let test_z =
  C.broadcast
    (Array.map test_data ~f:(fun dat ->
         match dat.z with
         | Some u ->
           let x = AD.Maths.reshape u [| 1; -1; n |] in
           AD.unpack_arr x
         | None -> Mat.zeros 1 1)
    |> Arr.concatenate ~axis:0)


let _ =
  C.root_perform (fun () ->
      let file trial label = in_dir (Printf.sprintf "train_data_%s_%i" label trial) in
      Array.iteri data ~f:(fun i data ->
          if i < 10
          then (
            let file = file i in
            Option.iter data.u ~f:(fun u ->
                Mat.(save_txt ~out:(file "u") (l2norm ~axis:1 (AD.unpack_arr u))));
            Option.iter data.z ~f:(fun z ->
                Mat.save_txt ~out:(file "z") (AD.unpack_arr z);
                let pre_o = L.pre_sample ~prms:true_prms.generative.likelihood ~z in
                L.save_data ~prefix:(file "pre_o") pre_o);
            L.save_data ~prefix:(file "o") data.o)
          else C.print "not doing anyhting"))


let _ =
  C.root_perform (fun () ->
      let file trial label = in_dir (Printf.sprintf "test_data_%s_%i" label trial) in
      Array.iteri data ~f:(fun i data ->
          if i < 10
          then (
            let file = file i in
            Option.iter data.u ~f:(fun u ->
                Mat.(save_txt ~out:(file "u") (l2norm ~axis:1 (AD.unpack_arr u))));
            Option.iter data.z ~f:(fun z ->
                Mat.save_txt ~out:(file "z") (AD.unpack_arr z);
                let pre_o = L.pre_sample ~prms:true_prms.generative.likelihood ~z in
                L.save_data ~prefix:(file "pre_o") pre_o);
            L.save_data ~prefix:(file "o") data.o)
          else C.print "not doing anyhting"))


let _ = C.print "first saving done "

let save_results ?u_init prefix prms data =
  let _ = C.print "saving results..." in
  let prms = C.broadcast prms in
  let file s = prefix ^ "." ^ s in
  C.root_perform (fun () ->
      Model.P.save_to_files ~prefix ~prms;
      let a =
        D.unpack_a ~prms:prms.VAE_P.generative.Generative_P.dynamics |> AD.unpack_arr
      in
      let e = Linalg.D.eigvals a in
      let er, ei = Dense.Matrix.Z.(re e, im e) in
      Mat.(save_txt ~out:(file "a_eig") (transpose (er @= ei))));
  Array.iteri data ~f:(fun i dat_trial ->
      if Int.(i % C.n_nodes = C.rank)
      then (
        let u_init =
          match u_init with
          | None -> None
          | Some u -> u.(i)
        in
        let mu =
          Model.posterior_mean
            ?u_init:None
            ~saving_iter:(in_dir (Printf.sprintf "iters_%i" i))
            ~prms
            dat_trial
        in
        Owl.Mat.save_txt
          ~out:(file (Printf.sprintf "posterior_u_%i" i))
          (AD.unpack_arr mu);
        let us, zs, os = Model.predictions ~n_samples:100 ~prms mu in
        let process label a =
          let a = AD.unpack_arr a in
          let b = Owl.Arr.(mean ~axis:2 a @|| var ~axis:2 a) in
          b
          |> (fun z -> Owl.Arr.reshape z [| n_steps; -1 |])
          |> Mat.save_txt ~out:(file (Printf.sprintf "predicted_%s_%i" label i))
        in
        process "u" us;
        process "z" zs;
        Array.iter ~f:(fun (label, x) -> process label x) os));
  let train_us =
    C.broadcast' (fun () ->
        let uss = Arr.split (Array.init n_trials ~f:(fun _ -> 1)) train_u in
        Array.map ~f:(fun x -> Arr.reshape x [| -1; m |]) uss |> Mat.concatenate ~axis:0)
  in
  let train_zs =
    C.broadcast' (fun () ->
        let zss = Arr.split (Array.init n_trials ~f:(fun _ -> 1)) train_z in
        Array.map ~f:(fun x -> Arr.reshape x [| -1; n |]) zss |> Mat.concatenate ~axis:0)
  in
  Array.foldi data ~init:[] ~f:(fun i accu dat ->
      if Int.(i % C.n_nodes = C.rank)
      then (
        let mu = Model.posterior_mean ?u_init:None ~prms dat in
        let _, zs, _ = Model.predictions ~n_samples:100 ~prms mu in
        let z = AD.unpack_arr zs in
        let u = Arr.get_slice [ [ n_beg - 1; -1 ] ] (AD.unpack_arr mu) in
        let m =
          Owl.Arr.(mean ~axis:2 z) |> fun z -> Owl.Arr.reshape z [| n_steps; -1 |]
        in
        (m, u) :: accu)
      else accu)
  |> C.gather
  |> fun v ->
  C.root_perform (fun () ->
      (* v is an array of lists *)
      let posterior_z =
        v
        |> Array.to_list
        |> List.concat
        |> Array.of_list
        |> Array.map ~f:fst
        |> Mat.concatenate ~axis:0
      in
      let posterior_us =
        v
        |> Array.to_list
        |> List.concat
        |> Array.of_list
        |> Array.map ~f:snd
        |> Mat.concatenate ~axis:0
      in
      let r2_u = compute_R2 posterior_us train_us in
      let r2_z = compute_R2 posterior_z train_zs in
      Mat.save_txt
        ~append:true
        ~out:(in_dir "train_r2")
        (Mat.of_array [| r2_z; r2_u |] 1 2))


let save_test ?u_init prefix prms data =
  let prms = C.broadcast prms in
  let file s = prefix ^ "." ^ s in
  C.root_perform (fun () ->
      Model.P.save_to_files ~prefix ~prms;
      let a =
        D.unpack_a ~prms:prms.VAE_P.generative.Generative_P.dynamics |> AD.unpack_arr
      in
      let e = Linalg.D.eigvals a in
      let er, ei = Dense.Matrix.Z.(re e, im e) in
      Mat.(save_txt ~out:(file "a_eig") (transpose (er @= ei))));
  Array.iteri data ~f:(fun i dat_trial ->
      if Int.(i % C.n_nodes = C.rank)
      then (
        let mu = Model.posterior_mean ?u_init:None ~prms dat_trial in
        Owl.Mat.save_txt
          ~out:(file (Printf.sprintf "posterior_u_%i" i))
          (AD.unpack_arr mu);
        let us, zs, os = Model.predictions ~n_samples:100 ~prms mu in
        let process label a =
          let a = AD.unpack_arr a in
          Owl.Arr.(mean ~axis:2 a @|| var ~axis:2 a)
          |> (fun z -> Owl.Arr.reshape z [| n_steps; -1 |])
          |> Mat.save_txt ~out:(file (Printf.sprintf "predicted_%s_%i" label i))
        in
        process "u" us;
        process "z" zs;
        Array.iter ~f:(fun (label, x) -> process label x) os));
  let test_us =
    let uss = Arr.split (Array.init n_trials ~f:(fun _ -> 1)) test_u in
    Array.map ~f:(fun x -> Arr.reshape x [| -1; m |]) uss |> Mat.concatenate ~axis:0
  in
  let test_zs =
    let zss = Arr.split (Array.init n_trials ~f:(fun _ -> 1)) test_z in
    Array.map ~f:(fun x -> Arr.reshape x [| -1; n |]) zss |> Mat.concatenate ~axis:0
  in
  Array.foldi data ~init:[] ~f:(fun i accu dat ->
      if Int.(i % C.n_nodes = C.rank)
      then (
        let mu = Model.posterior_mean ?u_init:None ~prms dat in
        let _, zs, _ = Model.predictions ~n_samples:100 ~prms mu in
        let z = AD.unpack_arr zs in
        let u = Arr.get_slice [ [ n_beg - 1; -1 ] ] (AD.unpack_arr mu) in
        let m =
          Owl.Arr.(mean ~axis:2 z) |> fun z -> Owl.Arr.reshape z [| n_steps; -1 |]
        in
        (m, u) :: accu)
      else accu)
  |> C.gather
  |> fun v ->
  C.root_perform (fun () ->
      (* v is an array of lists *)
      let posterior_z =
        v
        |> Array.to_list
        |> List.concat
        |> Array.of_list
        |> Array.map ~f:fst
        |> Mat.concatenate ~axis:0
      in
      let posterior_us =
        v
        |> Array.to_list
        |> List.concat
        |> Array.of_list
        |> Array.map ~f:snd
        |> Mat.concatenate ~axis:0
      in
      let r2_u = compute_R2 posterior_us test_us in
      let r2_z = compute_R2 posterior_z test_zs in
      Mat.save_txt
        ~append:true
        ~out:(in_dir "test_r2")
        (Mat.of_array [| r2_z; r2_u |] 1 2))


let _ = C.print "saving init"

(* let _ = save_results (in_dir "init") init_prms data *)
let elbobo_sample (p : Model.P.p) = p
let elbobo_exploit (p : Model.P.p) : Model.P.p = p

(* let _ = Model.check_grad ~prms:init_prms data `all (in_dir "check_grad") *)

let final_prms =
  let in_each_iteration ~u_init ~prms k =
    if Int.(k % 50 = 0) then save_results ~u_init (in_dir "final") prms data;
    if Int.(k % 50 = 0) then save_test ~u_init (in_dir "test") prms test_data;
    if Int.(k % 10 = 0)
    then C.root_perform (fun () -> Misc.save_bin ~out:(in_dir "final") prms)
  in
  Model.train
    ~n_samples:(fun _ -> 100)
    ~max_iter:Cmdargs.(get_int "-max_iter" |> default 2000)
    ~save_progress_to:(1, 100, in_dir "progress")
    ~conv_threshold:(fun _ -> 1E-6)
    ~in_each_iteration
    ~eta:(`of_iter (fun k -> Float.(0.02 / (1. + sqrt (of_int k / 50.)))))
    ~init_prms
    ~elbobo_sample
    data


let _ = save_results (in_dir "final") final_prms data
