open Base
open Owl
include Dynamics_typ

module Integrate (D : Dynamics_T) = struct
  let integrate ~prms =
    let dyn_k = D.dyn ~theta:prms in
    fun ~n ~u ->
      (* assume u is n_samples x n_steps x m *)
      assert (Array.length (AD.shape u) = 3);
      let u = AD.Maths.transpose ~axis:[| 1; 0; 2 |] u in
      (* now u is T x K x M *)
      let n_steps = AD.(shape u).(0) in
      let n_samples = AD.(shape u).(1) in
      let x0 = AD.Mat.zeros n_samples n in
      let us =
        let u = AD.Maths.reshape u [| n_steps; -1 |] in
        AD.Maths.split ~axis:0 (Array.init n_steps ~f:(fun _ -> 1)) u
        |> Array.map ~f:(fun v -> AD.Maths.reshape v [| n_samples; -1 |])
        |> Array.to_list
      in
      let rec dyn k x xs us =
        match us with
        | [] -> List.rev xs
        | u :: unexts ->
          let new_x = dyn_k ~k ~x ~u in
          dyn (k + 1) new_x (new_x :: xs) unexts
      in
      dyn 0 x0 [] us
      |> Array.of_list
      |> Array.map ~f:(fun v -> AD.Maths.reshape v [| 1; n_samples; n |])
      |> AD.Maths.concatenate ~axis:0 (* T x K x N *)
      |> AD.Maths.transpose ~axis:[| 1; 0; 2 |]

  (* result KxTxN *)
end

module Linear (X : sig
  val n_beg : int Option.t
end) =
struct
  module P = Owl_parameters.Make (Linear_P)
  open Linear_P

  let requires_linesearch = false

  (* alpha is the spectral abscissa of the equivalent continuous-time system
     beta is the spectral radius of the random S *)
  let init ~dt_over_tau ~alpha ~beta (set : Owl_parameters.setter) n m =
    (* exp (dt_over_tau * (W-I))
       where W = alpha*I + S *)
    let d =
      let tmp = Float.(exp (-2. * dt_over_tau * (1.0 - alpha))) in
      Mat.init_2d 1 n (fun _ _ -> Float.(tmp / (1. - tmp)))
    in
    let u = AD.Mat.eye n in
    let q =
      let s = Mat.(Float.(beta * dt_over_tau / sqrt (2. * of_int n)) $* gaussian n n) in
      Linalg.D.expm Mat.(s - transpose s)
    in
    let b =
      if n = m
      then None
      else Some (set (AD.Mat.gaussian ~sigma:Float.(1. / sqrt (of_int m)) m n))
    in
    { d = set ~above:1E-5 (AD.pack_arr d); u = set u; q = set (AD.pack_arr q); b }


  let unpack_a ~prms =
    let q =
      let q, r = AD.Linalg.qr (Owl_parameters.extract prms.q) in
      let r = AD.Maths.diag r in
      AD.Maths.(q * signum r)
    in
    let u =
      let q, r = AD.Linalg.qr (Owl_parameters.extract prms.u) in
      let r = AD.Maths.diag r in
      AD.Maths.(q * signum r)
    in
    let d = Owl_parameters.extract prms.d in
    let dp1_sqrt_inv = AD.Maths.(F 1. / sqrt (F 1. + d)) in
    let d_sqrt = AD.Maths.(sqrt d) in
    AD.Maths.(u * d_sqrt *@ (q * dp1_sqrt_inv) *@ transpose u)


  let generate_bs ~n ~m =
    match X.n_beg with
    | None -> Array.init 1 ~f:(fun _ -> AD.Mat.zeros 1 1)
    | Some nb ->
      let nr = n / nb in
      let _ = assert (nr = m) in
      Array.init nb ~f:(fun i ->
          let inr = i * nr in
          let rnr = n - ((i + 1) * nr) in
          AD.Maths.(
            transpose
              (concatenate
                 ~axis:0
                 [| AD.Mat.zeros inr m; AD.Mat.eye nr; AD.Mat.zeros rnr m |])))


  let dyn ~theta =
    let a = unpack_a ~prms:theta in
    let n = AD.Mat.row_num a in
    let b =
      match theta.b with
      | None -> AD.Mat.(eye n)
      | Some t -> Owl_parameters.extract t
    in
    let m = AD.Mat.row_num b in
    let beg_bs = generate_bs ~n ~m in
    fun ~k ~x ~u ->
      let default = AD.Maths.((x *@ a) + (u *@ b)) in
      match X.n_beg with
      | None -> default
      | Some i -> if k < i then AD.Maths.(x + (u *@ beg_bs.(k))) else default


  let dyn_x =
    (* Marine to check this *)
    let dyn_x ~theta =
      let a = unpack_a ~prms:theta in
      let n = AD.Mat.row_num a in
      let b =
        match theta.b with
        | None -> AD.Mat.(eye n)
        | Some t -> Owl_parameters.extract t
      in
      let m = AD.Mat.row_num b in
      let beg_bs = generate_bs ~n ~m in
      fun ~k ~x:_ ~u:_ ->
        match X.n_beg with
        | None -> a
        | Some i -> if k < i then AD.Mat.eye n else a
    in
    Some dyn_x


  let dyn_u =
    (* Marine to check this *)
    let dyn_u ~theta =
      let q = Owl_parameters.extract theta.q in
      let n = AD.Mat.row_num q in
      let b =
        match theta.b with
        | None -> AD.Mat.(eye n)
        | Some t -> Owl_parameters.extract t
      in
      let m = AD.Mat.row_num b in
      let beg_bs = generate_bs ~n ~m in
      fun ~k ~x:_ ~u:_ ->
        match X.n_beg with
        | None -> b
        | Some i -> if k < i then beg_bs.(k) else b
    in
    Some dyn_u
