Easy Bayesian linear modeling

STA602 at Duke University

rstanarm and bayesplot

Download

To download rstanarm and bayesplot run the code below

install.packages("rstanarm", "bayesplot")

To load the packages, run

library(rstanarm)
library(bayesplot)

Overview

  • rstanarm contains a host of functions to make Bayesian linear modeling in R easy. See https://mc-stan.org/rstanarm/articles/ for a variety of tutorials.

    • pros: easy to test Bayesian linear models, can be fast (uses Hamiltonian Monte Carlo proposals)

    • cons: limited in scope, e.g. requires differentiable objective and small model adjustments can be cumbersome to implement, e.g. placing a prior on variance versus standard deviation of normal model.

  • bayesplot contains many useful plotting wrappers that work out of the box with objects created by rstanarm in an intuitive way.

Example

library(tidyverse)
spam = read_csv(
  "https://sta602-sp25.github.io/data/spam.csv")
glimpse(spam)
Rows: 4,601
Columns: 58
$ make              <dbl> 0.00, 0.21, 0.06, 0.00, 0.00, 0.00, 0.00, 0.00, 0.15…
$ address           <dbl> 0.64, 0.28, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ all               <dbl> 0.64, 0.50, 0.71, 0.00, 0.00, 0.00, 0.00, 0.00, 0.46…
$ num3d             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ our               <dbl> 0.32, 0.14, 1.23, 0.63, 0.63, 1.85, 1.92, 1.88, 0.61…
$ over              <dbl> 0.00, 0.28, 0.19, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ remove            <dbl> 0.00, 0.21, 0.19, 0.31, 0.31, 0.00, 0.00, 0.00, 0.30…
$ internet          <dbl> 0.00, 0.07, 0.12, 0.63, 0.63, 1.85, 0.00, 1.88, 0.00…
$ order             <dbl> 0.00, 0.00, 0.64, 0.31, 0.31, 0.00, 0.00, 0.00, 0.92…
$ mail              <dbl> 0.00, 0.94, 0.25, 0.63, 0.63, 0.00, 0.64, 0.00, 0.76…
$ receive           <dbl> 0.00, 0.21, 0.38, 0.31, 0.31, 0.00, 0.96, 0.00, 0.76…
$ will              <dbl> 0.64, 0.79, 0.45, 0.31, 0.31, 0.00, 1.28, 0.00, 0.92…
$ people            <dbl> 0.00, 0.65, 0.12, 0.31, 0.31, 0.00, 0.00, 0.00, 0.00…
$ report            <dbl> 0.00, 0.21, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ addresses         <dbl> 0.00, 0.14, 1.75, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ free              <dbl> 0.32, 0.14, 0.06, 0.31, 0.31, 0.00, 0.96, 0.00, 0.00…
$ business          <dbl> 0.00, 0.07, 0.06, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ email             <dbl> 1.29, 0.28, 1.03, 0.00, 0.00, 0.00, 0.32, 0.00, 0.15…
$ you               <dbl> 1.93, 3.47, 1.36, 3.18, 3.18, 0.00, 3.85, 0.00, 1.23…
$ credit            <dbl> 0.00, 0.00, 0.32, 0.00, 0.00, 0.00, 0.00, 0.00, 3.53…
$ your              <dbl> 0.96, 1.59, 0.51, 0.31, 0.31, 0.00, 0.64, 0.00, 2.00…
$ font              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ num000            <dbl> 0.00, 0.43, 1.16, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ money             <dbl> 0.00, 0.43, 0.06, 0.00, 0.00, 0.00, 0.00, 0.00, 0.15…
$ hp                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ hpl               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ george            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ num650            <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ lab               <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ labs              <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ telnet            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ num857            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ data              <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.15…
$ num415            <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ num85             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ technology        <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ num1999           <dbl> 0.00, 0.07, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ parts             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ pm                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ direct            <dbl> 0.00, 0.00, 0.06, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ cs                <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ meeting           <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ original          <dbl> 0.00, 0.00, 0.12, 0.00, 0.00, 0.00, 0.00, 0.00, 0.30…
$ project           <dbl> 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ re                <dbl> 0.00, 0.00, 0.06, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ edu               <dbl> 0.00, 0.00, 0.06, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00…
$ table             <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ conference        <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0…
$ charSemicolon     <dbl> 0.000, 0.000, 0.010, 0.000, 0.000, 0.000, 0.000, 0.0…
$ charRoundbracket  <dbl> 0.000, 0.132, 0.143, 0.137, 0.135, 0.223, 0.054, 0.2…
$ charSquarebracket <dbl> 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.0…
$ charExclamation   <dbl> 0.778, 0.372, 0.276, 0.137, 0.135, 0.000, 0.164, 0.0…
$ charDollar        <dbl> 0.000, 0.180, 0.184, 0.000, 0.000, 0.000, 0.054, 0.0…
$ charHash          <dbl> 0.000, 0.048, 0.010, 0.000, 0.000, 0.000, 0.000, 0.0…
$ capitalAve        <dbl> 3.756, 5.114, 9.821, 3.537, 3.537, 3.000, 1.671, 2.4…
$ capitalLong       <dbl> 61, 101, 485, 40, 40, 15, 4, 11, 445, 43, 6, 11, 61,…
$ capitalTotal      <dbl> 278, 1028, 2259, 191, 191, 54, 112, 49, 1257, 749, 2…
$ type              <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1…

