library(magrittr)
library(pbapply)

source("R/function/get_cores.R")

pboptions(type="timer")

sigmoid = function(z) 1/(1+exp(-z))

sample_ordinal = function(z, beta) {
    cutoffs = c(sigmoid(beta - z), 1)
    min(which(runif(1) < cutoffs)) - 1
}

observation_times = function(t, rate) {
    result = NULL
    
    while (sum(result) < t) {
        result = c(result, rexp(1, rate))
    }
    return(cumsum(result)[-length(result)])
}

#' Runge-Kutta for SDEs
rk = function(x0, u, start_t, end_t, alpha, a, mu, sigma, n=100) {
    dt = (end_t - start_t) / n
    t = seq(from=start_t, to=end_t, by=dt)
    x = x0
    
    for (i in seq_len(n)) {
        dw = rnorm(1, sd=sqrt(dt))
        x = x + dt * (-alpha * x - a * u(t[i])$y + mu) + sigma*dw
    }
    return(x)
}

simulation = function(a, x0, covariates, alpha, sigma, mu, min_obs = 10, min_opioids = 2) {
    age = rnorm(1, 54.90201, 17.25901)              # Age in years
    height = rnorm(1, 66.60628, 4.265764) * 2.54    # Height in cm
    weight = rnorm(1, 88.1026, 20.87864)            # Weight in kg
    sex = ifelse(runif(1) < 0.47, "Male", "Female")
    
    repeat {
        obs_t = observation_times(1440, 10/1440)
        opioid_t = observation_times(1440, 5/1440)
        if (length(obs_t) > min_obs && length(opioid_t) > min_opioids) {
            break
        }
    }
    
    dose_table = data.frame(
        Drug = "hydromorphone",
        Time = opioid_t,
        Dose = 0.5,
        Units = "mg"
    )
    
    eventTable <- data.frame(Time = double(), Event = character(), Fill = character())
    output <- stanpumpR::simulateDrugsWithCovariates(dose_table, eventTable, weight, height, age, sex, 1441, F)
    opioid_data = output[[1]]$results %>% dplyr::filter(Site == "Effect Site")
    
    # covariates = rnorm(9)
    # x0 = sum(covariates) * sqrt(10 * frac_x0 / 9) + rnorm(1, sd=sqrt(10 * (1-frac_x0)))
    # 
    # total_var_log_a = 2.25
    # log_a = -5 - sum(covariates) * sqrt(2.25 * frac_a / 9) + rnorm(1, sd=sqrt(2.25 * (1-frac_a)))
    # a = exp(log_a)
    
    u = function(t) approx(opioid_data$Time, opioid_data$Y, t, yleft=0, yright=0)
    x = rep(NA, length(obs_t))
    
    start_times = c(0, obs_t[-length(obs_t)])
    end_times = obs_t
    
    current_x = x0
    for (i in seq_along(obs_t)) {
        current_x = rk(current_x, u, start_times[i], end_times[i], alpha=alpha, a=a, sigma=sigma, mu=mu)
        x[i] = current_x
    }
    
    y = sapply(x, function(z) sample_ordinal(z, seq(from=-5, to=4)))
    
    opioid_auc = opioid_data %>% (function(x) {
        # Compute trapezoidal riemann sum
        # Interval boundaries
        t = sort(unique(c(0, obs_t, x$Time)))
        values = approx(x$Time, x$Y, t, yleft = 0, yright=0)$y
        widths = t[-1] - t[-length(t)]
        means = (values[-1] + values[-length(t)]) / 2
        integrals = widths * means # Equal to integral from t_i to t_{i+1}
        
        end_indices = sapply(obs_t, function(y) which(t==y))
        start_indices = c(1, end_indices[-length(end_indices)]+1)
        
        return(
            mapply(function(a,b) sum(c(0, integrals)[a:b]), start_indices, end_indices)
        )
    }) %>% array(dim=c(length(.), 1))
    
    ssm_data = data.frame(
        t = obs_t,
        pain = y,
        gap = obs_t - c(0, obs_t[-length(obs_t)])
    )
    
    return(
        list(
            a = a,
            x = c(x0, x),
            ssm_data = ssm_data,
            opioid_auc = opioid_auc,
            covariates = covariates,
            opioid_t = opioid_t,
            age = age,
            height = height,
            weight = weight,
            sex = sex
        )
    )
}

generate_data = function(n = 1000, frac_x0, frac_a, alpha, sigma, mu, min_obs = 10, min_opioids = 2) {
    set.seed(1234567890)
    covariates = rnorm(n * 9) %>% array(dim=c(n, 9))
    x0 = rowSums(covariates) * sqrt(5 * frac_x0 / 9) + rnorm(n, sd = sqrt(5 * (1-frac_x0)))
    log_a = -5 - rowSums(covariates) * sqrt(2.25 * frac_a / 9) + rnorm(n, sd=sqrt(2.25 * (1-frac_a)))
    a = exp(log_a)
    
    data = pblapply(seq_len(n), function(i) {
        set.seed(1234567890+i)
        repeat {
            x = simulation(a[i], x0[i], covariates[i,], alpha=alpha, sigma=sigma, mu=mu, min_obs=min_obs, min_opioids=min_opioids)
            if (length(unique(x$ssm_data$pain)) > 1 && sum(x$opioid_auc) > 0 && sum(x$ssm_data$pain != 0 & x$ssm_data$pain != 10) > 2) {
                return(x)
            }
        }
    }, cl = 1)
    
    return(data)
}

