using Revise, StatsBase, StatsFuns, SpecialFunctions, Distributions,
  Flux, PyCall, PyPlot, CSV, DataFrames, StaticArrays, ProgressMeter,
  Serialization, LinearAlgebra

using TreatmentCurves


binary_variables = [ "sex", "college", "sleep_7to9_hrs", "daytime_dozing_never", "night_shift_work",
  "whr_normal", "vitamin_mineral_intake", "supp_fish_oil", "never_varies_diet", "leisure_social_sports_or_gym",
  "leisure_social_pub_or_social", "leisure_social_religious_group", "leisure_social_adult_edu_class",
  "leisure_social_other_group", "friend_family_visits_2x_a_week_or_more", "able_to_confide_daily" ]

ordinal_variables = ["e4_copies" => [0, 1, 2], "smoking" => ["Never", "Previous", "Current"],
  "alcohol" => ["Infrequent", "Occasional", "Frequent"], "aha_physical_activity" => ["poor", "intermediate", "ideal"],
  "salt_added_to_food" => ["never_rarely", "sometimes", "usually", "always"], "variation_in_diet" => ["never_rarely", "sometimes", "often"] ]

real_variables = [ "townsend", "age_at_scan", "bmi", "body_fat_percentage", "waist_to_hip_ratio", "tea_intake", "coffee_intake", "water_intake" ]

bounded_variables = [ "dqs_comp_1_fruits", "dqs_comp_2_vegetables", "dqs_comp_3_whole_grains", "dqs_comp_4_fish", "dqs_comp_5_dairy",
  "dqs_comp_6_vegetable_oil", "dqs_comp_7_refined_grains", "dqs_comp_8_processed_meat", "dqs_comp_9_unprocessed_meat", "dqs_comp_10_sugary_foods_drinks" ]

cortical_outputs = [ "AVG_bankssts_thickavg", "AVG_caudalanteriorcingulate_thickavg", "AVG_caudalmiddlefrontal_thickavg",
  "AVG_cuneus_thickavg", "AVG_entorhinal_thickavg", "AVG_fusiform_thickavg", "AVG_inferiorparietal_thickavg", "AVG_inferiortemporal_thickavg",
  "AVG_isthmuscingulate_thickavg", "AVG_lateraloccipital_thickavg", "AVG_lateralorbitofrontal_thickavg", "AVG_lingual_thickavg",
  "AVG_medialorbitofrontal_thickavg", "AVG_middletemporal_thickavg", "AVG_parahippocampal_thickavg", "AVG_paracentral_thickavg",
  "AVG_parsopercularis_thickavg", "AVG_parsorbitalis_thickavg", "AVG_parstriangularis_thickavg", "AVG_pericalcarine_thickavg",
  "AVG_postcentral_thickavg", "AVG_posteriorcingulate_thickavg", "AVG_precentral_thickavg", "AVG_precuneus_thickavg",
  "AVG_rostralanteriorcingulate_thickavg", "AVG_rostralmiddlefrontal_thickavg", "AVG_superiorfrontal_thickavg", "AVG_superiorparietal_thickavg",
  "AVG_superiortemporal_thickavg", "AVG_supramarginal_thickavg", "AVG_frontalpole_thickavg", "AVG_temporalpole_thickavg",
  "AVG_transversetemporal_thickavg", "AVG_insula_thickavg" ]
  
subcortical_outputs = [ "AVG_Lateral_Ventricle", "AVG_Inf_Lat_Vent", "AVG_Cerebellum_White_Matter",
  "AVG_Cerebellum_Cortex", "AVG_Thalamus", "AVG_Caudate", "AVG_Putamen", "AVG_Pallidum", "AVG_Hippocampus", "AVG_Amygdala",
  "AVG_Accumbens_area", "AVG_VentralDC", "AVG_choroid_plexus" ]

other_outputs = [ "brainage_bc" ]

## NOTE: the actual UK Biobank data cannot be included here due to licensing issues. 
# normally, this would be loaded as a DataFrame into a `df` variable