end

module Nonlinear (X : sig
  val phi : AD.t -> AD.t
  val d_phi : AD.t -> AD.t
end) =
struct
  module P = Owl_parameters.Make (Nonlinear_P)
  open Nonlinear_P
  open X

  let requires_linesearch = true

  let init ?(radius = 0.1) ~n ~m (set : Owl_parameters.setter) =
    let sigma = Float.(radius / sqrt (of_int n)) in
    { a = set (AD.Mat.gaussian ~sigma n n)
    ; bias = set (AD.Mat.zeros 1 n)
    ; b = Some (set (AD.Mat.gaussian ~sigma:Float.(1. / sqrt (of_int m)) m n))
    }


  (* grow a network to size n; A and bias are padded with zeros, B is padded with gaussian *)
  let enlarge ~(prms : P.p) n =
    let open Owl_parameters in
    let n_prev = AD.Mat.row_num (extract prms.a) in
    let n_pad = n - n_prev in
    assert (Int.(n_pad >= 0));
    let a =
      prms.a
      |> map (fun a ->
             a
             |> AD.unpack_arr
             |> fun x -> Mat.((x @|| zeros n_prev n_pad) @= zeros n_pad n) |> AD.pack_arr)
    in
    let bias =
      prms.bias
      |> map (fun bias ->
             bias |> AD.unpack_arr |> fun x -> Mat.(x @|| zeros 1 n_pad) |> AD.pack_arr)
    in
    let b =
      Option.map
        prms.b
        ~f:
          (map (fun b ->
               b
               |> AD.unpack_arr
               |> fun x ->
               let m = Mat.row_num x in
               Mat.(x @|| gaussian ~sigma:Float.(1. / sqrt (of_int m)) m n_pad)
               |> AD.pack_arr))
    in
    { a; bias; b }


  let b_rescaled ~prms =
    Option.map prms.b ~f:(function b ->
        let b = Owl_parameters.extract b in
        AD.Maths.(b / sqrt (sum ~axis:0 (sqr b))))


  let u_eff ~prms =
    match b_rescaled ~prms with
    | None -> fun u -> u
    | Some b -> fun u -> AD.Maths.(u *@ b)


  let dyn ~theta =
    let a = Owl_parameters.extract theta.a in
    let bias = Owl_parameters.extract theta.bias in
    let u_eff = u_eff ~prms:theta in
    fun ~k:_ ~x ~u -> AD.Maths.((phi x *@ a) + u_eff u + bias)


  let dyn_x =
    let dyn_x ~theta =
      let a = Owl_parameters.extract theta.a in
      fun ~k:_ ~x ~u:_ ->
        let d = d_phi x in
        AD.Maths.(transpose d * a)
    in
    Some dyn_x


  let dyn_u =
    let dyn_u ~theta =
      match b_rescaled ~prms:theta with
      | None -> fun ~k:_ ~x:_ ~u -> AD.Mat.(eye (col_num u))
      | Some b -> fun ~k:_ ~x:_ ~u:_ -> b
    in
    Some dyn_u
