using Revise, StatsBase, StatsFuns, SpecialFunctions, Distributions,
  Flux, PyCall, PyPlot, CSV, DataFrames, StaticArrays, ProgressMeter,
  Serialization, LinearAlgebra

using TreatmentCurves


using Random
Random.seed!(1337)


# insert datasets here


beta_entropy(alpha, beta) = (
  logbeta(alpha, beta)  -  (alpha-1) * digamma(alpha)  -  (beta-1) * digamma(beta)
    +  (alpha+beta-2) * digamma(alpha+beta) )

# ontologies are defined on the basis of projections of some real data, herein brain region volumes

potential_outcome_grid = range(0f0, stop=1f0, length=100)

function simulate(dataset, sample_size, order=1, n_cov=1)
  n_dims, n_points = size(dataset)

  # (covariate..., treatment, hidden...)
  projections = randn(Float32, 2n_cov+1, n_dims)

  unnormalized_variables = projections * dataset
  variables = mapslices(v -> v |> ecdf(v), unnormalized_variables, dims=2)

  variables[n_cov+1, :] .*= 2n_cov # upweigh the impact of the treatment compared to the other variables

  quadratic = (x, s) -> mapslices(v -> v' * s * v, x, dims=1)

  if order == 1
    mixing = randn(Float32, 1, 2n_cov+1)
    unnormalized_output = mixing * variables # logit or tan transformation of the unit-interval variables?
  elseif order == 2
    mixing = randn(Float32, 2n_cov+1, 2n_cov+1) # Wishart would just put more constraints on the interaction terms
    # reshape(diag(variables' * mixing * variables), 1, :)
    unnormalized_output = quadratic(variables, mixing)
  end
  output_center = median(unnormalized_output) # half & half
  output_scale = mean(abs, unnormalized_output .- output_center)
  output = (unnormalized_output .- output_center) ./ output_scale
  probabilities = normcdf.(output) # or sigmoid
  # interestingly, with normcdf we seem to get values closer to 0 and 1, and with sigmoid we get closer to the middle
  # tail behaviors! even though the curves don't look so different visually
  # partly stems from the fact that mean absolute deviation is smaller than stdev by a fixed fraction, for normal variables
  outcome = Bernoulli.(probabilities) .|> rand .|> Float32
  mutated_vars = copy(variables)
  potential_outcomes = mapreduce(vcat, potential_outcome_grid) do t # expectations. (grid x points)
    mutated_vars[n_cov+1, :] .= t
    if order == 1
      (mixing * mutated_vars .- output_center) ./ output_scale .|> normcdf
    elseif order == 2
      mixed = quadratic(mutated_vars, mixing)
      (mixed .- output_center) ./ output_scale .|> normcdf
    end
  end

  outcome_instances = map(1:n_points) do i
    input = vcat( variables[1:n_cov, i], variables[n_cov+1, i] ) |> SVector{n_cov+1}
    Instance(i, input, @SVector([ outcome[1, i] ]) )
  end
  nominal_instances = map(1:n_points) do i
    input = variables[1:n_cov, i] |> SVector{n_cov}
    Instance(i, input, @SVector([ variables[n_cov+1, i] ]) )
  end
  complete_instances = map(1:n_points) do i
    input = vcat( variables[1:n_cov, i], variables[(n_cov+1) .+ (1:n_cov), i] ) |> SVector{2n_cov}
    Instance(i, input, @SVector([ variables[n_cov+1, i] ]) )
  end
  
  #n_training = round(Int, 0.75n_points)
  n_testing = round(Int, 0.25sample_size)
  shuffled = shuffle(1:n_points)
  train_indices = shuffled[1:sample_size] |> sort
  test_indices = shuffled[(1:n_testing) .+ sample_size] |> sort

  (; projections, mixing, outcome_instances, nominal_instances, complete_instances,
     train_indices, test_indices, potential_outcomes )
end # uniform marginals

pred_param = BernoulliParam(1f2)
prop_param = BetaParam(1f2)

function infer(data, n_cov=1)
  predictor = [ NeuralPredictor(n_cov+1, 1, n_layers=0, n_inner_dims=0, dropout=0.0f0) for _ in 1:1 ]
  nominal_prop = [ NeuralPredictor(n_cov, 2, n_layers=0, n_inner_dims=0, dropout=0.0f0) for _ in 1:1 ]
  complete_prop = [ NeuralPredictor(2n_cov, 2, n_layers=0, n_inner_dims=0, dropout=0.0f0) for _ in 1:1 ]
  # the bernouli scale basically scales with the learning rate, but I harmonize it with the propensities here
  predictor_t = train_directly!(predictor, data.outcome_instances[data.train_indices],
    pred_param, weights=UnitScheme{Float32}(), n_iterations=50, optimizer=ADAM(1e1),
    batch_size=div(length(data.train_indices), 4), verbose=false, 
    validation_instances=data.outcome_instances[data.test_indices])
  nominal_t = train_directly!(nominal_prop, data.nominal_instances[data.train_indices],
    prop_param, weights=UnitScheme{Float32}(), n_iterations=50, optimizer=ADAM(1e1),
    batch_size=div(length(data.train_indices), 4), verbose=false,
    validation_instances=data.nominal_instances[data.test_indices])
  complete_t = train_directly!(complete_prop, data.complete_instances[data.train_indices],
    prop_param, weights=UnitScheme{Float32}(), n_iterations=50, optimizer=ADAM(1e1),
    batch_size=div(length(data.train_indices), 4), verbose=false,
    validation_instances=data.complete_instances[data.test_indices])
  nominal_info = mean(data.nominal_instances) do instance
    map_prediction(nominal_prop[1], prop_param, instance.input) do (alpha, beta)
      -beta_entropy(alpha, beta) # uniform entropy is 0
    end |> only
  end
  complete_info = mean(data.complete_instances) do instance
    map_prediction(complete_prop[1], prop_param, instance.input) do (alpha, beta)
      -beta_entropy(alpha, beta)
    end |> only
  end
  predictor_fit = mean(data.outcome_instances) do instance
    ensemble_log_likelihood(predictor, pred_param, instance)
  end
  nominal_fit = mean(data.nominal_instances) do instance # over entire huge dataset?
    ensemble_log_likelihood(nominal_prop, prop_param, instance)
  end
  complete_fit = mean(data.complete_instances) do instance
    ensemble_log_likelihood(complete_prop, prop_param, instance)
  end
  (; predictor_t, nominal_t, complete_t, nominal_info, complete_info,
     predictor_fit, nominal_fit, complete_fit )
end

trusts = [
  PolyTrustAdapter(BalancedBetaTrust{Float32}, [0f0, 1f0]) => (;),
  DirectFactorTrust(1f-4) => (;),
  SimpleTrust{Float32}() => (;),
  BernoulliTrust{Float32}() => (; trigger=0.5f0) ]

sensitivities = range(1f0, 2.5f0, length=100)

function integrate_divergences(truth, lower, upper)
  ## can have NaNs
  #@assert lower <= upper
  #@assert 0 < lower
  #@assert upper < 1
  a = lower * log(truth/lower)
  b = upper * log(truth/upper)
  c = (1-lower) * log((1-truth)/(1-lower))
  d = (1-upper) * log((1-truth)/(1-upper))
  (upper-lower) + truth*(a-b) + (1-truth)*(c-d)
end

function evaluate(data, models, n_cov=1, trust_order=2)
  predictor = models.predictor_t.predictors
  nominal_prop =  models.nominal_t.predictors
  treatments = potential_outcome_grid
  intervals = [ bound_potential_outcome(
      identity, data.outcome_instances[data.test_indices],
      predictor, pred_param, nominal_prop, prop_param,
      UnitScheme{Float32}(), trust[1], n_cov+1; n_resamples=1, n_draws=2,
      static_draws=[0f0, 1f0], order=trust_order, generator=Bernoulli(0.5),
      trust_options=trust[2], sensitivity, treatment, verbose=false)[1:2]
    for trust in trusts, sensitivity in sensitivities, treatment in treatments ]
  truth = reshape(mean(data.potential_outcomes[:, data.test_indices], dims=2), 1, 1, :)
  lower, upper = [x[1]|>only for x in intervals], [x[2]|>only for x in intervals]
  coverage = (lower .< truth) .& (upper .> truth)
  ignorance = logit.(upper) .- logit.(lower) # actually, add up (integrate) KL-divergences
  divergence = integrate_divergences.(truth, lower, upper)
  (; coverage, ignorance, divergence ) # (trusts x sensitivities x points)
end

using Base.Threads

function run(dataset; sample_size, order=1, trust_order=2, n_cov=1, n_tests, target_coverage)
  complete_info = zeros(Float32, n_tests)
  nominal_info = zeros(Float32, n_tests)
  fits = zeros(Float32, 3, n_tests)
  score = zeros(Float32, 2, length(trusts), n_tests)
  progress = Progress(n_tests, desc="Running..")
  @threads for t in 1:n_tests
    data = simulate(dataset, sample_size, order, n_cov)
    models = infer(data, n_cov)
    complete_info[t] = models.complete_info
    nominal_info[t] = models.nominal_info
    fits[:, t] = [ models.predictor_fit, models.nominal_fit, models.complete_fit ]
    results = evaluate(data, models, n_cov, trust_order)
    coverage = mean(results.coverage, dims=3)
    ignorance = mean(results.ignorance, dims=3)
    divergence = mean(results.divergence, dims=3)
    ignorance_at_target = mapreduce(hcat, axes(coverage, 1)) do i
      target = findfirst( coverage[i, :] .>= target_coverage )
      [ ignorance[i, isnothing(target) ? end : target],
        divergence[i, isnothing(target) ? end : target] ]
    end
    score[:, :, t] = ignorance_at_target
    next!(progress)
  end
  (; complete_info, nominal_info, fits, score )
end