open Base
open Owl_parameters

module type Likelihood_T = sig
  module P : Owl_parameters.T
  open P

  type datum
  type data

  val requires_linesearch : bool
  val label : string
  val save_data : ?prefix:string -> data -> unit
  val data_slice : k:int -> data -> datum
  val to_mat_list : data -> (string * AD.t) list
  val size : prms:p -> int
  val pre_sample : prms:p -> z:AD.t -> data
  val sample : prms:p -> z:AD.t -> data
  val neg_logp_t : prms:p -> data_t:datum -> k:int -> z_t:AD.t -> AD.t
  val neg_jac_t : (prms:p -> data_t:datum -> k:int -> z_t:AD.t -> AD.t) option
  val neg_hess_t : (prms:p -> data_t:datum -> k:int -> z_t:AD.t -> AD.t) option
  val logp : prms:p -> data:data -> z:AD.t -> AD.t
end

module type Likelihood_Shared_T = sig
  module P : Owl_parameters.T
  open P

  type datum
  type data

  val requires_linesearch : bool
  val label : string
  val save_data : ?prefix:string -> data -> unit
  val data_slice : k:int -> 'a * data -> 'a * datum
  val to_mat_list : data -> (string * AD.t) list
  val size : i:int -> prms:p -> int
  val pre_sample : i:int -> prms:p -> z:AD.t -> data
  val sample : prms:p -> z:AD.t -> i:int -> data
  val neg_logp_t : prms:p -> data_t:int * datum -> k:int -> z_t:AD.t -> AD.t
  val neg_jac_t : (prms:p -> data_t:int * datum -> k:int -> z_t:AD.t -> AD.t) option
  val neg_hess_t : (prms:p -> data_t:int * datum -> k:int -> z_t:AD.t -> AD.t) option
  val logp : prms:p -> data:int * data -> z:AD.t -> AD.t
end

module Gaussian_P = struct
  type 'a prm =
    { c : 'a
    ; c_mask : AD.t option
    ; bias : 'a
    ; variances : 'a (* 1 x space *)
    }
  [@@deriving accessors ~submodule:A]

  let map ~f x =
    { c = f x.c; c_mask = x.c_mask; bias = f x.bias; variances = f x.variances }


  let fold ?prefix ~init ~f x =
    let init = f init (x.c, with_prefix ?prefix "c") in
    let init = f init (x.bias, with_prefix ?prefix "bias") in
    f init (x.variances, with_prefix ?prefix "variances")
end

module Gaussian_nonlinear_P = struct
  type 'a prm =
    { alpha : 'a
    ; c : 'a
    ; c_mask : AD.t option
    ; bias : 'a
    ; variances : 'a (* 1 x space *)
    }
  [@@deriving accessors ~submodule:A]

  let map ~f x =
    { alpha = f x.alpha
    ; c = f x.c
    ; c_mask = x.c_mask
    ; bias = f x.bias
    ; variances = f x.variances
    }


  let fold ?prefix ~init ~f x =
    let init = f init (x.alpha, with_prefix ?prefix "alpha") in
    let init = f init (x.c, with_prefix ?prefix "c") in
    let init = f init (x.bias, with_prefix ?prefix "bias") in
    f init (x.variances, with_prefix ?prefix "variances")
end

module Target_P = struct
  type 'a prm =
    { c : 'a
    ; variances : 'a (* 1 x space *)
    }
  [@@deriving accessors ~submodule:A]

  let map ~f x = { c = f x.c; variances = f x.variances }

  let fold ?prefix ~init ~f x =
    let init = f init (x.c, with_prefix ?prefix "c") in
    f init (x.variances, with_prefix ?prefix "variances")
end

module Poisson_P = struct
  type 'a prm =
    { c : 'a
    ; c_mask : AD.t option
    ; bias : 'a
    ; gain : 'a
    }
  [@@deriving accessors ~submodule:A]

  let map ~f x = { c = f x.c; c_mask = x.c_mask; bias = f x.bias; gain = f x.gain }

  let fold ?prefix ~init ~f x =
    let init = f init (x.c, with_prefix ?prefix "c") in
    let init = f init (x.bias, with_prefix ?prefix "bias") in
    f init (x.gain, with_prefix ?prefix "gain")
end

include Pair_typ

module Gaussian_Shared_P = struct
  type 'a prm =
    { c : 'a array
    ; c_mask : AD.t option array
    ; bias : 'a array
    ; variances : 'a array (* 1 x space *)
    }
  [@@deriving accessors ~submodule:A]

  let map ~f x =
    { c = Array.map ~f x.c
    ; c_mask = x.c_mask
    ; bias = Array.map ~f x.bias
    ; variances = Array.map ~f x.variances
    }


  let fold ?prefix ~init ~f x =
    let w = with_prefix ?prefix in
    let init =
      Array.foldi x.c ~init ~f:(fun i init y -> f init (y, w (Printf.sprintf "c.%i" i)))
    in
    let init =
      Array.foldi x.bias ~init ~f:(fun i init y ->
          f init (y, w (Printf.sprintf "bias.%i" i)))
    in
    Array.foldi x.variances ~init ~f:(fun i init y ->
        f init (y, w (Printf.sprintf "variances.%i" i)))
end
