1 load package and get datasets

library(survival)
library(survminer)
## Loading required package: ggplot2
## Loading required package: ggpubr
## Loading required package: magrittr
library(cgdsr)
## Please send questions to cbioportal@googlegroups.com
library(sparklyr)
library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
library(ggplot2)
library(grid)
library(shiny)
library(miniUI)
library(reshape2)

cgds <- cgdsr::CGDS("http://www.cbioportal.org/")
Studies<- cgdsr::getCancerStudies(cgds)
clinicalData <- cgdsr::getClinicalData(cgds, "gbm_tcga_pub_all")

sc <- spark_connect(master = "local")

ClinicalData_tbl <- dplyr::copy_to(sc, clinicalData, name = "ClinicalData_tbl", overwrite = TRUE)

iris_tbl <- copy_to(sc, iris, "iris", overwrite = TRUE)

2 Create the gadget user interface

ui <- miniPage(
  gadgetTitleBar("ggplot Data after spark transformation"),
  
  miniTabstripPanel(
    miniTabPanel("Plot", icon = icon("area-chart"),
                 miniContentPanel(
                   plotOutput("clinicalDataPlot")
                 )
    ),
    miniTabPanel("Linear model", icon = icon("table"),
                 miniContentPanel(
                   plotOutput("iris_model_plot")
                 )
    ),
    miniTabPanel("Survival Regression", icon = icon("table"),
                 miniContentPanel(
                   plotOutput("Survival_regression_plot")
                 )
    )
    
  )
)

3 Create the shiny gadget functions