end

module Two_area (X : sig
  val phi : AD.t -> AD.t
  val d_phi : AD.t -> AD.t
end) =
struct
  module P = Owl_parameters.Make (Two_area_P)
  open Two_area_P
  open X

  let requires_linesearch = true

  (* alpha is the spectral abscissa of the equivalent continuous-time system
     beta is the spectral radius of the random S *)
  let init ?b ~radius (set : Owl_parameters.setter) n1 n2 =
    let n = n1 + n2 in
    let a11 = AD.Mat.gaussian ~sigma:Float.(radius / sqrt (of_int n1)) n1 n1 in
    let a21 = AD.Mat.gaussian ~sigma:Float.(1. / sqrt (of_int n1)) n1 n2 in
    let a22 = AD.Mat.gaussian ~sigma:Float.(radius / sqrt (of_int n2)) n2 n2 in
    let bias = AD.Mat.zeros 1 n in
    let b = Option.map b ~f:(fun b -> set (AD.pack_arr b)) in
    { a11 = set a11; a21 = set a21; a22 = set a22; bias = set bias; b }


  let unpack_a ~prms =
    let a11 = Owl_parameters.extract prms.a11 in
    let a21 = Owl_parameters.extract prms.a21 in
    let a22 = Owl_parameters.extract prms.a22 in
    let n1 = AD.Mat.row_num a11 in
    let n2 = AD.Mat.row_num a22 in
    let z = AD.Mat.zeros n2 n1 in
    AD.Maths.concat
      ~axis:1
      (AD.Maths.concat ~axis:0 a11 z)
      (AD.Maths.concat ~axis:0 a21 a22)


  let b_rescaled ~prms =
    Option.map prms.b ~f:(function b ->
        let b = Owl_parameters.extract b in
        AD.Maths.(b / sqrt (sum ~axis:0 (sqr b))))


  let u_eff ~prms =
    match b_rescaled ~prms with
    | None -> fun u -> u
    | Some b -> fun u -> AD.Maths.(u *@ b)


  let dyn ~theta =
    let a = unpack_a ~prms:theta in
    let bias = Owl_parameters.extract theta.bias in
    let u_eff = u_eff ~prms:theta in
    fun ~k:_ ~x ~u -> AD.Maths.((phi x *@ a) + u_eff u + bias)


  let dyn_x =
    let dyn_x ~theta =
      let a = unpack_a ~prms:theta in
      fun ~k:_ ~x ~u:_ ->
        let d = d_phi x in
        AD.Maths.(transpose d * a)
    in
    Some dyn_x


  let dyn_u =
    let dyn_u ~theta =
      match b_rescaled ~prms:theta with
      | None -> fun ~k:_ ~x:_ ~u -> AD.Mat.(eye (col_num u))
      | Some b -> fun ~k:_ ~x:_ ~u:_ -> b
    in
    Some dyn_u
end

