module Pds

using StaticArrays, Distributions, Random, Functors, Optimisers, LinearAlgebra
using NNlib: logsumexp, softmax, logsoftmax
using Zygote
using ZygoteRules: @adjoint
using PyCall

import Base: -

export Gaussian, Mix, logp, kl, vi, sample, sample!

const runtime = Dict{Symbol, PyObject}()

__init__() = begin
    pushfirst!(PyVector(pyimport("sys")["path"]), @__DIR__)
    demo_fn = pyimport("demo_fn")
    runtime[:logp] = demo_fn.logp
    runtime[:logp_adj] = demo_fn.logp_adj
    runtime[:vi] = demo_fn.vi
    runtime[:vi_adj] = demo_fn.vi_adj
nothing end

abstract type Pd{T} end

struct
Gaussian{T, V, W} <: Pd{T}
   μ::SVector{V, T}
   logσ::SVector{V, T}
   uL::SMatrix{V, V, T, W}
end

struct
Mix{T, P <: Pd{T}} <: Pd{T}
   logα::Vector{T}
   comp::Vector{P}
end

Base.copy(m::Mix) = Mix(copy(m.logα), copy(m.comp))

function
Base.copy!(dst::M, src::M) where {T, P, M <: Mix{T, P}}
   copy!(dst.logα, src.logα)
   copy!(dst.comp, src.comp)
end

function
Gaussian{T}(n::Integer) where {T}
   Gaussian{T, n, n^2}(
      (@SVector zeros(T, n)),
      (@SVector zeros(T, n)),
      SMatrix{n, n, T, n^2}(I))
end

function
Gaussian(μ::AbstractVector{T}) where {T}
   n = length(μ)
   Gaussian{T, n, n^2}(
      SVector{n}(μ), (@SVector zeros(T, n)),
      SMatrix{n, n, T, n^2}(I))end

Functors.@functor Gaussian
Functors.@functor Mix
Optimisers.trainable(g::Gaussian) = (;μ=g.μ, logσ=g.logσ, uL=g.uL)
Optimisers.trainable(m::Mix) = (;logα=m.logα, comp=m.comp)

function
-(g₁::G, g₂::G)::G where {T, V, W, G <: Gaussian{T, V, W}}
   Gaussian(g₁.μ - g₂.μ, g₁.logσ - g₂.logσ, g₁.uL - g₂.uL)
end
function
-(m₁::M, m₂::M)::M where {T, V, W, M <: Mix{T, Gaussian{T, V, W}}}
   Mix(m₁.logα .- m₂.logα, m₁.comp .- m₂.comp)
end

Base.convert(::Type{LowerTriangular{T, MMatrix{V,V,T,W}}}, m::MMatrix{V,V,T,W}) where {V,T,W} = begin
   LowerTriangular(m)
end

function
logp(g::Gaussian{T}, x::AbstractVecOrMat{T})::T where {T}
   -length(x) * log(T(2)π) / 2 - sum(g.logσ .+ T(0.5) .* (inv(UnitLowerTriangular(g.uL)) * (x .- g.μ)).^2 .* exp.(-2 .* g.logσ))
end

function
logp_f32(::Type{T}, ∇logα, ∇μᵥ, ∇logσᵥ, ∇uL⁻¹ᵥ, x) where {T}
   convert(T, runtime["logp"](∇logα, PyReverseDims(∇μᵥ), PyReverseDims(∇logσᵥ), PyReverseDims(∇uL⁻¹ᵥ), PyReverseDims(x)))
end

@adjoint function
logp_f32(::Type{T}, ∇logα, ∇μ, ∇logσ, ∇uL⁻¹, x) where {T}
   y, adj = runtime[:logp_adj](∇logα, PyReverseDims(∇μ), PyReverseDims(∇logσ), PyReverseDims(∇uL⁻¹), PyReverseDims(x))

   convert(T, y), g -> let; local (∇logα, ∇μ, ∇logσ, ∇uL⁻¹, _) = adj(g)
      nothing, convert(Vector{T}, ∇logα), convert(Matrix{T}, ∇μ), convert(Matrix{T}, ∇logσ), convert(Array{T,3}, ∇uL⁻¹), nothing
   end
