open Base
include Elbobo_typ
open Owl_parameters

module Matrix_column_prior = struct
  module P_prior = Owl_parameters.Make (Matrix_P.Prior)
  module P_posterior = Owl_parameters.Make (Matrix_P.Posterior)
  open Matrix_P.Prior
  open Matrix_P.Posterior

  type sample = AD.t

  (* Gaussian KL *)
  let kl prior posterior =
    let prior_vars = extract prior.prior_vars (* 1 x N *) in
    let post_vars = extract posterior.posterior_vars (* N_out x N *) in
    let mu = extract posterior.posterior_mean (* N_out x N *) in
    let n_out = AD.Mat.row_num mu in
    let n = AD.Mat.col_num mu in
    let d = AD.F Float.(of_int Int.(n * n_out)) in
    let logdet_term =
      AD.Maths.((F Float.(of_int n_out) * sum' (log prior_vars)) - sum' (log post_vars))
    in
    let trace_term = AD.Maths.(sum' (post_vars / prior_vars)) in
    let quadratic_term = AD.Maths.(l2norm_sqr' (mu / sqrt prior_vars)) in
    AD.Maths.(F 0.5 * (logdet_term - d + trace_term + quadratic_term))


  let sample _ = assert false
end

module Matrix_scalar_prior = struct
  module P_prior = Owl_parameters.Make (Matrix_P.Prior)
  module P_posterior = Owl_parameters.Make (Matrix_P.Posterior)
  open Matrix_P.Prior
  open Matrix_P.Posterior

  type sample = AD.t

  (* Gaussian KL *)
  let kl prior posterior =
    let prior_vars = extract prior.prior_vars (* 1 x 1 *) in
    let post_vars = extract posterior.posterior_vars (* N_out x N *) in
    let mu = extract posterior.posterior_mean (* N_out x N *) in
    let n_out = AD.Mat.row_num mu in
    let n = AD.Mat.col_num mu in
    let d = AD.F Float.(of_int Int.(n * n_out)) in
    let logdet_term = AD.Maths.((d * sum' (log prior_vars)) - sum' (log post_vars)) in
    let trace_term = AD.Maths.(sum' (post_vars / prior_vars)) in
    let quadratic_term = AD.Maths.(l2norm_sqr' (mu / sqrt prior_vars)) in
    AD.Maths.(F 0.5 * (logdet_term - d + trace_term + quadratic_term))


  let sample _ = assert false
end

module Pair (E1 : sig
  include ELBOBO_T

  val label : string
end) (E2 : sig
  include ELBOBO_T

  val label : string
end) =
struct
  module P_prior =
    Owl_parameters.Make
      (Pair_P.Make
         (struct
           include E1.P_prior

           let label = E1.label
         end)
         (struct
           include E2.P_prior

           let label = E2.label
         end))

  module P_posterior =
    Owl_parameters.Make
      (Pair_P.Make
         (struct
           include E1.P_posterior

           let label = E1.label
         end)
         (struct
           include E2.P_posterior

           let label = E2.label
         end))

  open Pair_P

  type sample = (AD.t, AD.t) Pair_P.prm_

  let kl prior posterior =
    AD.Maths.(E1.kl prior.fst posterior.fst + E2.kl prior.snd posterior.snd)


  let sample _ = assert false
end

module ARD = struct
  module P_prior = Owl_parameters.Make (ARD_P.Prior)
  module P_posterior = Owl_parameters.Make (ARD_P.Posterior)
  open ARD_P.Prior
  open ARD_P.Posterior

  type sample = AD.t * AD.t Option.t * AD.t Array.t

  (* Gaussian KL *)
  let kl_column prior_vars post_mean post_vars =
    (* prior_vars : 1 x N
    ;  post_vars : Z x N
    ;  post_mean : Z x N  
    *)
    let n_out = AD.Mat.row_num post_mean in
    let n = AD.Mat.col_num post_mean in
    let d = AD.F Float.(of_int Int.(n * n_out)) in
    let logdet_term =
      AD.Maths.((F Float.(of_int n_out) * sum' (log prior_vars)) - sum' (log post_vars))
    in
    let trace_term = AD.Maths.(sum' (post_vars / prior_vars)) in
    let quadratic_term = AD.Maths.(l2norm_sqr' (post_mean / sqrt prior_vars)) in
    AD.Maths.(F 0.5 * (logdet_term - d + trace_term + quadratic_term))


  let kl_square prior_stds post_mean post_vars =
    (* prior_vars : 1 x N
    ;  post_vars : N x N
    ;  post_mean : N x N  
    *)
    let n = AD.Mat.row_num post_mean in
    let d = AD.F Float.(of_int Int.(n * n)) in
    let logdet_term =
      AD.Maths.((F Float.(2. * of_int n) * sum' (log prior_stds)) - sum' (log post_vars))
    in
    let trace_term =
      AD.Maths.(sum' (F 1. / transpose prior_stds * post_vars / prior_stds))
    in
    let quadratic_term =
      AD.Maths.(l2norm_sqr' (F 1. / transpose prior_stds * post_mean / prior_stds))
    in
    AD.Maths.(F 0.5 * (logdet_term - d + trace_term + quadratic_term))


  let kl prior posterior =
    let scales = extract prior.ard_scales in
    let kl_a =
      let prior_stds = AD.Maths.(sqrt (extract prior.a) * scales) in
      let post_mean = extract posterior.a_mean in
      let post_vars = extract posterior.a_var in
      kl_square prior_stds post_mean post_vars
    in
    let kl_b =
      match prior.b, posterior.b_mean, posterior.b_var with
      | Some p, Some pm, Some pv ->
        let prior_vars = AD.Maths.(extract p * sqr scales) in
        let post_mean = extract pm in
        let post_vars = extract pv in
        kl_column prior_vars post_mean post_vars
      | None, None, None -> AD.F 0.
      | _ -> failwith "inconsistent options for ELBOBO B"
    in
    let kl_c =
      Array.foldi posterior.c_mean ~init:(AD.F 0.) ~f:(fun i accu (post_mean, _) ->
          let post_vars = posterior.c_var.(i) in
          let prior_vars =
            let z, _ = prior.c.(i) in
            AD.Maths.(extract z * sqr scales)
          in
          AD.Maths.(accu + kl_column prior_vars (extract post_mean) (extract post_vars)))
    in
    AD.Maths.(kl_a + kl_b + kl_c)


  let sample prms =
    let sample mu sigma2 = AD.Maths.(mu + AD.Arr.(gaussian (shape mu) * sqrt sigma2)) in
    let a = sample (extract prms.a_mean) (extract prms.a_var) in
    let b =
      match prms.b_mean, prms.b_var with
      | Some mu, Some sigma2 -> Some (sample (extract mu) (extract sigma2))
      | None, None -> None
      | _ -> failwith "inconsistent option type for B"
    in
    let c =
      Array.mapi prms.c_mean ~f:(fun i (mu, _) ->
          sample (extract mu) (extract prms.c_var.(i)))
    in
    a, b, c
end

module NONE = struct
  module P_prior = Owl_parameters.Empty
  module P_posterior = Owl_parameters.Empty

  type sample = unit

  let kl _ _ = AD.F 0.
  let sample _ = ()
end

module ARD_GRU = struct
  module P_prior = Owl_parameters.Make (ARD_GRU_P.Prior)
  module P_posterior = Owl_parameters.Make (ARD_GRU_P.Posterior)
  open ARD_GRU_P.Prior
  open ARD_GRU_P.Posterior

  type sample = AD.t * AD.t * AD.t Option.t * AD.t Array.t

  (* Gaussian KL *)
  let kl_column prior_vars post_mean post_vars =
    (* prior_vars : 1 x N
    ;  post_vars : Z x N
    ;  post_mean : Z x N  
    *)
    let n_out = AD.Mat.row_num post_mean in
    let n = AD.Mat.col_num post_mean in
    let d = AD.F Float.(of_int Int.(n * n_out)) in
    let logdet_term =
      AD.Maths.((F Float.(of_int n_out) * sum' (log prior_vars)) - sum' (log post_vars))
    in
    let trace_term = AD.Maths.(sum' (post_vars / prior_vars)) in
    let quadratic_term = AD.Maths.(l2norm_sqr' (post_mean / sqrt prior_vars)) in
    AD.Maths.(F 0.5 * (logdet_term - d + trace_term + quadratic_term))


  let kl_square prior_stds post_mean post_vars =
    (* prior_vars : 1 x N
    ;  post_vars : N x N
    ;  post_mean : N x N  
    *)
    let n = AD.Mat.row_num post_mean in
    let d = AD.F Float.(of_int Int.(n * n)) in
    let logdet_term =
      AD.Maths.((F Float.(2. * of_int n) * sum' (log prior_stds)) - sum' (log post_vars))
    in
    let trace_term =
      AD.Maths.(sum' (F 1. / transpose prior_stds * post_vars / prior_stds))
    in
    let quadratic_term =
      AD.Maths.(l2norm_sqr' (F 1. / transpose prior_stds * post_mean / prior_stds))
    in
    AD.Maths.(F 0.5 * (logdet_term - d + trace_term + quadratic_term))


  let kl prior posterior =
    let scales = extract prior.ard_scales in
    let kl_uf =
      let prior_stds = AD.Maths.(sqrt (extract prior.uf) * scales) in
      let post_mean = extract posterior.uf_mean in
      let post_vars = extract posterior.uf_var in
      kl_square prior_stds post_mean post_vars
    in
    let kl_uh =
      let prior_stds = AD.Maths.(sqrt (extract prior.uh) * scales) in
      let post_mean = extract posterior.uh_mean in
      let post_vars = extract posterior.uh_var in
      kl_square prior_stds post_mean post_vars
    in
    let kl_b =
      match prior.b, posterior.b_mean, posterior.b_var with
      | Some p, Some pm, Some pv ->
        let prior_vars = AD.Maths.(extract p * sqr scales) in
        let post_mean = extract pm in
        let post_vars = extract pv in
        kl_column prior_vars post_mean post_vars
      | None, None, None -> AD.F 0.
      | _ -> failwith "inconsistent options for ELBOBO B"
    in
    let kl_c =
      Array.foldi posterior.c_mean ~init:(AD.F 0.) ~f:(fun i accu (post_mean, _) ->
          let post_vars = posterior.c_var.(i) in
          let prior_vars =
            let z, _ = prior.c.(i) in
            AD.Maths.(extract z * sqr scales)
          in
          AD.Maths.(accu + kl_column prior_vars (extract post_mean) (extract post_vars)))
    in
    AD.Maths.(kl_uf + kl_uh + kl_b + kl_c)


  let sample prms =
    let sample mu sigma2 = AD.Maths.(mu + AD.Arr.(gaussian (shape mu) * sqrt sigma2)) in
    let uf = sample (extract prms.uf_mean) (extract prms.uf_var) in
    let uh = sample (extract prms.uh_mean) (extract prms.uh_var) in
    let b =
      match prms.b_mean, prms.b_var with
      | Some mu, Some sigma2 -> Some (sample (extract mu) (extract sigma2))
      | None, None -> None
      | _ -> failwith "inconsistent option type for B"
    in
    let c =
      Array.mapi prms.c_mean ~f:(fun i (mu, _) ->
          sample (extract mu) (extract prms.c_var.(i)))
    in
    uf, uh, b, c
end
