# Exponential families
include("binary_search.jl");

struct Gaussian
    σ2;
end

# Default Gaussian, simplify computations
Gaussian() = Gaussian(1);

# Sample
sample(rng, expfam::Gaussian, μ) = μ + sqrt(expfam.σ2)*randn(rng);

# KL divergence
d(expfam::Gaussian, μ, λ) = (μ-λ)^2/(2*expfam.σ2);

# KL derivatives
dµ_d(expfam::Gaussian, μ, λ) = (µ-λ)/expfam.σ2;
dλ_d(expfam::Gaussian, μ, λ) = (λ-µ)/expfam.σ2;

# upward and downward confidence intervals (box confidence region)
dup(expfam::Gaussian, μ, v) = μ + sqrt(2*expfam.σ2*v);
ddn(expfam::Gaussian, μ, v) = μ - sqrt(2*expfam.σ2*v);


struct Bernoulli
end

rel_entr(x, y) = x == 0 ? 0. : x * log(x / y);
dx_rel_entr(x, y) = x == 0 ? 0. : log(x / y);
dy_rel_entr(x, y) = -x / y;

d(expfam::Bernoulli, μ, λ) = max(0, rel_entr(μ, λ) + rel_entr(1 - μ, 1 - λ));
dµ_d(expfam::Bernoulli, μ, λ) = dx_rel_entr(μ, λ) - dx_rel_entr(1 - μ, 1 - λ);
dλ_d(expfam::Bernoulli, μ, λ) = dy_rel_entr(μ, λ) - dy_rel_entr(1 - μ, 1 - λ);
invh(expfam::Bernoulli, μ, x) = 2 * μ / (1 - x + sqrt((x - 1)^2 + 4 * x * μ));
sample(rng, expfam::Bernoulli, μ) = rand(rng) ≤ μ;

function dup(expfam::Bernoulli, μ, v)
    μ == 1 ? 1. : binary_search(λ -> d(expfam, μ, λ) - v, μ, 1);
end

function ddn(expfam::Bernoulli, μ, v)
    μ == 0 ? 0. : binary_search(λ -> v - d(expfam, μ, λ), 0, μ);
end

```
Specifying variance
```

# Increasing factor
C_cst(μ, σ2, astar, a) = (μ[astar] - μ[a])^2 / (min(σ2[a], σ2[astar]));
C_cst(μ, σ2, astar) = maximum([C_cst(μ, σ2, astar, a) for a in 1:length(μ) if a != astar]);
increasing_factor(μ, σ2, astar, a) = C_cst(μ, σ2, astar, a)/log(1 + C_cst(μ, σ2, astar, a));
increasing_factor(μ, σ2, astar) = maximum([increasing_factor(μ, σ2, astar, a) for a in 1:length(μ) if a != astar]);

# Additive term
additive_term(N, hμ, hσ2, astar, a) = (hμ[astar] - hμ[a])^4 * max(N[a] / hσ2[a]^2, N[astar] / hσ2[astar]^2) / 2;

# KL with variance specified
dUV(μ, σ2, λ) = 0.5 * log(1 + (μ - λ)^2 / σ2);
dKV(μ, σ2, λ) = 0.5 * (μ - λ)^2 / σ2;
d(μ, σ2, λ, known_var) = known_var ? dKV(μ, σ2, λ) : dUV(μ, σ2, λ);

# upward and downward confidence intervals (box confidence region)
dupKV(μ, σ2, v) = μ + sqrt(2 * σ2 * v);
dupUV(μ, σ2, v) = μ + sqrt(2 * σ2 * v * (1 + 2 * v));
dup(μ, σ2, v, known_var) = known_var ? dupKV(μ, σ2, v) : dupUV(μ, σ2, v);
ddnKV(μ, σ2, v) = μ - sqrt(2 * σ2 * v);
ddnUV(μ, σ2, v) = μ - sqrt(2 * σ2 * v * (1 + 2 * v));
ddn(μ, σ2, v, known_var) = known_var ? ddnKV(μ, σ2, v) : ddnUV(μ, σ2, v);