end

inv_uL(x) = inv(UnitLowerTriangular(x))

function
logp(m::Mix{T, Gaussian{T, V, W}}, x::AbstractMatrix{T}; f32=true)::T where {T, V, W}
   f32 && return logp_f32(T,
      convert(Vector{Float32}, m.logα),
      convert(Matrix{Float32}, stack(map(c -> c.μ, m.comp); dims=2)),
      convert(Matrix{Float32}, stack(map(c -> c.logσ, m.comp); dims=2)),
      convert(Array{Float32, 3}, stack(map(c -> inv_uL(c.uL), m.comp); dims=3)),
      convert(Matrix{Float32}, x))

   sum(logsumexp(logsoftmax(m.logα) + stack(map(g -> logp_each(g, x), m.comp); dims=1); dims=1))
end

function
logp_each(g::Gaussian{T}, x::AbstractMatrix{T})::Vector{T} where {T}
   -size(x, 1) * log(T(2)π) / 2 .- dropdims(sum(g.logσ .+ T(0.5) .* (inv(UnitLowerTriangular(g.uL)) * (x .- g.μ)).^2 .* exp.(-2 .* g.logσ); dims=1); dims=1)
end

function
logp_each(m::Mix{T}, x::AbstractMatrix{T})::Vector{T} where {T}
   dropdims(logsumexp(logsoftmax(m.logα) .+ stack(map(g -> logp_each(g, x), m.comp); dims=1); dims=1); dims=1)
end

# -length(x) * log(T(2)π) / 2
entropy(g::Gaussian{T}) where {T} = length(g.μ) * (log(T(2)π) + 1) / 2  + sum(g.logσ)