Description

4601 emails sent to the inbox of someone named “George” that are classified as type = 1 (spam) or 0 (non-spam). The data was collected at Hewlett-Packard labs and contains 58 variables. The first 48 variables are specific keywords and each observation is the percentage of appearance (frequency) of that word in the message. Click here to read more.

Exercise Overview

You want to build a spam filter that blocks emails that have a high probability of being spam.

In statistical terms, your outcome \(Y\) is the type of the email (spam or not). The 57 predictors from the data set are contained in \(x\) and include: the frequency of certain words, the occurrence of certain symbols and the use of capital letters in each email.

Let \(X = \{x_1, \ldots, x_n\}\) where \(x_i \in \mathbb{R}^{58}\) (number of predictors + 1 for intercept) and \(n =\) the number of emails in the data set. \(Y \in \mathbb{R}^n\).

\[ \begin{aligned} p(y_i =1 | x_i, \beta) = \theta_i\\ \text{logit}(\theta_i) = X\beta \end{aligned} \]

A priori, you believe that many of the predictors included in the data set do not in fact help you predict whether the email is spam. You express your beliefs with the prior designated below:

Let \(\beta_0\) be the intercept term.

\[ \beta_0 \sim Normal(0, 100) \]

For each \(\beta\) associated with the 57 predictors:

\[ \beta_i \sim \text{iid}~Laplace(0, .5) \ \ \text{for }i \in \{1, 57\}{} \]

Standardize the data

  • Before we fit our model to the data, we need to standardize the predictors (columns of \(X\)). Why is this important? Discuss.
# scale functions re-scales columns of a df
spam2 = cbind("type" = spam$type, 
              scale(select(spam, -"type"))) %>%
  data.frame()

To validate our model we will separate it into non-overlapping sets – a training set and a testing set.

set.seed(360) # ensures we get the same subset
N = nrow(spam2)
indices = sample(N, size = 0.8 * N)
spam_train = spam2[indices,]
spam_test = spam2[-indices,]
# sanity check 
nrow(spam_train) + nrow(spam_test) == N
[1] TRUE

Exercise 1

Read about how to fit Bayesian logistic regression using rstanarm here: https://mc-stan.org/rstanarm/articles/binomial.html and write code to fit the spam_train data set.

Hint: use the stan_glm function. If you set arguments chains = 1, this will run 1 Markov chain instead of the default 4. You can use the argument iter=2000 to manually set the number of iterations in your Markov chain to 2000. This may take anywhere from 1-4 minutes to run locally on your machine. If you are pressed for time, you can load the resulting object directly from the website using the code below.

fit1 = readRDS(url("https://sta602-sp25.github.io/data/spam-train-fit.rds"))

Exercise 2

Examining the output

  • Did stan_glm do what we think it did? Did the Markov chain converge? Which parameters, if any, have a 90% credible interval that covers 0?

Notice sample: 1000 since half get thrown away by stan and is called “burn-in” i.e. a period that the chain spends reaching the target distribution gets discarded.

summary(fit1)

Model Info:
 function:     stan_glm
 family:       binomial [logit]
 formula:      type ~ .
 algorithm:    sampling
 sample:       1000 (posterior sample size)
 priors:       see help('prior_summary')
 observations: 3680
 predictors:   58

Estimates:
                    mean   sd    10%   50%   90%
