## ----setup, include = FALSE---------------------------------------------------
knitr::opts_chunk$set(
  collapse = TRUE,
  comment = "#>"
)

library(poissonsuperlearner)
library(riskRegression)

## -----------------------------------------------------------------------------
d <- simulateStenoT1(
  n = 45,
  scenario = "alpha",
  competing_risks = TRUE,
  seed = 1
)

d <- d[, .(
  id,
  time,
  event,
  sex,
  age,
  diabetes_duration,
  value_LDL,
  value_Smoking
)]

head(d)

## -----------------------------------------------------------------------------
shared_library <- list(
  simple = Learner_glmnet(
    covariates = c("sex", "diabetes_duration"),
    cross_validation = FALSE,
    lambda = 0
  ),
  shrink = Learner_glmnet(
    covariates = c("sex", "age", "value_LDL"),
    cross_validation = FALSE,
    lambda = 0.05,
    alpha = 1
  )
)

fit_shared <- Superlearner(
  data = d,
  id = "id",
  status = "event",
  event_time = "time",
  learners = shared_library,
  number_of_nodes = 3,
  nfold = 2
)

## -----------------------------------------------------------------------------
summary(fit_shared)

## -----------------------------------------------------------------------------
newdata <- d[1:2]
times <- c(1, 2)

risk_shared_sl <- predictRisk(
  fit_shared, newdata = newdata, times = times, cause = 1, model = "sl"
)

risk_shared_discrete <- predictRisk(
  fit_shared, newdata = newdata, times = times, cause = 1, model = "discrete_sl"
)

risk_shared_simple <- predictRisk(
  fit_shared, newdata = newdata, times = times, cause = 1, model = "simple"
)

risk_shared_shrink <- predictRisk(
  fit_shared, newdata = newdata, times = times, cause = 1, model = "shrink"
)

list(
  sl = risk_shared_sl,
  discrete_sl = risk_shared_discrete,
  simple = risk_shared_simple,
  shrink = risk_shared_shrink
)

## -----------------------------------------------------------------------------
libraries_by_cause <- list(
  cvd = list(
    cvd_simple = Learner_glmnet(
      covariates = c("sex", "diabetes_duration"),
      cross_validation = FALSE,
      lambda = 0
    ),
    cvd_shrink = Learner_glmnet(
      covariates = c("age", "value_LDL"),
      cross_validation = FALSE,
      lambda = 0.05,
      alpha = 1
    )
  ),
  death = list(
    death_simple = Learner_glmnet(
      covariates = c("sex", "age"),
      cross_validation = FALSE,
      lambda = 0
    ),
    death_shrink = Learner_glmnet(
      covariates = c("diabetes_duration", "value_Smoking"),
      cross_validation = FALSE,
      lambda = 0.05,
      alpha = 1
    )
  )
)

fit_by_cause <- Superlearner(
  data = d,
  id = "id",
  status = "event",
  event_time = "time",
  learners = libraries_by_cause,
  number_of_nodes = 3,
  nfold = 2
)

## -----------------------------------------------------------------------------
summary(fit_by_cause)

## -----------------------------------------------------------------------------
risk_by_cause_sl <- predictRisk(
  fit_by_cause, newdata = newdata, times = times, cause = 1, model = "sl"
)

risk_by_cause_discrete <- predictRisk(
  fit_by_cause, newdata = newdata, times = times, cause = 1, model = "discrete_sl"
)

list(
  sl = risk_by_cause_sl,
  discrete_sl = risk_by_cause_discrete
)

## -----------------------------------------------------------------------------
cause_specific_model <- c("cvd_simple", "death_shrink")
cause_specific_model_alt <- c("cvd_shrink", "death_simple")

risk_by_cause_selected <- predictRisk(
  fit_by_cause, newdata = newdata, times = times, cause = 1,
  model = cause_specific_model
)

risk_by_cause_selected_alt <- predictRisk(
  fit_by_cause, newdata = newdata, times = times, cause = 1,
  model = cause_specific_model_alt
)

list(
  selected_learners = risk_by_cause_selected,
  selected_learners_alt = risk_by_cause_selected_alt
)

## -----------------------------------------------------------------------------
risk_by_cause_indexed <- predictRisk(
  fit_by_cause, newdata = newdata, times = times, cause = 1, model = c(1, 2)
)

risk_by_cause_indexed