function
logintprod(g₁::Gaussian{T}, g₂::Gaussian{T})::T where {T}
   d = size(g₁.μ, 1)
   μ₁, logσ₁ = g₁.μ, g₁.logσ
   μ₂, logσ₂ = g₂.μ, g₂.logσ
   L₁ = UnitLowerTriangular(g₁.uL) * Diagonal(exp.(logσ₁))
   L₂ = UnitLowerTriangular(g₂.uL) * Diagonal(exp.(logσ₂))
   Σ⁻¹ = inv(L₁ * L₁' + L₂ * L₂')
   μd = μ₂ - μ₁

   - (d * (log(T(2)π))) / 2 + T(0.5) * logdet(Σ⁻¹) - T(0.5) * (μd' * Σ⁻¹ * μd)
end

function
kl(g₁::Gaussian{T}, g₂::Gaussian{T})::T where {T}
   d = size(g₁.μ, 1)
   μ₁, logσ₁ = g₁.μ, g₁.logσ
   μ₂, logσ₂ = g₂.μ, g₂.logσ
   L₁ = UnitLowerTriangular(g₁.uL) * Diagonal(exp.(logσ₁))
   L₂⁻¹ = Diagonal(exp.(-logσ₂)) * inv(UnitLowerTriangular(g₂.uL))

   2 \ (- d + tr(L₂⁻¹' * L₂⁻¹ * L₁ * L₁') + sum(2 .* (logσ₂ .- logσ₁) .+ (L₂⁻¹ * (μ₂ .- μ₁)).^2))
end

function
kl_mc(m₁::M, m₂::M; n::Integer=1000)::T where {T, V, W, M <: Mix{T, Gaussian{T, V, W}}}
   x = sample(m₁, n)
   mean(logp_each(m₁, x) .- logp_each(m₂, x))
end

function
vi_f32(::Type{T}, logα₁, logα₂, μ₁, μ₂, logσ₁, logσ₂, uL₁, uL⁻¹₁, uL⁻¹₂, x) where {T}
   convert(T, runtime[:vi](logα₁, logα₂, PyReverseDims(μ₁), PyReverseDims(μ₂), PyReverseDims(logσ₁), PyReverseDims(logσ₂), PyReverseDims(uL₁), PyReverseDims(uL⁻¹₁), PyReverseDims(uL⁻¹₂)), PyReverseDims(x))
end

@adjoint function
vi_f32(::Type{T}, logα₁, logα₂, μ₁, μ₂, logσ₁, logσ₂, uL₁, uL⁻¹₁, uL⁻¹₂, x) where {T}
   y, adj = runtime[:vi_adj](logα₁, logα₂, PyReverseDims(μ₁), PyReverseDims(μ₂), PyReverseDims(logσ₁), PyReverseDims(logσ₂), PyReverseDims(uL₁), PyReverseDims(uL⁻¹₁), PyReverseDims(uL⁻¹₂), PyReverseDims(x))

   convert(T, y), g -> let; local (∇logα₁, _, ∇μ₁, _, ∇logσ₁, _, ∇uL₁, ∇uL⁻¹₁, _, _) = adj(g)
      nothing, convert(Vector{T}, ∇logα₁), nothing, convert(Matrix{T}, ∇μ₁), nothing, convert(Matrix{T}, ∇logσ₁), nothing, convert(Array{T, 3}, ∇uL₁), convert(Array{T, 3}, ∇uL⁻¹₁), nothing, nothing
   end
end

function
vi(m₁::M, m₂::M, x::AbstractArray{T}; f32=true)::T where {T, V, W, M <: Mix{T, Gaussian{T, V, W}}}
   vi_f32(T,
      convert(Vector{Float32}, m₁.logα),
      convert(Vector{Float32}, m₂.logα),
      convert(Matrix{Float32}, stack(map(c -> c.μ, m₁.comp); dims=2)),
      convert(Matrix{Float32}, stack(map(c -> c.μ, m₂.comp); dims=2)),
      convert(Matrix{Float32}, stack(map(c -> c.logσ, m₁.comp); dims=2)),
      convert(Matrix{Float32}, stack(map(c -> c.logσ, m₂.comp); dims=2)),
      convert(Array{Float32, 3}, stack(map(c -> c.uL, m₁.comp); dims=3)),
      convert(Array{Float32, 3}, stack(map(c -> inv_uL(c.uL), m₁.comp); dims=3)),
      convert(Array{Float32, 3}, stack(map(c -> inv_uL(c.uL), m₂.comp); dims=3)),
      convert(Array{Float32, 3}, x))
end

function
merge_mixtures(ms::AbstractVector{Mix{T, P}})::Mix{T, P} where {T, P}
   Mix{T, P}(
      mapreduce(m -> m.logα, vcat, ms; init=T[]),
      mapreduce(m -> m.comp, vcat, ms; init=P[]))
end

function
sample(g::Gaussian{T}) where {T}
   x = randn(T, size(g.μ));
   x .= g.μ .+ g.uL * (exp.(g.logσ) .* x)
end

function
sample(g::Gaussian{T}, n::Integer) where {T}
   x = randn(T, length(g.μ), n);
   x .= g.μ .+ g.uL * (exp.(g.logσ) .* x)
end

function
sample(m::Mix{T}, n::Integer) where {T}
   stacked = stack(sample.(m.comp, n); dims=2)
   # println(stacked |> size)
   result = Matrix{T}(undef, size(stacked, 1), n)
   # indices = Vector{Int32}(undef, n)
   cdist = Categorical(softmax(m.logα))
   # rand(cdist) |> println
   for i ∈ 1:n; 
      # indices[i] = rand(cdist)
      result[:, i] = stacked[:, rand(cdist), i]
   end
   # stacked[:, indices, 1:n] |> size |> println
   # result .= stacked[:, indices, 1:n]
   result
end

sample!(x, g::Gaussian) = begin
   x .= g.μ .+ g.uL * (exp.(g.logσ) .* randn!(x))
end

sample!(x, m::Mix) = begin
   cdist = Categorical(softmax(m.logα))
   foreach(x -> sample!(x, m.comp[rand(cdist)]), eachcol(x)); x
end

function
sample_comp(m::Mix{T}, n::Integer) where {T}
   out = Array{T}(undef, size(m.comp[1].μ, 1), length(m.logα), n)
   foreach(sample!, eachslice(out; dims=2), m.comp); out
end

end