(Intercept)        -3.9    0.5  -4.5  -3.9  -3.4
make               -0.1    0.1  -0.2  -0.1   0.0
address            -0.2    0.1  -0.4  -0.2  -0.1
all                 0.1    0.1   0.0   0.1   0.1
num3d               0.9    0.7   0.2   0.7   1.7
our                 0.4    0.1   0.3   0.4   0.5
over                0.2    0.1   0.1   0.2   0.3
remove              0.9    0.1   0.7   0.9   1.1
internet            0.2    0.1   0.1   0.2   0.3
order               0.1    0.1   0.0   0.1   0.2
mail                0.1    0.0   0.0   0.1   0.1
receive             0.0    0.1  -0.1   0.0   0.0
will               -0.2    0.1  -0.3  -0.2  -0.1
people             -0.1    0.1  -0.1   0.0   0.0
report              0.0    0.1  -0.1   0.0   0.1
addresses           0.3    0.2   0.1   0.3   0.5
free                1.0    0.1   0.8   1.0   1.1
business            0.5    0.1   0.3   0.5   0.6
email               0.1    0.1   0.0   0.1   0.2
you                 0.1    0.1   0.0   0.1   0.2
credit              0.6    0.2   0.3   0.6   0.9
your                0.3    0.1   0.2   0.3   0.4
font                0.2    0.2   0.1   0.2   0.4
num000              0.7    0.2   0.5   0.7   0.9
money               0.2    0.1   0.1   0.2   0.3
hp                 -3.1    0.5  -3.8  -3.1  -2.5
hpl                -1.0    0.4  -1.5  -1.0  -0.5
george             -8.5    1.8 -10.9  -8.5  -6.3
num650              0.2    0.1   0.1   0.2   0.4
lab                -1.1    0.5  -1.8  -1.0  -0.5
labs               -0.1    0.1  -0.3  -0.1   0.0
telnet             -0.3    0.3  -0.8  -0.3   0.0
num857             -0.2    0.6  -1.0  -0.2   0.4
data               -0.7    0.2  -1.0  -0.7  -0.5
num415             -0.6    0.8  -1.7  -0.5   0.2
num85              -0.8    0.4  -1.3  -0.8  -0.3
technology          0.4    0.1   0.2   0.4   0.6
num1999             0.0    0.1  -0.1   0.0   0.1
parts              -0.2    0.1  -0.4  -0.2  -0.1
pm                 -0.4    0.2  -0.6  -0.4  -0.2
direct             -0.2    0.1  -0.3  -0.1   0.0
cs                 -1.4    0.8  -2.4  -1.2  -0.5
meeting            -1.5    0.5  -2.2  -1.5  -1.0
original           -0.4    0.2  -0.8  -0.4  -0.1
project            -1.3    0.4  -1.8  -1.3  -0.8
re                 -0.7    0.2  -0.9  -0.7  -0.5
edu                -1.5    0.3  -1.9  -1.5  -1.1
table              -0.3    0.1  -0.5  -0.3  -0.1
conference         -1.1    0.4  -1.7  -1.1  -0.6
charSemicolon      -0.3    0.1  -0.5  -0.3  -0.2
charRoundbracket   -0.1    0.1  -0.2  -0.1   0.0
charSquarebracket  -0.1    0.1  -0.3  -0.1   0.0
charExclamation     0.2    0.1   0.2   0.2   0.3
charDollar          1.2    0.2   1.0   1.3   1.5
charHash            0.7    0.4   0.2   0.7   1.2
capitalAve          0.0    0.3  -0.3   0.0   0.4
capitalLong         0.9    0.4   0.4   0.8   1.3
capitalTotal        0.7    0.1   0.5   0.7   0.8

Fit Diagnostics:
           mean   sd   10%   50%   90%
mean_PPD 0.4    0.0  0.4   0.4   0.4  

The mean_ppd is the sample average posterior predictive distribution of the outcome variable (for details see help('summary.stanreg')).

MCMC diagnostics
                  mcse Rhat n_eff