module Nonlinear_Init (X : sig
  val phi : AD.t -> AD.t
  val d_phi : AD.t -> AD.t
  val n_beg : int
end) =
struct
  module P = Owl_parameters.Make (Nonlinear_Init_P)
  open Nonlinear_Init_P
  open X

  let requires_linesearch = true

  let init ?(radius = 0.1) ~n ~m (set : Owl_parameters.setter) =
    let sigma = Float.(radius / sqrt (of_int n)) in
    { a = set (AD.Mat.gaussian ~sigma n n)
    ; bias = set (AD.Mat.zeros 1 n)
    ; b = Some (set (AD.Mat.gaussian ~sigma:Float.(1. / sqrt (of_int m)) m n))
    }


  (* grow a network to size n; A and bias are padded with zeros, B is padded with gaussian *)
  let enlarge ~(prms : P.p) n =
    let open Owl_parameters in
    let n_prev = AD.Mat.row_num (extract prms.a) in
    let n_pad = n - n_prev in
    assert (Int.(n_pad >= 0));
    let a =
      prms.a
      |> map (fun a ->
             a
             |> AD.unpack_arr
             |> fun x -> Mat.((x @|| zeros n_prev n_pad) @= zeros n_pad n) |> AD.pack_arr)
    in
    let bias =
      prms.bias
      |> map (fun bias ->
             bias |> AD.unpack_arr |> fun x -> Mat.(x @|| zeros 1 n_pad) |> AD.pack_arr)
    in
    let b =
      Option.map
        prms.b
        ~f:
          (map (fun b ->
               b
               |> AD.unpack_arr
               |> fun x ->
               let m = Mat.row_num x in
               Mat.(x @|| gaussian ~sigma:Float.(1. / sqrt (of_int m)) m n_pad)
               |> AD.pack_arr))
    in
    { a; bias; b }


  (*we can use n_beg = n/m (or one more)*)
  let generate_bs ~n ~m =
    let nr = n / X.n_beg in
    Array.init n_beg ~f:(fun i ->
        let inr = i * nr in
        let rnr = n - ((i + 1) * nr) in
        AD.Maths.(
          transpose
            (concatenate
               ~axis:0
               [| AD.Mat.zeros inr m; AD.Mat.eye nr; AD.Mat.zeros rnr m |])))


  let b_rescaled ~prms =
    Option.map prms.b ~f:(function b ->
        let b = Owl_parameters.extract b in
        AD.Maths.(b / sqrt (sum ~axis:0 (sqr b))))


  let u_eff ~prms =
    match b_rescaled ~prms with
    | None -> fun u -> u
    | Some b -> fun u -> AD.Maths.(u *@ b)


  let dyn ~theta =
    let a = Owl_parameters.extract theta.a in
    let bias = Owl_parameters.extract theta.bias in
    let n = AD.Mat.row_num a in
    let m =
      match theta.b with
      | None -> n
      | Some b -> AD.Mat.row_num (Owl_parameters.extract b)
    in
    let u_eff = u_eff ~prms:theta in
    let beg_bs = generate_bs ~n ~m in
    fun ~k ~x ~u ->
      if Int.(k < X.n_beg)
      then AD.Maths.(x + (u *@ beg_bs.(k)))
      else AD.Maths.((phi x *@ a) + u_eff u + bias)


  let dyn_x =
    let dyn_x ~theta =
      let a = Owl_parameters.extract theta.a in
      let n = AD.Mat.row_num a in
      fun ~k ~x ~u:_ ->
        if Int.(k < X.n_beg)
        then AD.Mat.eye n
        else (
          let d = d_phi x in
          AD.Maths.(transpose d * a))
    in
    Some dyn_x


  let dyn_u =
    let dyn_u ~theta =
      let n = AD.Mat.row_num (Owl_parameters.extract theta.a) in
      let m =
        match theta.b with
        | None -> n
        | Some b -> AD.Mat.row_num (Owl_parameters.extract b)
      in
      let beg_bs = generate_bs ~n ~m in
      fun ~k ->
        if Int.(k < X.n_beg)
        then fun ~x:_ ~u:_ -> beg_bs.(k)
        else (
          match b_rescaled ~prms:theta with
          | None -> fun ~x:_ ~u -> AD.Mat.(eye (col_num u))
          | Some b -> fun ~x:_ ~u:_ -> b)
    in
    Some dyn_u
end

