1 load libraries

library(survival)
library(survminer)
library(cgdsr)
library(sparklyr)
library(dplyr)

2 Working with R Session

2.1 Load clinical Data

cgds <- cgdsr::CGDS("http://www.cbioportal.org/")
#Studies<- cgdsr::getCancerStudies(cgds)
clinicalData <- cgdsr::getClinicalData(cgds, "gbm_tcga_pub_all")

#clinicalData <- read.csv("ClinicalData.csv") #, na.strings=c("","NA")

clinicalData[c('DFS_MONTHS','DFS_STATUS', 'OS_MONTHS', 'OS_STATUS', 'TREATMENT_STATUS' )]

2.2 Transformations 1

clinicalData$OS_STATUS <- gsub("LIVING", "0", clinicalData$OS_STATUS, ignore.case = TRUE)
clinicalData$OS_STATUS <- gsub("DECEASED", "1", clinicalData$OS_STATUS, ignore.case = TRUE)
clinicalData$DFS_STATUS <- gsub("^$|^ $", "DiseaseFree", clinicalData$DFS_STATUS, ignore.case = TRUE)
clinicalData$OS_STATUS <- as.numeric(clinicalData$OS_STATUS)

2.3 survival plot

fit <- survival::survfit(Surv(OS_MONTHS, OS_STATUS) ~ DFS_STATUS, data = clinicalData)
   survminer::ggsurvplot(fit, data = clinicalData,
                          type = "kaplan-meier",
                          #conf.type="log",
                          conf.int = TRUE,
                          pval = TRUE,
                          fun = "pct",
                          risk.table = TRUE,
                          size = 1,
                          linetype = "strata",
                          palette = c("#E7B800", "#2E9FDF"),
                          legend = "top",
                          lengend.title = "DFS_STATUS",
                          legend.labs = c("DiseaseFree", "Recurred")
   )

3 R session VS Spark

3.1 Plot DiseaseFree vs Reccured during OS_MONTHS

  clinicalData <- cgdsr::getClinicalData(cgds, "gbm_tcga_pub_all")
#clinicalData <- read.csv("ClinicalData.csv") #, na.strings=c("","NA")
  start_time <- Sys.time()
  clinicalData %>% 
  mutate(OS_STATUS = gsub("LIVING", "0", OS_STATUS)) %>%
  mutate(OS_STATUS = gsub( "DECEASED", "1", OS_STATUS)) %>%
  mutate(DFS_STATUS = gsub( "^$|^ $", "DiseaseFree", DFS_STATUS)) %>%
  mutate(OS_STATUS = as.numeric(OS_STATUS)) %>%
  arrange(OS_MONTHS) %>%
  mutate( DiseaseFree = ifelse(DFS_STATUS == "DiseaseFree", 1, 0)) %>% 
  as.data.frame() %>%
  mutate(n_DiseaseFree = cumsum(DiseaseFree == 1)) %>%
  mutate(n_Recurred = cumsum(DiseaseFree == 0)) %>%
  ggplot(aes(x = OS_MONTHS, y = value, color = variable)) +
  geom_point(aes(y = n_DiseaseFree, col = "n_DiseaseFree")) +
  geom_point(aes(y = n_Recurred, col = "n_Recurred")) +
  labs(title = paste("Using R Session, Running time = ", Sys.time() - start_time))

3.2 Spark Node: Plot DiseaseFree vs Reccured during OS_MONTHS

 clinicalData <- cgdsr::getClinicalData(cgds, "gbm_tcga_pub_all")
# clinicalData <- read.csv("ClinicalData.csv") #, na.strings=c("","NA")
 sc <- spark_connect(master = "local",
                   version = "2.4.0")

 clinicalData_tbl <- dplyr::copy_to(sc, clinicalData, overwrite = TRUE)
  start_time <- Sys.time()
  clinicalData_tbl %>%
  mutate(OS_STATUS = regexp_replace(OS_STATUS, "LIVING", "0")) %>%
  mutate(OS_STATUS = regexp_replace(OS_STATUS, "DECEASED", "1")) %>%
  mutate(DFS_STATUS = regexp_replace(DFS_STATUS, "^$|^ $", "DiseaseFree")) %>%
  mutate(OS_STATUS = as.numeric(OS_STATUS)) %>%
  #mutate(OS_STATUS = regexp_replace(as.numeric(OS_STATUS), 'NaN', NA)) %>%
  #mutate(OS_STATUS = regexp_replace(OS_STATUS, NaN, NA)) %>%
  #na.replace('') %>%  ## not good for OS_STATUS (0,1)
  #dplyr::filter(!is.na(OS_MONTHS)) 
  arrange(is.na(OS_MONTHS), OS_MONTHS) %>%  ## OUFFF put Nan at the end of the column
  mutate(DiseaseFree = ifelse(DFS_STATUS == "DiseaseFree", 1, 0)) %>% 
  as.data.frame() %>%
  mutate( n_DiseaseFree = cumsum(as.numeric(DiseaseFree == 1 ))) %>%
  mutate( n_Recurred = cumsum(as.numeric(DiseaseFree == 0 ))) %>%
  ggplot(aes(x = OS_MONTHS, y = value, color = variable)) +
  geom_point(aes(y = n_DiseaseFree, col = "n_DiseaseFree")) +
  geom_point(aes(y = n_Recurred, col = "n_Recurred"))  +
   labs(title = paste("Using Spark Node, Running time = ", Sys.time() - start_time))