(Intercept)       0.0  1.0   740 
make              0.0  1.0  1161 
address           0.0  1.0  1156 
all               0.0  1.0   907 
num3d             0.0  1.0   958 
our               0.0  1.0   973 
over              0.0  1.0  1022 
remove            0.0  1.0   909 
internet          0.0  1.0  1188 
order             0.0  1.0  1064 
mail              0.0  1.0  1022 
receive           0.0  1.0  1171 
will              0.0  1.0  1022 
people            0.0  1.0  1307 
report            0.0  1.0   935 
addresses         0.0  1.0  1115 
free              0.0  1.0   830 
business          0.0  1.0  1059 
email             0.0  1.0   895 
you               0.0  1.0  1026 
credit            0.0  1.0  1328 
your              0.0  1.0  1077 
font              0.0  1.0  1226 
num000            0.0  1.0  1124 
money             0.0  1.0   910 
hp                0.0  1.0  1161 
hpl               0.0  1.0  1422 
george            0.1  1.0   779 
num650            0.0  1.0  1017 
lab               0.0  1.0  1477 
labs              0.0  1.0  1317 
telnet            0.0  1.0  1002 
num857            0.0  1.0  1388 
data              0.0  1.0  1265 
num415            0.0  1.0  1427 
num85             0.0  1.0  1247 
technology        0.0  1.0  1021 
num1999           0.0  1.0   834 
parts             0.0  1.0  1227 
pm                0.0  1.0  1135 
direct            0.0  1.0   793 
cs                0.0  1.0  1293 
meeting           0.0  1.0  1547 
original          0.0  1.0  1805 
project           0.0  1.0  1432 
re                0.0  1.0  1027 
edu               0.0  1.0  1048 
table             0.0  1.0  1082 
conference        0.0  1.0  1492 
charSemicolon     0.0  1.0  1308 
charRoundbracket  0.0  1.0   906 
charSquarebracket 0.0  1.0  1286 
charExclamation   0.0  1.0  1070 
charDollar        0.0  1.0  1147 
charHash          0.0  1.0  1111 
capitalAve        0.0  1.0  1165 
capitalLong       0.0  1.0  1421 
capitalTotal      0.0  1.0  1312 
mean_PPD          0.0  1.0   909 
log-posterior     0.5  1.0   338 

For each parameter, mcse is Monte Carlo standard error, n_eff is a crude measure of effective sample size, and Rhat is the potential scale reduction factor on split chains (at convergence Rhat=1).
prior_summary(fit1)
Priors for model 'fit1' 
------
Intercept (after predictors centered)
 ~ normal(location = 0, scale = 10)

Coefficients
 ~ laplace(location = [0,0,0,...], scale = [0.5,0.5,0.5,...])
------
See help('prior_summary.stanreg') for more details
betaNames = names(spam_train)[2:7]
betaNames
[1] "make"    "address" "all"     "num3d"   "our"     "over"   
mcmc_trace(fit1, pars = betaNames)

betaNames = names(spam_train)[2:7]
betaNames
[1] "make"    "address" "all"     "num3d"   "our"     "over"   
mcmc_hist(fit1, pars = c(betaNames))