server <- function(input, output){

  ## import data from spark to R
  
  iris_tbl <- spark_read_table(sc, 'iris')
  
  start_time <- Sys.time()
  spark_clinicalData_trans <- reactive({
    clinicalData_tbl <- spark_read_table(sc, 'ClinicalData_tbl')
    clinicalData_tbl %>%
      mutate(OS_STATUS = regexp_replace(OS_STATUS, "LIVING", "0")) %>%
      mutate(OS_STATUS = regexp_replace(OS_STATUS, "DECEASED", "1")) %>%
      mutate(DFS_STATUS = regexp_replace(DFS_STATUS, "^$|^ $", "DiseaseFree")) %>%
      filter(!is.na(OS_STATUS)) %>%
      mutate(OS_STATUS = as.numeric(OS_STATUS)) %>%
      arrange(is.na(OS_MONTHS), OS_MONTHS) %>%  ## OUFFF put Nan at the end of the column
      mutate(DiseaseFree = ifelse(DFS_STATUS == "DiseaseFree", 1, 0)) %>% 
      as.data.frame() %>%
      mutate( n_DiseaseFree = cumsum(as.numeric(DiseaseFree == 1 ))) %>%
      mutate( n_Recurred = cumsum(as.numeric(DiseaseFree == 0 ))) %>%
      collect()
    
  })
  
  # DOES NOT WORK
  # spark_iris_model <- reactive({
  #    iris_tbl %>%
  #     select(Petal_Width, Petal_Length) %>%
  #     ml_linear_regression(Petal_Length ~ Petal_Width) %>%
  #     collect()
  #   
  # }) 
  
  spark_iris_data <- reactive({
    iris_tbl %>%
      select(Petal_Width, Petal_Length)
  })
  
  
  output$clinicalDataPlot <- renderPlot({
    ggplot(spark_clinicalData_trans(), 
           aes(x = OS_MONTHS, y = value, color = variable)) +
      geom_point(aes(y = n_DiseaseFree, col = "n_DiseaseFree")) +
      geom_point(aes(y = n_Recurred, col = "n_Recurred"))  +
      labs(title = paste("Running Time = ", Sys.time() - start_time, " s"))
  })
  
  output$iris_model_plot <- renderPlot({
    
    #ml_model <-  spark_iris_model() DOES NOT WORK
    
    ## Predict linear model
    ml_model <-  iris_tbl %>%
      select(Petal_Width, Petal_Length) %>%
      ml_linear_regression(Petal_Length ~ Petal_Width)
    
    # get ggplot
    ggplot(spark_iris_data(), ## select data
           aes(Petal_Length, Petal_Width)) +
      geom_point(aes(Petal_Width, Petal_Length), size = 2, alpha = 0.5) +
      geom_abline(aes(slope = ml_model$coefficients[2],
                      intercept = ml_model$coefficients[1]),
                  color = "red") +
      labs(
        x = "Petal Width",
        y = "Petal Length",
        title = "Linear Regression: Petal Length ~ Petal Width",
        subtitle = "Use Spark.ML linear regression to predict petal length as a function of petal width."
      )
  })
  
  
  spark_clinical_trans_for_regression <- function(){
    print("0")
    # clinicalData <- read.csv("ClinicalData.csv") #, na.strings=c("","NA")
    # 
    # clinicalData <- clinicalData[c('OS_MONTHS',  'OS_STATUS', 'DFS_STATUS' )]
    # 
    # sc <- spark_connect(master = "local",
    #                      version = "2.4.0")
    # clinicalData_tbl <- dplyr::copy_to(sc, clinicalData, overwrite = TRUE)
    
    ClinicalData_tbl <- spark_read_table(sc, 'ClinicalData_tbl')
    
    #start_time_surv_reg <- Sys.time()
    clinicalData_trans_tbl <- ClinicalData_tbl %>%
      mutate(OS_STATUS = regexp_replace(OS_STATUS, "LIVING", 1)) %>%
      mutate(OS_STATUS = regexp_replace(OS_STATUS, "DECEASED", 0)) %>%
      mutate(DFS_STATUS = regexp_replace(DFS_STATUS, "^$|^ $", "DiseaseFree")) %>%
      mutate(DFS_STATUS = regexp_replace(DFS_STATUS, "DiseaseFree", 1)) %>%
      mutate(DFS_STATUS = regexp_replace(DFS_STATUS, "Recurred", 2)) %>%
      # mutate(xr = ifelse(TREATMENT_STATUS == "Untreated", 1 , 2)) %>%
      # mutate(xr = ifelse(TREATMENT_STATUS == "Treated", 2, 1)) %>%
      mutate(OS_STATUS = as.numeric(OS_STATUS)) %>%
      mutate(DFS_STATUS = as.numeric(DFS_STATUS)) %>%
      #arrange(is.na(OS_MONTHS), OS_MONTHS) %>% ## OUFFF put Nan at the end of the column
      filter(!is.na(OS_STATUS)) ## rm all NA in OS_STATUS column
    
    partitions_clinicalData <- clinicalData_trans_tbl %>%
      sdf_partition(training = 0.9, test = 0.1, seed = 1111)
    
    clinicalData_training <- partitions_clinicalData$training
    clinicalData_test <- partitions_clinicalData$test
    
    sur_reg_clinicalData <- clinicalData_training %>%
      ml_aft_survival_regression(OS_MONTHS ~ DFS_STATUS + OS_STATUS, censor_col = "OS_STATUS") 
    
    intercept_clinicalData <- sur_reg_clinicalData$coefficients[1]
    coefficients_clinicalData <- sur_reg_clinicalData$coefficients[c(2,3)]
    
    plotParams_clinicalData <- clinicalData_trans_tbl %>%
      select(c('DFS_STATUS', 'OS_STATUS'))
    
    scale_clinicalData <- as_tibble(exp(intercept_clinicalData + as_tibble(plotParams_clinicalData) * coefficients_clinicalData))
    
    tSeq_clinicalData <-  clinicalData_trans_tbl %>% select('OS_MONTHS')
    probs_clinicalData <- data.frame(t = tSeq_clinicalData)
    print("1")
    
    for (i in 1:8) { 
      probs_clinicalData[, paste("(DFS_STATUS, OS_STATUS) = (", toString(as_tibble(plotParams_clinicalData)[i, ]), ")", sep = "")] <- 
        pweibull(pull(tSeq_clinicalData), shape = 1, scale = pull(scale_clinicalData)[i], lower.tail = F)
    }
       print("2")
    melted_clinicalData <- melt(probs_clinicalData, id.vars="OS_MONTHS", variable.name="group", value.name="prob") %>%
      collect()
    return(melted_clinicalData)
    
    
  }
  
  
  
  output$survival_regression_plot <- renderPlot({
    start_time_surv_reg <- Sys.time()
    spark_clinical_trans_for_regression() %>%
    ggplot( aes(x= OS_MONTHS, y= prob, group= group, color= group)) + 
      geom_point() +
      #geom_smooth() +
      #geom_jitter() +
      labs(title = "plot the spark ml_aft_survival_regression modeling",
           x = "time", y = "Survival probability") +
      # annotation_custom(grob = textGrob("Read all about it"),  
      #       xmin = 120, xmax = 120, ymin = 0.3, ymax = 0.3) +
      theme(legend.position = c(0.8, 0.85),  legend.background = element_rect(color = "grey90", fill = "grey90")) +
      geom_text(aes(label = '1: DeseaseFree / Living', x = 95, y = 0.7), color="grey60", size=3.5) +
      geom_text(aes(label = '2: Recurred, 0: Diceased', x = 95, y = 0.65), color="grey60", size=3.5)+
       geom_text(aes(label = paste('running time: ', round(Sys.time() - start_time_surv_reg, digits = 2), 's'), x = 95, y = 0.6), color="#a0a0a0", size=3.5)
    
  })
  
  observeEvent(input$done, {
    stopApp(TRUE)
  })
}
# Run the gadget
runGadget(ui, server)
## 
## Listening on http://127.0.0.1:5407
## [1] TRUE