(*two areas, with the inner one being linear*)
module Two_area_mixed (X : sig
  val phi : AD.t -> AD.t
  val d_phi : AD.t -> AD.t
end) =
struct
  module P = Owl_parameters.Make (Two_area_mixed_P)
  open Two_area_mixed_P
  open X

  let requires_linesearch = true

  let init ?b ~radius (set : Owl_parameters.setter) n1 n2 =
    let n = n1 + n2 in
    let a11 = AD.Mat.gaussian ~sigma:Float.(radius / sqrt (of_int n1)) n1 n1 in
    let a21 = AD.Mat.gaussian ~sigma:Float.(1. / sqrt (of_int n1)) n1 n2 in
    let a22 = AD.Mat.gaussian ~sigma:Float.(radius / sqrt (of_int n2)) n2 n2 in
    { a11 = set a11
    ; a21 = set a21
    ; a22 = set a22
    ; bias = set (AD.Mat.zeros 1 n)
    ; b = Option.map b ~f:(fun b -> set (AD.pack_arr b))
    }


  let unpack_a ~prms =
    let a11 = Owl_parameters.extract prms.a11 in
    let a21 = Owl_parameters.extract prms.a21 in
    let a22 = Owl_parameters.extract prms.a22 in
    a11, a21, a22


  let b_rescaled ~prms =
    Option.map prms.b ~f:(function b ->
        let b = Owl_parameters.extract b in
        AD.Maths.(b / sum ~axis:0 (sqr b)))


  let u_eff ~prms =
    match b_rescaled ~prms with
    | None -> fun u -> u
    | Some b -> fun u -> AD.Maths.(u *@ b)


  let dyn ~theta =
    let a11, a21, a22 = unpack_a ~prms:theta in
    let bias = Owl_parameters.extract theta.bias in
    let u_eff = u_eff ~prms:theta in
    let n1 = AD.Mat.row_num a11 in
    let n2 = AD.Mat.row_num a22 in
    fun ~k:_ ~x ~u ->
      let x1 = AD.Maths.get_slice [ []; [ 0; n1 - 1 ] ] x in
      let x2 = AD.Maths.get_slice [ []; [ n1; -1 ] ] x in
      let new_x1 = AD.Maths.(x1 *@ a11) in
      let new_x2 = AD.Maths.((x1 *@ a21) + (phi x2 *@ a22)) in
      AD.Maths.(concat ~axis:1 new_x1 new_x2 + u_eff u + bias)


  let dyn_x =
    let dyn_x ~theta =
      let a11, a21, a22 = unpack_a ~prms:theta in
      let n1 = AD.Mat.row_num a11 in
      let n2 = AD.Mat.row_num a22 in
      fun ~k:_ ~x ~u:_ ->
        let x2 = AD.Maths.get_slice [ []; [ n1; -1 ] ] x in
        let d2 = AD.Maths.(transpose (d_phi x2)) in
        let z = AD.Mat.zeros n2 n1 in
        let top = AD.Maths.concat ~axis:0 a11 z in
        let bottom = AD.Maths.concat ~axis:0 a21 AD.Maths.(d2 * a22) in
        AD.Maths.concat ~axis:1 top bottom
    in
    Some dyn_x


  let dyn_u =
    let dyn_u ~theta =
      match b_rescaled ~prms:theta with
      | None -> fun ~k:_ ~x:_ ~u -> AD.Mat.(eye (col_num u))
      | Some b -> fun ~k:_ ~x:_ ~u:_ -> b
    in
    Some dyn_u
end

module Nonlinear_2 (X : sig
  val phi : AD.t -> AD.t
  val d_phi : AD.t -> AD.t
  val n_beg : int Option.t
end) =
struct
  module P = Owl_parameters.Make (Nonlinear_2_P)
  open Nonlinear_2_P
  open X

  let requires_linesearch = true

  let init ?(radius = 0.1) ~n ~m (set : Owl_parameters.setter) =
    let sigma = Float.(radius / sqrt (of_int n)) in
    { a = set (AD.Mat.gaussian ~sigma n n)
    ; a_nl = set (AD.Mat.zeros n n)
    ; bias = set (AD.Mat.zeros 1 n)
    ; b = Some (set (AD.Mat.gaussian ~sigma:Float.(1. / sqrt (of_int m)) m n))
    }


  let b_rescaled ~prms =
    Option.map prms.b ~f:(function b ->
        let b = Owl_parameters.extract b in
        b)


  let generate_bs ~n ~m =
    match X.n_beg with
    | None -> Array.init 1 ~f:(fun _ -> AD.Mat.zeros 1 1)
    | Some nb ->
      let nr = n / nb in
      let _ = assert (nr = m) in
      Array.init nb ~f:(fun i ->
          let inr = i * nr in
          let rnr = n - ((i + 1) * nr) in
          AD.Maths.(
            transpose
              (concatenate
                 ~axis:0
                 [| AD.Mat.zeros inr m; AD.Mat.eye nr; AD.Mat.zeros rnr m |])))


  let u_eff ~prms =
    match b_rescaled ~prms with
    | None -> fun u -> u
    | Some b -> fun u -> AD.Maths.(u *@ b)


  let dyn ~theta =
    let a = Owl_parameters.extract theta.a in
    let a_nl = Owl_parameters.extract theta.a_nl in
    let bias = Owl_parameters.extract theta.bias in
    let u_eff = u_eff ~prms:theta in
    let n = AD.Mat.row_num a in
    let b =
      match b_rescaled ~prms:theta with
      | None -> AD.Mat.(eye n)
      | Some b -> b
    in
    let m = AD.Mat.row_num b in
    let beg_bs = generate_bs ~n ~m in
    fun ~k ~x ~u ->
      let default = AD.Maths.((x *@ a) + (phi x *@ a_nl) + u_eff u + bias) in
      match X.n_beg with
      | None -> default
      | Some i -> if k < i then AD.Maths.(x + (u *@ beg_bs.(k))) else default


  let dyn_x =
    let dyn_x ~theta =
      let a = Owl_parameters.extract theta.a in
      let a_nl = Owl_parameters.extract theta.a_nl in
      let n = AD.Mat.row_num a in
      fun ~k ~x ~u:_ ->
        let default =
          let d = d_phi x in
          AD.Maths.(a + (transpose d * a_nl))
        in
        match X.n_beg with
        | None -> default
        | Some i -> if k < i then AD.Mat.eye n else default
    in
    Some dyn_x


  let dyn_u =
    let dyn_u ~theta =
      let a = Owl_parameters.extract theta.a in
      let n = AD.Mat.row_num a in
      let b =
        match b_rescaled ~prms:theta with
        | None -> AD.Mat.(eye n)
        | Some b -> b
      in
      let n = AD.Mat.col_num b in
      let m = AD.Mat.row_num b in
      let beg_bs = generate_bs ~n ~m in
      fun ~k ~x ~u:_ ->
        match X.n_beg with
        | None -> b
        | Some i -> if k < i then beg_bs.(k) else b
    in
    Some dyn_u
