open Base
open Owl_parameters

module type ELBOBO_T = sig
  type sample

  module P_prior : Owl_parameters.T
  module P_posterior : Owl_parameters.T

  val kl : P_prior.p -> P_posterior.p -> AD.t
  val sample : P_posterior.p -> sample
end

module ARD_P = struct
  module Prior = struct
    type 'a prm =
      { ard_scales : 'a
      ; a : 'a
      ; b : 'a Option.t
      ; c : ('a * String.t) Array.t
      }
    [@@deriving accessors ~submodule:A]

    let map ~f x =
      { ard_scales = f x.ard_scales
      ; a = f x.a
      ; b = Option.map ~f x.b
      ; c = Array.map ~f:(fun (x, label) -> f x, label) x.c
      }


    let fold ?prefix ~init ~f x =
      let w = with_prefix ?prefix in
      let init = f init (x.ard_scales, w "ard_scales") in
      let init = f init (x.a, w "a") in
      let init = Option.fold x.b ~init ~f:(fun init b -> f init (b, w "b")) in
      Array.fold x.c ~init ~f:(fun init (x, label) ->
          f init (x, w (Printf.sprintf "c.%s" label)))
  end

  module Posterior = struct
    type 'a prm =
      { a_mean : 'a
      ; b_mean : 'a Option.t
      ; c_mean : ('a * String.t) Array.t
      ; a_var : 'a
      ; b_var : 'a Option.t
      ; c_var : 'a Array.t
      }
    [@@deriving accessors ~submodule:A]

    let map ~f x =
      { a_mean = f x.a_mean
      ; b_mean = Option.map ~f x.b_mean
      ; c_mean = Array.map ~f:(fun (x, label) -> f x, label) x.c_mean
      ; a_var = f x.a_var
      ; b_var = Option.map ~f x.b_var
      ; c_var = Array.map ~f x.c_var
      }


    let fold ?prefix ~init ~f x =
      let w = with_prefix ?prefix in
      let init = f init (x.a_mean, w "a_mean") in
      let init = Option.fold x.b_mean ~init ~f:(fun init b -> f init (b, w "b_mean")) in
      let init =
        Array.fold x.c_mean ~init ~f:(fun init (x, label) ->
            f init (x, w (Printf.sprintf "c_mean.%s" label)))
      in
      let init = f init (x.a_var, w "a_var") in
      let init = Option.fold x.b_var ~init ~f:(fun init b -> f init (b, w "b_var")) in
      let init =
        Array.foldi x.c_var ~init ~f:(fun i init y ->
            let _, label = x.c_mean.(i) in
            f init (y, w (Printf.sprintf "c_var.%s" label)))
      in
      init
  end
end

module Matrix_P = struct
  module Prior = struct
    type 'a prm = { prior_vars : 'a } [@@deriving accessors ~submodule:A]

    let map ~f x = { prior_vars = f x.prior_vars }
    let fold ?prefix ~init ~f x = f init (x.prior_vars, with_prefix ?prefix "prior_vars")
  end

  module Posterior = struct
    type 'a prm =
      { posterior_mean : 'a
      ; posterior_vars : 'a
      }
    [@@deriving accessors ~submodule:A]

    let map ~f x =
      { posterior_mean = f x.posterior_mean; posterior_vars = f x.posterior_vars }


    let fold ?prefix ~init ~f x =
      let init = f init (x.posterior_mean, with_prefix ?prefix "posterior_mean") in
      f init (x.posterior_vars, with_prefix ?prefix "posterior_vars")
  end
end

include Pair_typ

module ARD_GRU_P = struct
  module Prior = struct
    type 'a prm =
      { ard_scales : 'a
      ; uf : 'a
      ; uh : 'a
      ; b : 'a Option.t
      ; c : ('a * String.t) Array.t
      }
    [@@deriving accessors ~submodule:A]

    let map ~f x =
      { ard_scales = f x.ard_scales
      ; uf = f x.uf
      ; uh = f x.uh
      ; b = Option.map ~f x.b
      ; c = Array.map ~f:(fun (x, label) -> f x, label) x.c
      }


    let fold ?prefix ~init ~f x =
      let w = with_prefix ?prefix in
      let init = f init (x.ard_scales, w "ard_scales") in
      let init = f init (x.uf, w "uf") in
      let init = f init (x.uh, w "uh") in
      let init = Option.fold x.b ~init ~f:(fun init b -> f init (b, w "b")) in
      Array.fold x.c ~init ~f:(fun init (x, label) ->
          f init (x, w (Printf.sprintf "c.%s" label)))
  end

  module Posterior = struct
    type 'a prm =
      { uf_mean : 'a
      ; uh_mean : 'a
      ; b_mean : 'a Option.t
      ; c_mean : ('a * String.t) Array.t
      ; uf_var : 'a
      ; uh_var : 'a
      ; b_var : 'a Option.t
      ; c_var : 'a Array.t
      }
    [@@deriving accessors ~submodule:A]

    let map ~f x =
      { uf_mean = f x.uf_mean
      ; uh_mean = f x.uh_mean
      ; b_mean = Option.map ~f x.b_mean
      ; c_mean = Array.map ~f:(fun (x, label) -> f x, label) x.c_mean
      ; uf_var = f x.uf_var
      ; uh_var = f x.uh_var
      ; b_var = Option.map ~f x.b_var
      ; c_var = Array.map ~f x.c_var
      }


    let fold ?prefix ~init ~f x =
      let w = with_prefix ?prefix in
      let init = f init (x.uf_mean, w "uf_mean") in
      let init = f init (x.uh_mean, w "uh_mean") in
      let init = Option.fold x.b_mean ~init ~f:(fun init b -> f init (b, w "b_mean")) in
      let init =
        Array.fold x.c_mean ~init ~f:(fun init (x, label) ->
            f init (x, w (Printf.sprintf "c_mean.%s" label)))
      in
      let init = f init (x.uf_var, w "uf_var") in
      let init = f init (x.uh_var, w "uh_var") in
      let init = Option.fold x.b_var ~init ~f:(fun init b -> f init (b, w "b_var")) in
      let init =
        Array.foldi x.c_var ~init ~f:(fun i init y ->
            let _, label = x.c_mean.(i) in
            f init (y, w (Printf.sprintf "c_var.%s" label)))
      in
      init
  end
end
