library(argparser)
library(magrittr)
library(rstan)
library(pracma)

source("R/simulation.R")
source("R/function/map_estimate.R")
source("R/function/get_cores.R")

pboptions(type="timer")
options(mc.cores = get_cores())
rstan_options(auto_write=T)

parser = arg_parser("Simulation study")
parser %<>% add_argument("--output", help="Output filename")
parser %<>% add_argument("--n", help="Sample size", default=1000)
parser %<>% add_argument("--alpha", help="Damping coefficient", default=0)
parser %<>% add_argument("--mu", help="Drift term", default=0)
parser %<>% add_argument("--sigma", help="State noise magnitude", default=0.1)

args = parser %>% parse_args

fprintf("Generating data...\n")
data = generate_data(args$n, 0, 0, args$alpha, args$sigma, args$mu)

fprintf("Initial E-step...\n")
model = stan_model("stan/opioid_response_model.stan")
ssm_results = pblapply(data, function(x) {
    fit = sampling(model,
                   data = list(
                       k = 11,
                       m = dim(x$opioid_auc)[2],
                       n = dim(x$ssm_data)[1],
                       y = x$ssm_data$pain + 1,
                       u = x$opioid_auc,
                       beta = seq(from=-5, to=4),
                       intervals = x$ssm_data$gap,
                       mu0 = 0,
                       sigma0 = 3.622
                   ),
                   refresh = 0,
                   seed = 123456789
    )
    parameters = rstan::extract(fit)
    return(parameters)
}, cl=1)


a = sapply(data, function(x) x$a)
a_map = sapply(ssm_results, function(x) map_estimate(x$A))

fprintf(
    "Concordance: %f, Kendall: %f\n",
    survival::concordance(y~x, data.frame(y=a, x=a_map))$concordance,
    cor(a, a_map, method="kendall")
)

result = list(
    parameters = list(
        n = args$n,
        alpha = args$alpha,
        sigma = args$sigma,
        mu = args$mu
    ),
    a = sapply(data, function(x) x$a),
    a_map = sapply(ssm_results, function(x) map_estimate(x$A)),
    a_mean = sapply(ssm_results, function(x) mean(x$A))
)

if (is.na(args$output)) {
    args$output = sprintf("data/results/simulation_n_%d_alpha_%g_mu_%g_sigma_%g.rds", args$n, args$alpha, args$mu, args$sigma)
}

saveRDS(
    result,
    args$output
)