inputs = hcat(
  mapreduce(hcat, binary_variables) do name
    df[:, name] .|> Float32
  end,
  mapreduce(hcat, ordinal_variables) do (name, categories)
    col = df[:, name]
    indices = indexin(col, categories) .|> Int
    indices |> ecdf(indices) .|> Float32
  end,
  mapreduce(hcat, real_variables) do name
    col = df[:, name] .|> Float32
    zscore(col)
  end,
  mapreduce(hcat, bounded_variables) do name
    col = df[:, name] .|> Float32
    (col .- minimum(col)) ./ (maximum(col) - minimum(col))
  end ) |> permutedims

outputs = hcat(
  mapreduce(hcat, cortical_outputs) do name
    col = df[:, name] .|> Float32
    zscore(col)
  end,
  mapreduce(hcat, subcortical_outputs) do name
    col = df[:, name] ./ df[:, "ICV"]
    zscore(col) .|> Float32
  end,
  mapreduce(hcat, other_outputs) do name
    col = df[:, name] .|> Float32
    zscore(col)
  end ) |> permutedims

n_inputs, n_patients = size(inputs)
n_outputs = size(outputs, 1)

instances = map(1:n_patients) do i
  Instance(i,
    SVector{n_inputs}(inputs[:, i]),
    SVector{n_outputs}(outputs[:, i]))
end

treatment_indices = vcat(
  n_inputs - 10 .+ [1, 2, 3, 4, 5, 6, 7, 9], # DQSs
  )
covariate_indices = setdiff(1:n_inputs, treatment_indices) |> sort
n_treatments = length(treatment_indices)

treatments = map(1:n_patients) do i
  Instance(i,
    SVector{n_inputs-n_treatments}(inputs[covariate_indices, i]),
    SVector{n_treatments}(inputs[treatment_indices, i]))
end

treatment_indices_2 = n_inputs - 10 .- [10, 5, 4, 3] # exercise, bmi, body fat, waist-hip
covariate_indices_2 = setdiff(1:n_inputs, treatment_indices_2) |> sort
n_treatments_2 = length(treatment_indices_2)

treatments_2 = map(1:n_patients) do i
  Instance(i,
    SVector{n_inputs-n_treatments_2}(inputs[covariate_indices_2, i]),
    SVector{n_treatments_2}(inputs[treatment_indices_2, i]))
end

treatment_indices_3 = n_inputs - 10 .- [11, 9, 2, 1, 0] # alcohol, salt, tea, coffee, water
covariate_indices_3 = setdiff(1:n_inputs, treatment_indices_3) |> sort
n_treatments_3 = length(treatment_indices_3)

treatments_3 = map(1:n_patients) do i
  Instance(i,
    SVector{n_inputs-n_treatments_3}(inputs[covariate_indices_3, i]),
    SVector{n_treatments_3}(inputs[treatment_indices_3, i]))
end


using Random
Random.seed!(1337)


indices = shuffle(1:n_patients)
n_splits = 4
splits = round.(Int, range(0, stop=n_patients, length=n_splits+1))
test_indices = [ indices[(splits[i]+1) : splits[i+1]] |> sort for i in 1:n_splits ]
train_indices = [ setdiff(1:n_patients, test_indices[i]) |> sort for i in 1:n_splits ]

# predictors = [NeuralPredictor(40, 48*2, n_layers=2, n_inner_dims=32, dropout=0.1f0) for _ in 1:16, _ in 1:4]
# pred_t = [ train_directly!(@view(predictors[:,i]), instances[train_indices[i]], GaussianParam{Float32}(), weights=BootstrapScheme{Float32}(), n_iterations=10_000, optimizer=ADAM(5e-3), batch_size=div(length(train_indices[i]), 10), verbose=true, validation_instances=instances[test_indices[i]], reg_weight=1f-3) for i in 1:4 ]

# linear_p = [NeuralPredictor(32, 8*2, n_layers=0, n_inner_dims=8, dropout=0.05f0) for _ in 1:16, _ in 1:4]
# prop_t = [ train_directly!(@view(linear_p[:,i]), treatments[train_indices[i]], BetaParam(64f0), weights=BootstrapScheme{Float32}(), n_iterations=10_000, optimizer=ADAM(1e-2), batch_size=div(length(train_indices[i]), 10), verbose=true, validation_instances=treatments[test_indices[i]], reg_weight=1f-3) for i in 1:4 ]