end

module Mini_GRU (X : sig
  val phi : AD.t -> AD.t
  val d_phi : AD.t -> AD.t
  val sigma : AD.t -> AD.t
  val d_sigma : AD.t -> AD.t
  val n_beg : int Option.t
end) =
struct
  module P = Owl_parameters.Make (Mini_GRU_P)
  open Mini_GRU_P
  open X

  let requires_linesearch = true

  let init ?(radius = 0.1) ~n ~m (set : Owl_parameters.setter) =
    (*h = size 1xN
     x = size 1xN (x = Bu)
     h = size 1xK 
     f = size of h so 1xN
     Wf = NxN
     B = MxN *)
    let sigma = Float.(radius / sqrt (of_int n)) in
    { wf = set (AD.Mat.gaussian ~sigma:0. n n)
    ; wh = set (AD.Mat.gaussian ~sigma n n)
    ; bh = set (AD.Mat.zeros 1 n)
    ; bf = set (AD.Mat.ones 1 n)
    ; uh = set (AD.Mat.gaussian ~sigma n n)
    ; uf = set (AD.Mat.gaussian ~sigma:0. n n)
    ; b = Some (set (AD.Mat.gaussian ~sigma:Float.(1. / sqrt (of_int m)) m n))
    }


  let b_rescaled ~prms =
    Option.map prms.b ~f:(function b ->
        let b = Owl_parameters.extract b in
        b)


  let generate_bs ~n ~m =
    match X.n_beg with
    | None -> Array.init 1 ~f:(fun _ -> AD.Mat.zeros 1 1)
    | Some nb ->
      let nr = n / nb in
      let _ = assert (nr = m) in
      Array.init nb ~f:(fun i ->
          let inr = i * nr in
          let rnr = n - ((i + 1) * nr) in
          AD.Maths.(
            transpose
              (concatenate
                 ~axis:0
                 [| AD.Mat.zeros inr m; AD.Mat.eye nr; AD.Mat.zeros rnr m |])))


  let u_eff ~prms =
    match b_rescaled ~prms with
    | None -> fun u -> u
    | Some b -> fun u -> AD.Maths.(u *@ b)


  let dyn ~theta =
    let wh = Owl_parameters.extract theta.wh in
    let wf = Owl_parameters.extract theta.wf in
    let bh = Owl_parameters.extract theta.bh in
    let bf = Owl_parameters.extract theta.bf in
    let uh = Owl_parameters.extract theta.uh in
    let uf = Owl_parameters.extract theta.uf in
    let n = AD.Mat.col_num bh in
    let m =
      match theta.b with
      | None -> n
      | Some b -> AD.Mat.row_num (Owl_parameters.extract b)
    in
    let beg_bs = generate_bs ~n ~m in
    let u_eff = u_eff ~prms:theta in
    fun ~k ~x ~u ->
      let default =
        let h_pred = x in
        let x = u_eff u in
        let f = sigma AD.Maths.((x *@ wf) + bf + (h_pred *@ uf)) in
        let h_hat =
          let hf = AD.Maths.(h_pred * f) in
          phi AD.Maths.((x *@ wh) + bh + (hf *@ uh))
        in
        AD.Maths.(((F 1. - f) * h_pred) + (f * h_hat))
      in
      match X.n_beg with
      | None -> default
      | Some i -> if k < i then AD.Maths.(x + (u *@ beg_bs.(k))) else default


  let dyn_x =
    let _dyn_x ~theta =
      let wh = Owl_parameters.extract theta.wh in
      let wf = Owl_parameters.extract theta.wf in
      let bh = Owl_parameters.extract theta.bh in
      let bf = Owl_parameters.extract theta.bf in
      let uh = Owl_parameters.extract theta.uh in
      let uf = Owl_parameters.extract theta.uf in
      let u_eff = u_eff ~prms:theta in
      let n = AD.Mat.col_num bh in
      fun ~k ~x ~u ->
        let default =
          let h_pred = x in
          let x = u_eff u in
          let f_pre = AD.Maths.((x *@ wf) + bf + (h_pred *@ uf)) in
          let f = sigma f_pre in
          let h_hat_pre =
            let hf = AD.Maths.(h_pred * f) in
            AD.Maths.((x *@ wh) + bh + (hf *@ uh))
          in
          let h_hat = phi h_hat_pre in
          AD.Maths.(
            diagm (F 1. - f)
            - (uf * ((h_pred - h_hat) * d_sigma f_pre))
            + (((transpose f * uh) + (uf *@ (transpose (h_pred * d_sigma f_pre) * uh)))
              * (f * d_phi h_hat_pre)))
        in
        match X.n_beg with
        | None -> default
        | Some i -> if k < i then AD.Mat.eye n else default
    in
    (* still wrong... maybe some transposes somewhere? need to check the side of the diagm + the u transposes again... *)
    Some _dyn_x


  let dyn_u =
    let _dyn_u ~theta =
      let wh = Owl_parameters.extract theta.wh in
      let wf = Owl_parameters.extract theta.wf in
      let bh = Owl_parameters.extract theta.bh in
      let bf = Owl_parameters.extract theta.bf in
      let uh = Owl_parameters.extract theta.uh in
      let uf = Owl_parameters.extract theta.uf in
      let b =
        match b_rescaled ~prms:theta with
        | None -> AD.Mat.eye (AD.Mat.col_num bh)
        | Some b -> b
      in
      let u_eff = u_eff ~prms:theta in
      let m = AD.Mat.row_num b in
      let n = AD.Mat.col_num bh in
      let beg_bs = generate_bs ~n ~m in
      fun ~k ~x ~u ->
        let default =
          let h_pred = x in
          let x = u_eff u in
          let f_pre = AD.Maths.((x *@ wf) + bf + (h_pred *@ uf)) in
          let f = sigma f_pre in
          let h_hat_pre =
            let hf = AD.Maths.(h_pred * f) in
            AD.Maths.((x *@ wh) + bh + (hf *@ uh))
          in
          let h_hat = phi h_hat_pre in
          AD.Maths.(
            b
            *@ ((wf * (d_sigma f_pre * (h_hat - h_pred)))
               + ((wh + (wf *@ (transpose (h_pred * d_sigma f_pre) * uh)))
                 * (f * d_phi h_hat_pre))))
        in
        match X.n_beg with
        | None -> default
        | Some i -> if k < i then beg_bs.(k) else default
    in
    Some _dyn_u