4 Survival regression using ml_aft_survival_regression

4.1 Ovarian Data from survival package

library(survival)
ovarian

4.1.1 Predict Survival regression using spark

# MAGIC - **futime**: survival or censoring time
# MAGIC - **fustat**: censoring status
# MAGIC - **age**:  in years
# MAGIC - **resid_ds**: residual disease present (1=no, 2=yes)
# MAGIC - **rx**:   treatment group
# MAGIC - **ecog.ps**:  ECOG performance status (1 is better, see reference)

sc <- spark_connect(master = "local",
                     version = "2.4.0")
## Re-using existing Spark connection to local
ovarian_tbl <- sdf_copy_to(sc, ovarian, name = "ovarian_tbl", overwrite = TRUE)

#spark.survreg(ovarian_tbl, Surv(futime, fustat) ~ ecog_ps + rx)

partitions <- ovarian_tbl %>%
  sdf_partition(training = 0.7, test = 0.3, seed = 1111)

ovarian_training <- partitions$training
ovarian_test <- partitions$test

sur_reg <- ovarian_training %>%
  ml_aft_survival_regression(futime ~ ecog_ps + rx + age + resid_ds, censor_col = "fustat")

pred <- ml_predict(sur_reg, ovarian_test)
pred

4.1.2 Extract parameters

intercept <- sur_reg$coefficients[1]
coefficients <- sur_reg$coefficients[c(2,3)]
sur_reg$coefficients
## (Intercept)     ecog_ps          rx         age    resid_ds 
## 11.69793764 -0.27694569  0.65752740 -0.07862066 -0.79872945
plotParams <- round(ovarian[c('resid.ds', 'rx', 'ecog.ps', 'age')])
scale <- exp(intercept + as.matrix(plotParams) * coefficients)
cbind(plotParams, scale)
tSeq <- as_tibble(ovarian_tbl)$futime # seq(0, 5E3, 50)
probs <- tibble(t = tSeq)
for (i in 1:2^4) { 
  probs[, paste("(resid.ds, rx, ecog.ps, age) = (", toString(plotParams[i, ]), ")", sep = "")] <- 
    pweibull(tSeq, shape = 1, scale = scale[i], lower.tail = F)
}
probs

4.1.3 Melt the DataFrame

library(reshape2)
# MAGIC - **futime**: survival or censoring time
# MAGIC - **fustat**: censoring status
# MAGIC - **age**:  in years
# MAGIC - **resid_ds**: residual disease present (1=no, 2=yes)
# MAGIC - **rx**:   treatment group
# MAGIC - **ecog.ps**:  ECOG performance status (1 is better, see reference)

melted <- melt(probs, id.vars="t", variable.name="group", value.name="prob")
melted

4.1.4 Plot survival regression

library(ggplot2)

ggplot(data=melted, aes(x=t, y=prob, group=group, color=group)) + 
  geom_point() +
  #geom_smooth() +
  #geom_jitter() +
  labs(x = "time", y = "Survival probability")

4.2 Clinical Data from gbm_tcga_pub_all Study from cBioPortal

# > ovarian_training
# # Source: spark<?> [?? x 6]
# futime fustat   age resid_ds    rx ecog_ps
# *  <dbl>  <dbl> <dbl>    <dbl> <dbl>   <dbl>
# 1    156      1  66.5        2     1       2
# 2    329      1  43.1        2     1       1
# 3    353      1  63.2        1     2       2
# 4    365      1  64.4        2     2       1
# 5    377      0  58.3        1     2       1
# 6    421      0  53.4        2     2       1
# 7    448      0  56.4        1     1       2
# 8    464      1  56.9        2     2       2
# 9    475      1  59.9        2     2       2
# 10   563      1  55.2        1     2       2
# # ... with more rows

clinicalData <- cgdsr::getClinicalData(cgds, "gbm_tcga_pub_all")

#clinicalData <- read.csv("ClinicalData.csv") #, na.strings=c("","NA")

clinicalData <- clinicalData[c('OS_MONTHS',  'OS_STATUS', 'DFS_STATUS' )]
sc <- spark_connect(master = "local",
                     version = "2.4.0")
