Code
library(GGally)
library(tidyverse)
library(mlr3proba)
We investigate the inflation that may occur when using two scoring rules for evaluating survival models. The scoring rules are the Integrated Survival Brier Score (ISBS) (Graf et al. 1999), and the proposed re-weighted version (RISBS) (Sonabend 2022). See documentation details for their respective formulas. The first (ISBS) is not a proper scoring rule (Rindt et al. 2022), the second (RISBS) is (Sonabend 2022).
In this section we investigate an example where the proper ISBS gets inflated (i.e. too large value for the score, compared to the improper version) and show how we can avoid such a thing from happening when evaluating model performance.
Load libraries:
library(GGally)
library(tidyverse)
library(mlr3proba)
Let’s use a dataset where in a particular train/test resampling the issue occurs:
= readRDS(file = "inflated_data.rds")
inflated_data = inflated_data$task
task = inflated_data$part
part
task
<TaskSurv:mgus> (176 x 9)
* Target: time, status
* Properties: -
* Features (7):
- dbl (6): age, alb, creat, dxyr, hgb, mspike
- fct (1): sex
Separate train and test data:
= task$clone()$filter(rows = part$train)
task_train = task$clone()$filter(rows = part$test) task_test
Kaplan-Meier of the training survival data:
autoplot(task_train) +
labs(title = "Kaplan-Meier (train data)",
subtitle = "Time-to-event distribution")
Kaplan-Meier of the training censoring data:
autoplot(task_train, reverse = TRUE) +
labs(title = "Kaplan-Meier (train data)",
subtitle = "Censoring distribution")
Estimates of the censoring distribution G_{KM}(t) (values from the above figure):
= task_train$kaplan(reverse = TRUE)
km_train = tibble(time = km_train$time, surv = km_train$surv)
km_tbl tail(km_tbl)
# A tibble: 6 × 2
time surv
<dbl> <dbl>
1 12140 0.75
2 12313 0.625
3 12319 0.5
4 12349 0.25
5 12689 0.125
6 13019 0
As we can see from the above figures and table, due to having at least one censored observation at the last time point, G_{KM}(t_{max}) = 0 for t_{max} = 13019.
Is there an observation on the test set that has died (status
= 1) on that last time point (or after)?
= max(km_tbl$time) # max time point
max_time
= task_test$times()
test_times = task_test$status()
test_status
# get the id of the observation in the test data
= which(test_times >= max_time & test_status == 1)
id id
[1] 14
Yes there is such observation!
In mlr3proba
using proper = TRUE
for the RISBS calculation, this observation will be weighted by 1/0 according to the formula. Practically, to avoid division by zero, a small value eps = 0.001
will be used.
Let’s train a simple Cox model on the train set and calculate its predictions on the test set:
= lrn("surv.coxph")
cox = cox$train(task, part$train)$predict(task, part$test) p
We calculate the ISBS (improper) and RISBS (proper) scores:
= msr("surv.graf", proper = FALSE, id = "graf.improper")
graf_improper = msr("surv.graf", proper = TRUE, id = "graf.proper")
graf_proper $score(graf_improper, task = task, train_set = part$train) p
graf.improper
0.1493429
$score(graf_proper, task = task, train_set = part$train) p
graf.proper
10.64584
As we can see there is huge difference between the two versions of the score. We check the observation-wise scores (integrated across all time points):
Observation-wise RISBS scores:
$scores graf_proper
[1] 0.08994417 0.02854219 0.04214266 0.15578719 0.05364692
[6] 0.12969150 0.06463256 0.32033549 2.43262450 0.11602432
[11] 0.03228501 0.10172088 0.14652850 367.10227335 0.18004727
[16] 0.21991511 0.09070024 0.03507389 0.19856844 0.07925747
[21] 0.07732517 0.06982001 0.19468406 0.05267402 0.02419841
[26] 0.17645640 0.07633691 0.04379196 0.07839955 0.06684222
[31] 0.05457688 0.02874430 0.04071108 0.00000000 0.00000000
Observation-wise ISBS scores:
$scores graf_improper
[1] 0.08994417 0.02854219 0.04214266 0.15578719 0.05364692 0.12969150
[7] 0.06463256 0.32033549 0.62971109 0.11602432 0.03228501 0.10172088
[13] 0.14652850 1.07969258 0.16743979 0.21991511 0.09070024 0.03507389
[19] 0.19856844 0.07925747 0.07732517 0.06982001 0.19468406 0.05267402
[25] 0.02419841 0.16199516 0.07633691 0.04379196 0.07839955 0.06684222
[31] 0.05457688 0.02874430 0.04071108 0.03512466 0.46541333
It is the one observation that we identified earlier that causes the inflation of the RISBS score - it’s pretty much an outlier compared to all other values:
$scores[id] graf_proper
[1] 367.1023
Same is true for the improper ISBS, value is approximately x10 larger compared to the other observation-wise scores:
$scores[id] graf_improper
[1] 1.079693
By setting t_max
(time horizon to evaluate the measure up to) to the 95\% quantile of the event times, we can solve the inflation problem of the proper RISBS score, since we will divide by a value larger than zero from the above table of G_{KM}(t) values. The t_max
time point is:
= as.integer(quantile(task_train$unique_event_times(), 0.95))
t_max t_max
[1] 10080
Integrating up to t_max
, the proper RISBS score is:
= msr("surv.graf", id = "graf.proper", proper = TRUE, t_max = t_max)
graf_proper_tmax $score(graf_proper_tmax, task = task, train_set = part$train) # ISBS p
graf.proper
0.1436484
The score for the specific observation that had experienced the event at (or beyond) the latest training time point is now:
$scores[id] graf_proper_tmax
[1] 0.141502
To avoid the inflation of RISBS and generally have a more robust estimation of both RISBS and ISBS scoring rules, we advise to set the t_max
argument (time horizon). This can be either study-driven or based on a meaningful quantile of the distribution of (usually event) times in your dataset (e.g. 80\%).