Interactive Regression

#| '!! shinylive warning !!': |
#|   shinylive does not work in self-contained HTML documents.
#|   Please set `embed-resources: false` in your metadata.
#| standalone: true
#| viewerHeight: 800

library(shiny)
library(ggplot2)
library(bslib)
library(DT)       
library(tools)

# -------------------------------------------------------------------------
# Custom Theme 
# -------------------------------------------------------------------------
my_theme <- bs_theme(
  version = 5,
  bootswatch = "zephyr",
  primary = "#2C3E50",
  secondary = "#95a5a6"
  # Comment out fonts temporarily for debugging
  # base_font = font_collection(...), 
  # heading_font = font_collection(...) 
)

# -------------------------------------------------------------------------
# UI Definition
# -------------------------------------------------------------------------
ui <- page_sidebar(
 # theme = my_theme,
  fill = FALSE, 
  fillable = FALSE, 
  
  sidebar = sidebar(
    width = 300,
    h4("Data Input"),
    textAreaInput("xdata", "X Variable", value = "1, 2, 3, 4, 5, 6", height = "80px"),
    textAreaInput("ydata", "Y Variable", value = "1.5, 2.8, 3.2, 5.0, 5.5, 7.2", height = "80px"),
    hr(),
    h4("Controls"),
    checkboxInput("showResiduals", "Show Residuals", FALSE),
    actionButton("go", "Update", class = "btn-primary w-100", icon = icon("rotate")),
    hr(),
    numericInput("pred_x", "Predict Y at X:", value = 3.5),
    uiOutput("prediction_result")
  ),

  layout_columns(
    col_widths = c(8, 4, 12),
    row_heights = c("auto", "auto"),

    # Visualization
    card(
      card_header("Scatterplot"),
      # explicit height on plot helps stability
      plotOutput("plot", height = "400px")
    ),

    # Statistics
    card(
      card_header("Stats"),
      uiOutput("stats_ui")
    ),

    # Data Table
    card(
      card_header("Data & Residuals"),
      DTOutput("resultsTable")
    )
  )
)

# -------------------------------------------------------------------------
# Server Logic
# -------------------------------------------------------------------------
server <- function(input, output) {

  analysis_core <- eventReactive(input$go, {
    req(input$xdata, input$ydata)
    tryCatch({
      x_vals <- as.vector(as.numeric(unlist(strsplit(input$xdata, ","))))
      y_vals <- as.vector(as.numeric(unlist(strsplit(input$ydata, ","))))

      if(length(x_vals) != length(y_vals)) stop("X/Y mismatch")
      if(length(x_vals) < 2) stop("Need 2+ points")
      if(any(is.na(x_vals)) || any(is.na(y_vals))) stop("Non-numeric inputs")

      df <- data.frame(Observation = seq_along(x_vals), x = unname(x_vals), y = unname(y_vals))
      
      fit <- lm(y ~ x, data = df)
      s_summ <- summary(fit)
      b_coefs <- unname(coef(fit)) 
      
      mx <- mean(df$x); my <- mean(df$y)
      sx <- sd(df$x); sy <- sd(df$y)
      n  <- nrow(df)
      r_contrib <- (((df$x - mx)/sx) * ((df$y - my)/sy)) / (n - 1)
      
      list(
        data = df,
        intercept = b_coefs[1],
        slope = b_coefs[2],
        y_pred = as.vector(unname(predict(fit))),
        residuals = as.vector(unname(resid(fit))),
        r_squared = unname(s_summ$r.squared),
        cor = unname(cor(df$x, df$y)),
        r_contrib = as.vector(unname(r_contrib)),
        ssx = unname(sum((df$x - mx)^2)),
        ssy = unname(sum((df$y - my)^2)),
        valid = TRUE
      )
    }, error = function(e) { list(valid = FALSE, message = e$message) })
  }, ignoreNULL = FALSE) 

  output$prediction_result <- renderUI({
    res <- analysis_core()
    req(res$valid, input$pred_x)
    pred_val <- res$intercept + res$slope * input$pred_x
    div(
      style = "margin-top: 10px; padding: 10px; background-color: #f8f9fa; border-left: 5px solid #2C3E50;",
      strong("Predicted Y:"), br(), span(style = "font-size: 1.2em; color: #2C3E50;", round(pred_val, 4))
    )
  })

  output$plot <- renderPlot({
    res <- analysis_core()
    validate(need(res$valid, paste("Error:", res$message)))
    df <- res$data
    
    mean_x <- mean(df$x)
    pred_mean_y <- res$intercept + res$slope * mean_x

    centroid_label_safe <- sprintf("atop('(' * bar(x) * ',' ~ bar(y) * ')', '(' * %.1f * ',' ~ %.1f * ')')", mean_x, pred_mean_y)
    eq_label <- bquote(italic(hat(y)) == .(format(round(res$intercept, 3), nsmall=3)) + .(format(round(res$slope, 3), nsmall=3)) * italic(x))

    p <- ggplot(df, aes(x, y)) +
      geom_smooth(method = "lm", formula = y ~ x, se = FALSE, color = "#2C3E50", linewidth = 1.2, alpha = 0.8) +
      geom_point(size = 3, alpha = 0.7, color = "#2C3E50") +
      annotate("point", x = mean_x, y = pred_mean_y, color = "#e74c3c", size = 4, shape = 18) +
      annotate("text", x = mean_x, y = pred_mean_y, label = centroid_label_safe, parse = TRUE, vjust = -0.5, color = "#e74c3c", fontface = "bold", size = 4.5) +
      labs(title = "Scatterplot with Regression Line", subtitle = eq_label, x = "X", y = "Y") +
      theme_minimal(base_size = 14) + theme(plot.subtitle = element_text(face = "italic", size = 14, color = "#7f8c8d"))

    if (input$showResiduals) {
      p <- p + geom_segment(aes(x = x, xend = x, y = res$y_pred, yend = y), color = "#e74c3c", linetype = "dashed", alpha = 0.6)
    }
    p
  })

  output$stats_ui <- renderUI({
    res <- analysis_core()
    validate(need(res$valid, "Waiting..."))
    stat_row <- function(l, v) { div(style="display:flex; justify-content:space-between; border-bottom:1px solid #eee;", span(style="color:#7f8c8d;",l), strong(style="color:#2C3E50;",v)) }
    tagList(
      stat_row("Slope (b₁)", round(res$slope, 4)),
      stat_row("Intercept (b₀)", round(res$intercept, 4)),
      stat_row("Correlation (r)", round(res$cor, 4)),
      stat_row("R²", round(res$r_squared, 4))
    )
  })

  output$resultsTable <- renderDT({
    res <- analysis_core()
    validate(need(res$valid, ""))
    datatable(data.frame(X = res$data$x, Y = res$data$y, Resid = round(res$residuals, 3), Contrib_r = round(res$r_contrib, 4)),
              options = list(pageLength = 4, dom = 't', searching = FALSE), 
              rownames = FALSE, selection = 'none', style = 'bootstrap4', class = 'table table-hover table-striped') %>%
      formatStyle('Resid', color = styleInterval(0, c('#e74c3c', '#27ae60')))
  })
}

shinyApp(ui = ui, server = server)