To plot specific parameters, use the arguemnt pars, e.g.

  • mcmc_trace(fit1, pars = c("internet", "george")
  • mcmc_hist(fit1, pars = "make")

To read more about bayesplot functionality, see https://mc-stan.org/bayesplot/articles/plotting-mcmc-draws.html

chain_draws = as_draws(fit1)
chain_draws$george[1:5] # first 5 samples of the first chain run by stan
[1]  -8.657706  -8.058416  -7.542682  -5.875524 -10.805427
  • try the following command: View(chain_draws)

Report posterior mean, posterior median and 90% posterior CI.

posteriorMean = apply(chain_draws, 2, mean)
posteriorMedian = fit1$coefficients
posteriorCI = posterior_interval(fit1, prob = 0.9)
cbind(posteriorMean, posteriorMedian, posteriorCI)
                  posteriorMean posteriorMedian            5%          95%
(Intercept)         -3.92819676     -3.89842329  -4.688486640 -3.214996079
make                -0.10511093     -0.10291579  -0.234925630  0.007473911
address             -0.23994990     -0.22496984  -0.459838123 -0.079085239
all                  0.06181680      0.06204777  -0.032348802  0.162371165
num3d                0.86239227      0.69068831   0.094477008  2.262920466
our                  0.36326097      0.36517847   0.252660614  0.482222431
over                 0.22388788      0.22526813   0.104903413  0.345376589
remove               0.89046939      0.88519560   0.660560624  1.124715173
internet             0.19219955      0.18873972   0.084482068  0.313373050
order                0.14088547      0.14175852   0.022540540  0.262790934
mail                 0.07091792      0.07049483  -0.004812217  0.148380741
receive             -0.04641076     -0.04704090  -0.155648064  0.056349226
will                -0.19669579     -0.19506749  -0.323302246 -0.079179801
people              -0.05080340     -0.04939546  -0.170276558  0.067469609
report              -0.01152332     -0.01105396  -0.102076144  0.078218538
addresses            0.27381609      0.26410052   0.016663808  0.554382803
free                 0.95909363      0.95496752   0.738679361  1.184988932
business             0.46684266      0.46497563   0.276894271  0.663670058
email                0.10488455      0.10024984   0.001909845  0.222419640
you                  0.10146793      0.10210904  -0.005066755  0.212057793
credit               0.60260737      0.58770051   0.245466733  1.011651069
your                 0.29510195      0.29496631   0.185160138  0.408134909
font                 0.23712284      0.22816521   0.014405002  0.514370136
num000               0.71052738      0.69647322   0.470113449  0.973663603
money                0.20621767      0.19864778   0.101091758  0.340661983
hp                  -3.12311740     -3.10880324  -4.016671505 -2.304686191
hpl                 -1.03485238     -1.03298227  -1.697441864 -0.401214421
george              -8.54020532     -8.47338329 -11.651960267 -5.719147532
num650               0.23079106      0.22722517   0.066179720  0.417323354
lab                 -1.06217022     -0.98480464  -2.007243002 -0.334531315
labs                -0.12552567     -0.11162875  -0.398538494  0.091786582
telnet              -0.34336793     -0.26439840  -0.971412117  0.026123689
num857              -0.24559769     -0.15437261  -1.410789083  0.608907729
data                -0.73133589     -0.71699338  -1.108487928 -0.382954624
num415              -0.64101304     -0.47534569  -2.093090996  0.285482781
num85               -0.79569146     -0.76847847  -1.439452875 -0.215331103
technology           0.39271940      0.38676108   0.173703662  0.625821033
num1999             -0.01626319     -0.01501346  -0.146351106  0.113311375
parts               -0.19387536     -0.17594695  -0.420298806 -0.034525634
pm                  -0.36307678     -0.35286882  -0.654982493 -0.099008434
direct              -0.16079193     -0.14842066  -0.405108058  0.023707001
cs                  -1.35347826     -1.21771081  -2.748096259 -0.377142268
meeting             -1.54394154     -1.48634297  -2.451147222 -0.853955898
original            -0.44492327     -0.43068267  -0.889558420 -0.068304526
project             -1.27091449     -1.25700373  -1.901896255 -0.691331691
re                  -0.68897480     -0.68667471  -0.941480034 -0.439694059
edu                 -1.52342346     -1.51146322  -2.056088445 -1.041884151
table               -0.27888780     -0.25736694  -0.549891396 -0.077155657
conference          -1.10733935     -1.06265166  -1.864536810 -0.460671052
charSemicolon       -0.32588657     -0.31416332  -0.514101364 -0.166688438
charRoundbracket    -0.08698799     -0.08292353  -0.212643020  0.023441788
charSquarebracket   -0.14980893     -0.13280261  -0.342135625 -0.004178006
charExclamation      0.23238206      0.22752584   0.137507390  0.351658871
charDollar           1.24928427      1.25351495   0.944702156  1.558125382
charHash             0.71795213      0.70213467   0.122922001  1.368832748
capitalAve          -0.01033156     -0.03262611  -0.401092859  0.470713740
capitalLong          0.85197631      0.82214223   0.303300262  1.493571408
capitalTotal         0.67337732      0.67245181   0.471182694  0.893198052

Exercise 3

  • Test your spam filter on the spam_test data set.

  • Make a table showing correct and incorrect number of classifications.

Solutions

Exercise 1 solution

fit1 = stan_glm(type ~ ., data = spam_train,
                 family = binomial(link = "logit"),
                 prior = laplace(0, 0.5),
                 prior_intercept = normal(0, 10),
                 cores = 2, seed = 360,
                chains = 1, iter = 2000)

Exercise 2 Solution

  • Trace plots look good, can look through more by subsetting others.

  • ESS is high

which(sign(posteriorCI[,1]) != sign(posteriorCI[,2]))
            make              all             mail          receive 
               2                4               11               12 
          people           report              you             labs 
              14               15               20               31 
          telnet           num857           num415          num1999 
              32               33               35               38 
          direct charRoundbracket       capitalAve 
              41               51               56 

Exercise 3 Solution

data.frame(y = spam_test$type, 
           yhat = predict(object = fit1, newdata = spam_test[,-1], type = "response")) %>%
  mutate(yhat = ifelse(yhat >= 0.5, 1, 0)) %>%
  count(y, yhat)
  y yhat   n
1 0    0 529
2 0    1  26
3 1    0  31
4 1    1 335

856/921 classifications correct with a cutoff of 0.5.