## Re-using existing Spark connection to local
 clinicalData_tbl <- dplyr::copy_to(sc, clinicalData, overwrite = TRUE)
  start_time <- Sys.time()
  clinicalData_trans_tbl <- 
  clinicalData_tbl %>%
  mutate(OS_STATUS = regexp_replace(OS_STATUS, "LIVING", 1)) %>%
  mutate(OS_STATUS = regexp_replace(OS_STATUS, "DECEASED", 0)) %>%
  mutate(DFS_STATUS = regexp_replace(DFS_STATUS, "^$|^ $", "DiseaseFree")) %>%
  mutate(DFS_STATUS = regexp_replace(DFS_STATUS, "DiseaseFree", 1)) %>%
  mutate(DFS_STATUS = regexp_replace(DFS_STATUS, "Recurred", 2)) %>%
 # mutate(xr = ifelse(TREATMENT_STATUS == "Untreated", 1 , 2)) %>%
 # mutate(xr = ifelse(TREATMENT_STATUS == "Treated", 2, 1)) %>%
  mutate(OS_STATUS = as.numeric(OS_STATUS)) %>%
  mutate(DFS_STATUS = as.numeric(DFS_STATUS)) %>%
  #arrange(is.na(OS_MONTHS), OS_MONTHS) %>% ## OUFFF put Nan at the end of the column
  filter(!is.na(OS_STATUS)) ## rm all NA in OS_STATUS column
    #na.replace(1)
  clinicalData_trans <- as_tibble(clinicalData_trans_tbl)
  clinicalData_trans_tbl
partitions_clinicalData <- clinicalData_trans_tbl %>%
  sdf_partition(training = 0.9, test = 0.1, seed = 1111)

clinicalData_training <- partitions_clinicalData$training
clinicalData_test <- partitions_clinicalData$test

sur_reg_clinicalData <- clinicalData_training %>%
  ml_aft_survival_regression(OS_MONTHS ~ DFS_STATUS + OS_STATUS, censor_col = "OS_STATUS")

# We can save and load model by  `ml_save()` and `ml_load()`, which should work fine at least with a local Spark connection.

pred_clinicalData <- ml_predict(sur_reg_clinicalData, clinicalData_test)
pred_clinicalData

4.2.1 Extract parameters for Clinical Data

intercept_clinicalData <- sur_reg_clinicalData$coefficients[1]
coefficients_clinicalData <- sur_reg_clinicalData$coefficients[c(2,3)]
sur_reg_clinicalData$coefficients
## (Intercept)  DFS_STATUS   OS_STATUS 
##  12.0134996   0.1986802  -8.5141371
plotParams_clinicalData <- clinicalData_trans_tbl %>%
  select(c('DFS_STATUS', 'OS_STATUS')) %>%
  collect()

scale_clinicalData <- as_tibble(exp(intercept_clinicalData + as_tibble(plotParams_clinicalData) * coefficients_clinicalData))
cbind(plotParams_clinicalData, scale_clinicalData)
as_tibble(plotParams_clinicalData)[1,]
tSeq_clinicalData <-  
      clinicalData_trans_tbl %>%
      select('OS_MONTHS')
probs_clinicalData <- data.frame(t = tSeq_clinicalData)
for (i in 1:8) { 
  probs_clinicalData[, paste("(DFS_STATUS, OS_STATUS) = (", toString(as_tibble(plotParams_clinicalData)[i, ]), ")", sep = "")] <- 
    pweibull(pull(tSeq_clinicalData), shape = 1, scale = pull(scale_clinicalData)[i], lower.tail = F)
}
probs_clinicalData
library(reshape2)
melted_clinicalData <- melt(probs_clinicalData, id.vars="OS_MONTHS", variable.name="group", value.name="prob") %>%
  collect()
melted_clinicalData
library(ggplot2)
library(grid)
ggplot(data= melted_clinicalData, aes(x= OS_MONTHS, y= prob, group= group, color= group)) + 
  geom_point() +
  #geom_smooth() +
  #geom_jitter() +
  labs(x = "time", y = "Survival probability") +
  # annotation_custom(grob = textGrob("Read all about it"),  
  #       xmin = 120, xmax = 120, ymin = 0.3, ymax = 0.3) +
  theme(legend.position = c(0.8, 0.85),  legend.background = element_rect(color = "grey90", fill = "grey90")) +
  geom_text(aes(label = '1: DeseaseFree / Living', x = 95, y = 0.7), color="grey60", size=3.5) +
  geom_text(aes(label = '2: Recurred, 0: Diceased', x = 95, y = 0.65), color="grey60", size=3.5) +
  geom_text(aes(label = paste('running time: ', round(Sys.time() - start_time, digits = 2), 's'), x = 95, y = 0.6), color="#a0a0a0", size=3.5)

  #scale_color_manual(labels = c("T999", "T888"), values = c("blue", "red"))