library(argparser)
library(magrittr)
library(rstan)
library(pracma)

source("R/simulation.R")
source("R/function/map_estimate.R")
source("R/function/get_cores.R")

pboptions(type="timer")
options(mc.cores = get_cores())
rstan_options(auto_write=T)

parser = arg_parser("Ablation study study")
parser %<>% add_argument("--output", help="Output filename")
parser %<>% add_argument("--n", help="Sample size", default=250)
parser %<>% add_argument("--em-iter", help="Number of EM iterations", default=5)
parser %<>% add_argument("--alpha", help="Damping coefficient", default=0.0)
parser %<>% add_argument("--mu", help="Drift term", default=0)
parser %<>% add_argument("--sigma", help="State noise magnitude", default=0.3)
parser %<>% add_argument("--rx", help="Fraction of variance in x_0 explained by c", default=0)
parser %<>% add_argument("--ra", help="Fraction of variance in a explained by c", default=0)

args = parser %>% parse_args

fprintf("Generating data...\n")
data = generate_data(n=args$n, frac_x0=args$rx, frac_a=args$ra, alpha=args$alpha, sigma=args$sigma, mu=args$mu)
covariates = lapply(data, function(x) x$covariates) %>% do.call(rbind, .)
a = sapply(data, function(x) x$a)
x0 = sapply(data, function(x) x$x[1])

fprintf("Initial E-step...\n")
model = stan_model("stan/opioid_response_model.stan")
ssm_results = pblapply(data, function(x) {
    fit = sampling(model,
                   data = list(
                       k = 11,
                       m = dim(x$opioid_auc)[2],
                       n = dim(x$ssm_data)[1],
                       y = x$ssm_data$pain + 1,
                       u = x$opioid_auc,
                       beta = seq(from=-5, to=4),
                       intervals = x$ssm_data$gap,
                       mu0 = 0,
                       sigma0 = 3.622
                   ),
                   refresh = 0,
                   seed = 123456789
    )
    parameters = rstan::extract(fit)
    return(parameters)
}, cl=1)

a_map = sapply(ssm_results, function(x) map_estimate(x$A))

estimates = list(list(
    iter = 0,
    a_map = a_map,
    a_mean = sapply(ssm_results, function(x) mean(x$A)),
    logL = sapply(ssm_results, function(x) x$lp %>% mean) %>% mean
))

fprintf(
    "Concordance: %f, Kendall: %f, logL: %f\n",
    survival::concordance(y~x, data.frame(y=a, x=a_map))$concordance,
    cor(a, a_map, method="kendall"),
    sapply(ssm_results, function(x) x$lp %>% mean) %>% mean
)


init_model = stan_model("stan/opioid_response_model_priors.stan")
for (iteration in seq_len(args$em_iter)) {
    fprintf("EM iter %d of %d...\n", iteration, args$em_iter)
    
    # M-step
    x0_mean = sapply(ssm_results, function(x) x$x[,1] %>% mean)
    x0_var = sapply(ssm_results, function(x) x$x[,1] %>% var)
    a_log_mean = sapply(ssm_results, function(x) x$A %>% log %>% mean)
    a_log_var = sapply(ssm_results, function(x) x$A %>% log %>% var)
    
    x0_mean_model = glm(
        x0 ~ ., data=data.frame(
            x = covariates,
            x0 = x0_mean
        )
    )
    x0_mean_pred = predict(x0_mean_model)
    x0_residual = x0_mean_pred - x0_mean
    
    a_log_mean_model = glm(
        a ~ .,
        data=data.frame(
            x = covariates,
            a = a_log_mean
        )
    )
    a_log_mean_pred = predict(a_log_mean_model)
    a_log_residual = a_log_mean_pred - a_log_mean
    
    x0_sd_pred = sqrt(mean(x0_var + x0_residual^2))
    a_log_sd_pred = sqrt(mean(a_log_var + a_log_residual^2))
        
    ssm_results = pblapply(seq_along(data), function(index) {
        fit = sampling(init_model,
                       data = list(
                           k = 11,
                           m = dim(data[[index]]$opioid_auc)[2],
                           n = dim(data[[index]]$ssm_data)[1],
                           y = data[[index]]$ssm_data$pain + 1,
                           u = data[[index]]$opioid_auc,
                           beta = seq(from=-5, to=4),
                           intervals = data[[index]]$ssm_data$gap,
                           mu0 = x0_mean_pred[index],
                           sigma0 = x0_sd_pred,
                           mu_log_a = a_log_mean_pred[index],
                           sigma_log_a = a_log_sd_pred
                       ),
                       refresh = 0,
                       seed = 123456789
        )
        parameters = rstan::extract(fit)
        return(parameters)
    }, cl=1)
    
    a_map = sapply(ssm_results, function(x) map_estimate(x$A))
    
    fprintf(
        "Concordance: %f, Kendall: %f, logL: %f\n",
        survival::concordance(y~x, data.frame(y=a, x=a_map))$concordance,
        cor(a, a_map, method="kendall"),
        sapply(ssm_results, function(x) x$lp %>% mean) %>% mean
    )
    
    estimates = c(
        estimates,
        list(
            list(
                iter = iteration,
                a_map = a_map,
                a_mean = sapply(ssm_results, function(x) mean(x$A)),
                logL = sapply(ssm_results, function(x) x$lp %>% mean) %>% mean
            )
        )
    )
    
}

result = list(
    parameters = list(
        n = args$n,
        em_iter = args$em_iter,
        alpha = args$alpha,
        sigma = args$sigma,
        mu = args$mu,
        rx = args$rx,
        ra = args$ra
    ),
    covariates = covariates,
    a = sapply(data, function(x) x$a),
    x0 = sapply(data, function(x) x$x[1]),
    estimates = estimates
)

if (is.na(args$output)) {
    args$output = sprintf("data/results/ablation_n_%d_em_%d_alpha_%g_mu_%g_sigma_%g_rx_%g_ra_%g.rds", args$n, args$em_iter, args$alpha, args$mu, args$sigma, args$rx, args$ra)
}

saveRDS(
    result,
    args$output
)