end

module Mini_GRU_IO (X : sig
  val phi : AD.t -> AD.t
  val d_phi : AD.t -> AD.t
  val sigma : AD.t -> AD.t
  val d_sigma : AD.t -> AD.t
  val n_beg : int Option.t
end) =
struct
  module P = Owl_parameters.Make (Mini_GRU_IO_P)
  open Mini_GRU_IO_P
  open X

  let requires_linesearch = true

  let init ?(radius = 0.1) ~n ~m (set : Owl_parameters.setter) =
    (*h = size 1xN
     x = size 1xN (x = Bu)
     h = size 1xK 
     f = size of h so 1xN
     Wf = NxN
     B = MxN *)
    let sigma = Float.(radius / sqrt (of_int n)) in
    { wh = set (AD.Mat.gaussian ~sigma n n)
    ; bh = set (AD.Mat.zeros 1 n)
    ; uh = set (AD.Mat.gaussian ~sigma n n)
    ; uf = set (AD.Mat.gaussian ~sigma:0. n n)
    ; b = Some (set (AD.Mat.gaussian ~sigma:Float.(1. / sqrt (of_int m)) m n))
    }


  let b_rescaled ~prms =
    Option.map prms.b ~f:(function b ->
        let b = Owl_parameters.extract b in
        b)


  let generate_bs ~n ~m =
    match X.n_beg with
    | None -> Array.init 1 ~f:(fun _ -> AD.Mat.zeros 1 1)
    | Some nb ->
      let nr = n / nb in
      let _ = assert (nr = m) in
      Array.init nb ~f:(fun i ->
          let inr = i * nr in
          let rnr = n - ((i + 1) * nr) in
          AD.Maths.(
            transpose
              (concatenate
                 ~axis:0
                 [| AD.Mat.zeros inr m; AD.Mat.eye nr; AD.Mat.zeros rnr m |])))


  let u_eff ~prms =
    match b_rescaled ~prms with
    | None -> fun u -> u
    | Some b -> fun u -> AD.Maths.(u *@ b)


  let dyn ~theta =
    let wh = Owl_parameters.extract theta.wh in
    let bh = Owl_parameters.extract theta.bh in
    let uh = Owl_parameters.extract theta.uh in
    let uf = Owl_parameters.extract theta.uf in
    let n = AD.Mat.col_num bh in
    let m =
      match theta.b with
      | None -> n
      | Some b -> AD.Mat.row_num (Owl_parameters.extract b)
    in
    let beg_bs = generate_bs ~n ~m in
    let u_eff = u_eff ~prms:theta in
    fun ~k ~x ~u ->
      let default =
        let h_pred = x in
        let x = u_eff u in
        let f = sigma AD.Maths.(h_pred *@ uf) in
        let h_hat =
          let hf = AD.Maths.(h_pred * f) in
          AD.Maths.(phi AD.Maths.(bh + (hf *@ uh)) + (x *@ wh))
        in
        AD.Maths.(((F 1. - f) * h_pred) + (f * h_hat))
      in
      match X.n_beg with
      | None -> default
      | Some i -> if k < i then AD.Maths.(x + (u *@ beg_bs.(k))) else default


  let dyn_x =
    let _dyn_x ~theta =
      let wh = Owl_parameters.extract theta.wh in
      let bh = Owl_parameters.extract theta.bh in
      let uh = Owl_parameters.extract theta.uh in
      let uf = Owl_parameters.extract theta.uf in
      let u_eff = u_eff ~prms:theta in
      let n = AD.Mat.col_num bh in
      fun ~k ~x ~u ->
        let default =
          let h_pred = x in
          let x = u_eff u in
          let f_pre = AD.Maths.(h_pred *@ uf) in
          let f = sigma f_pre in
          let h_hat_pre =
            let hf = AD.Maths.(h_pred * f) in
            AD.Maths.(bh + (hf *@ uh))
          in
          let h_hat = AD.Maths.(phi h_hat_pre + (x *@ wh)) in
          AD.Maths.(
            diagm (F 1. - f)
            - (uf * ((h_pred - h_hat) * d_sigma f_pre))
            + (((transpose f * uh) + (uf *@ (transpose (h_pred * d_sigma f_pre) * uh)))
              * (f * d_phi h_hat_pre)))
        in
        match X.n_beg with
        | None -> default
        | Some i -> if k < i then AD.Mat.eye n else default
    in
    (* still wrong... maybe some transposes somewhere? need to check the side of the diagm + the u transposes again... *)
    Some _dyn_x


  let dyn_u =
    let _dyn_u ~theta =
      let wh = Owl_parameters.extract theta.wh in
      let bh = Owl_parameters.extract theta.bh in
      let uf = Owl_parameters.extract theta.uf in
      let b =
        match b_rescaled ~prms:theta with
        | None -> AD.Mat.eye (AD.Mat.col_num bh)
        | Some b -> b
      in
      let m = AD.Mat.row_num b in
      let n = AD.Mat.col_num bh in
      let beg_bs = generate_bs ~n ~m in
      fun ~k ~x ~u:_ ->
        let default =
          let h_pred = x in
          let f_pre = AD.Maths.(h_pred *@ uf) in
          let f = sigma f_pre in
          AD.Maths.(b *@ (wh * f))
        in
        match X.n_beg with
        | None -> default
        | Some i -> if k < i then beg_bs.(k) else default
    in
    Some _dyn_u
end
