Discriminant_analysis_examples

Raymaekers, J. and Rousseeuw, P.J.

2023-04-23

Introduction

This vignette visualizes classification results from discriminant analysis, using tools from the package.

library("classmap")
library("ggplot2")
## Warning: package 'ggplot2' was built under R version 4.1.3
library("gridExtra")

Iris data

As a first small example, we consider the Iris data. We first load the data and inspect it.

data(iris)
X <- iris[, 1:4]
y <- iris[, 5]
is.factor(y)
## [1] TRUE
table(y)
## y
##     setosa versicolor  virginica 
##         50         50         50
pairs(X, col = as.numeric(y) + 1, pch = 19)
plot of chunk unnamed-chunk-3

Now we carry out quadratic discriminant analysis and inspect the output. Note that we can also do linear discriminant analysis by choosing rule = “LDA”.

vcr.train <- vcr.da.train(X, y, rule = "QDA")
names(vcr.train)
##  [1] "yint"      "y"         "levels"    "predint"   "pred"      "altint"   
##  [7] "altlab"    "PAC"       "figparams" "fig"       "farness"   "ofarness" 
## [13] "classMS"   "lCurrent"  "lPred"     "lAlt"

We now inspect the output in detail. First look at the prediction as integer, the prediction as label, the alternative label as integer and the alternative label:

vcr.train$predint 
##   [1] 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
##  [38] 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 2 2 2
##  [75] 2 2 2 2 2 2 2 2 2 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3
## [112] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3
## [149] 3 3
vcr.train$pred[c(1:10, 51:60, 101:110)]
##  [1] "setosa"     "setosa"     "setosa"     "setosa"     "setosa"    
##  [6] "setosa"     "setosa"     "setosa"     "setosa"     "setosa"    
## [11] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [16] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [21] "virginica"  "virginica"  "virginica"  "virginica"  "virginica" 
## [26] "virginica"  "virginica"  "virginica"  "virginica"  "virginica"
vcr.train$altint  
##   [1] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
##  [38] 2 2 2 2 2 2 2 2 2 2 2 2 2 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3
##  [75] 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2 2 2 2 2
## [112] 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
## [149] 2 2
vcr.train$altlab[c(1:10, 51:60, 101:110)]
##  [1] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
##  [6] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [11] "virginica"  "virginica"  "virginica"  "virginica"  "virginica" 
## [16] "virginica"  "virginica"  "virginica"  "virginica"  "virginica" 
## [21] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"
## [26] "versicolor" "versicolor" "versicolor" "versicolor" "versicolor"

The Probability of Alternative Class (PAC) of each object is found in the $PAC element of the output:

vcr.train$PAC[1:3] 
## [1] 4.918517e-26 7.655808e-19 1.552279e-21
summary(vcr.train$PAC)
##      Min.   1st Qu.    Median      Mean   3rd Qu.      Max. 
## 0.0000000 0.0000000 0.0000081 0.0237098 0.0010938 0.8456517

The $fig element of the output contains the distance from case i to class g. Let’s look at it for the first 5 objects:

vcr.train$fig[1:5, ]
##            [,1] [,2] [,3]
## [1,] 0.02675535    1    1
## [2,] 0.33639794    1    1
## [3,] 0.16134074    1    1
## [4,] 0.25293196    1    1
## [5,] 0.06600114    1    1

From the fig, the farness of each object can be computed. The farness of an object i is the f(i, g) to its own class:

vcr.train$farness[1:5]
## [1] 0.02675535 0.33639794 0.16134074 0.25293196 0.06600114
summary(vcr.train$farness)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##  0.0153  0.2396  0.5159  0.4996  0.7617  0.9862

The “overall farness” of an object is defined as the lowest f(i, g) it has to any class g (including its own):

summary(vcr.train$ofarness)
##    Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
##  0.0153  0.2396  0.5145  0.4957  0.7543  0.9862

Objects with ofarness > cutoff are flagged as “outliers”. These can be included in a separate column in the confusion matrix. This confusion matrix can be computed using confmat.vcr, which also returns the accuracy.

To illustrate this we choose a rather low cutoff:

confmat.vcr(vcr.train, cutoff = 0.98)
## 
## Confusion matrix:
##             predicted
## given        setosa versicolor virginica outl
##   setosa         48          0         0    2
##   versicolor      0         48         2    0
##   virginica       0          1        48    1
## 
## The accuracy is 98%.

With the default cutoff = 0.99 no objects are flagged in this example:

confmat.vcr(vcr.train)
## 
## Confusion matrix:
##             predicted
## given        setosa versicolor virginica
##   setosa         50          0         0
##   versicolor      0         48         2
##   virginica       0          1        49
## 
## The accuracy is 98%.

Note that the accuracy is computed before any objects are flagged, so it does not depend on the cutoff.

The confusion matrix can also be constructed showing class numbers instead of labels. This option can be useful for long level names.

confmat.vcr(vcr.train, showClassNumbers = TRUE)
## 
## Confusion matrix:
##      predicted
## given  1  2  3
##     1 50  0  0
##     2  0 48  2
##     3  0  1 49
## 
## The accuracy is 98%.

A stacked mosaic plot made with the stackedplot() function can be used to visualize the confusion matrix. The outliers, if there are any, appear as grey areas on top.

cols <- c("red", "darkgreen", "blue")
stackedplot(vcr.train, classCols = cols, separSize = 1.5,
            minSize = 1, showLegend = TRUE)
plot of chunk unnamed-chunk-13
stackedplot(vcr.train, classCols = cols, separSize = 1.5,
            minSize = 1, showLegend = TRUE, cutoff = 0.98)
plot of chunk unnamed-chunk-13

The default stacked mosaic plot has no legend:

stplot <- stackedplot(vcr.train, classCols = cols, 
                     separSize = 1.5, minSize = 1,
                     main = "QDA on iris data")
stplot
plot of chunk unnamed-chunk-14

We also make the silhouette plot using the silplot() function:

# pdf("Iris_QDA_silhouettes.pdf", width=5.0, height=4.6)
silplot(vcr.train, classCols = cols, 
        main = "Silhouette plot of QDA on iris data")      
##  classNumber classLabel classSize classAveSi
##            1     setosa        50       1.00
##            2 versicolor        50       0.91
##            3  virginica        50       0.95
plot of chunk unnamed-chunk-15
# dev.off()

We now make the class maps based on the vcr object. This can be done using the classmap() function. We make a separate class map for each of the three classes. We see that class 1 is a very tight class (low PAC, no high farness). Class 2 is not so tight, and has two points which are predicted as virginica. Class 3 has one point predicted as versicolor.

classmap(vcr.train, 1, classCols = cols)
plot of chunk unnamed-chunk-16
classmap(vcr.train, 2, classCols = cols)
plot of chunk unnamed-chunk-16
classmap(vcr.train, 3, classCols = cols) # With the default cutoff no farness values stand out:
plot of chunk unnamed-chunk-16
# With a lower cutoff:
classmap(vcr.train, 3, classCols = cols, cutoff = 0.98)
plot of chunk unnamed-chunk-17
# Now one point is to the right of the vertical line.
# It also has a black border, meaning that it is flagged
# as an outlier, in the sense that its farness to _all_
# classes is above 0.98.

To illustrate the use of new data we create a fake dataset which is a subset of the training data, where not all classes occur, and ynew has NA’s.

Xnew <- X[c(1:50, 101:150), ]
ynew <- y[c(1:50, 101:150)]
ynew[c(1:10, 51:60)] <- NA
pairs(X, col = as.numeric(y) + 1, pch = 19) # 3 colors
plot of chunk unnamed-chunk-18
pairs(Xnew, col = as.numeric(ynew) + 1, pch = 19) # only red and blue
plot of chunk unnamed-chunk-18

Now we build the vcr object on the training data.

vcr.test <- vcr.da.newdata(Xnew, ynew, vcr.train)

Inspect some of the output to confirm that it corresponds with what we would expect:

plot(vcr.test$predint, vcr.train$predint[c(1:50, 101:150)]); abline(0, 1)
plot of chunk unnamed-chunk-20
plot(vcr.test$altint, vcr.train$altint[c(1:50, 101:150)]); abline(0, 1)
plot of chunk unnamed-chunk-20
plot(vcr.test$PAC, vcr.train$PAC[c(1:50, 101:150)]); abline(0, 1)
plot of chunk unnamed-chunk-20
vcr.test$farness 
##   [1]         NA         NA         NA         NA         NA         NA
##   [7]         NA         NA         NA         NA 0.29421328 0.32178116
##  [13] 0.51150351 0.89366298 0.96650511 0.91516067 0.82782724 0.04831270
##  [19] 0.78801603 0.23207165 0.80057706 0.46966655 0.97498614 0.90086633
##  [25] 0.96031572 0.63996116 0.43078605 0.07648892 0.16940167 0.35680444
##  [31] 0.31726541 0.76311424 0.91656430 0.79289643 0.15775639 0.57139387
##  [37] 0.82646629 0.53574220 0.56635148 0.04234576 0.24816964 0.98405396
##  [43] 0.69347810 0.98395599 0.93996829 0.36120046 0.47605905 0.20490701
##  [49] 0.15482827 0.03145084         NA         NA         NA         NA
##  [55]         NA         NA         NA         NA         NA         NA
##  [61] 0.44870413 0.07632847 0.07073258 0.58646621 0.79167886 0.26259281
##  [67] 0.11731858 0.94491559 0.98620090 0.84267322 0.14000705 0.41260605
##  [73] 0.85509019 0.51571738 0.14623030 0.51605016 0.53499547 0.40202972
##  [79] 0.10552473 0.74488824 0.53885697 0.96493110 0.23632013 0.66887310
##  [85] 0.93026914 0.83628519 0.68987941 0.32345397 0.49557670 0.35256862
##  [91] 0.32403877 0.91811245 0.26632690 0.15321938 0.50795409 0.69766299
##  [97] 0.63534800 0.11247730 0.62691061 0.42737442
plot(vcr.test$farness, vcr.train$farness[c(1:50, 101:150)]); abline(0, 1)
plot of chunk unnamed-chunk-20
plot(vcr.test$fig, vcr.train$fig[c(1:50, 101:150), ]); abline(0, 1)
plot of chunk unnamed-chunk-20
vcr.test$ofarness 
##   [1] 0.02675535 0.33639794 0.16134074 0.25293196 0.06600114 0.63210603
##   [7] 0.59041424 0.01732745 0.52024594 0.55494759 0.29421328 0.32178116
##  [13] 0.51150351 0.89366298 0.96650511 0.91516067 0.82782724 0.04831270
##  [19] 0.78801603 0.23207165 0.80057706 0.46966655 0.97498614 0.90086633
##  [25] 0.96031572 0.63996116 0.43078605 0.07648892 0.16940167 0.35680444
##  [31] 0.31726541 0.76311424 0.91656430 0.79289643 0.15775639 0.57139387
##  [37] 0.82646629 0.53574220 0.56635148 0.04234576 0.24816964 0.98405396
##  [43] 0.69347810 0.98395599 0.93996829 0.36120046 0.47605905 0.20490701
##  [49] 0.15482827 0.03145084 0.93068594 0.26632690 0.06831897 0.31105631
##  [55] 0.20258388 0.65479346 0.90867003 0.58975324 0.56295693 0.68728265
##  [61] 0.44870413 0.07632847 0.07073258 0.58646621 0.79167886 0.26259281
##  [67] 0.11731858 0.94491559 0.98620090 0.84267322 0.14000705 0.41260605
##  [73] 0.85509019 0.51571738 0.14623030 0.51605016 0.53499547 0.40202972
##  [79] 0.10552473 0.74488824 0.53885697 0.96493110 0.23632013 0.66887310
##  [85] 0.93026914 0.83628519 0.68987941 0.32345397 0.49557670 0.35256862
##  [91] 0.32403877 0.91811245 0.26632690 0.15321938 0.50795409 0.69766299
##  [97] 0.63534800 0.11247730 0.62691061 0.42737442
plot(vcr.test$ofarness, vcr.train$ofarness[c(1:50, 101:150)]); abline(0, 1)
plot of chunk unnamed-chunk-20

The confusion matrix for the test data, as for the training data, can be constructed by the confmat.vcr() function. A cutoff of 0.98 flags three outliers in this example.

confmat.vcr(vcr.test)
## 
## Confusion matrix:
##            predicted
## given       setosa versicolor virginica
##   setosa        40          0         0
##   virginica      0          1        39
## 
## The accuracy is 98.75%.
confmat.vcr(vcr.test, cutoff = 0.98)
## 
## Confusion matrix:
##            predicted
## given       setosa versicolor virginica outl
##   setosa        38          0         0    2
##   virginica      0          1        38    1
## 
## The accuracy is 98.75%.

Also the stacked mosaic plot can be constructed on the test data:

stplot # to compare with:
plot of chunk unnamed-chunk-22
stackedplot(vcr.test, classCols = cols, separSize = 1.5, minSize = 1)
## 
## Not all classes occur in these data. The classes to plot are:
## [1] 1 3
plot of chunk unnamed-chunk-22

We now make the silhouette plot on the test data:

#pdf("Iris_test_QDA_silhouettes.pdf", width=5.0, height=4.3)
silplot(vcr.test, classCols = cols, 
        main = "Silhouette plot of QDA on iris subset") 
##  classNumber classLabel classSize classAveSi
##            1     setosa        40       1.00
##            3  virginica        40       0.94
plot of chunk unnamed-chunk-23
#dev.off()

Finally, we construct the class maps for the test data. We compare the class map of the training data with that of the test data for each class.

classmap(vcr.train, 1, classCols = cols)
plot of chunk unnamed-chunk-24
classmap(vcr.test, 1, classCols = cols) 
plot of chunk unnamed-chunk-24
classmap(vcr.train, 2, classCols = cols)
plot of chunk unnamed-chunk-25
classmap(vcr.test, 2, classCols = cols)
## Error in classmap(vcr.test, 2, classCols = cols): Class number 2 with label versicolor has no objects to visualize.
classmap(vcr.train, 3, classCols = cols)
plot of chunk unnamed-chunk-26
classmap(vcr.test, 3, classCols = cols) 
plot of chunk unnamed-chunk-26

Floral buds data (shown in paper)

We now analyze the floral buds data, which was also used as an illustration in the paper. First load and inspect the data.

data(data_floralbuds)
X <- as.matrix(data_floralbuds[, 1:6])
y <- data_floralbuds$y
dim(X) # 550  6
## [1] 550   6
length(y) # 550
## [1] 550
table(y)
## y
##  branch     bud  scales support 
##      49     363      94      44
# branch     bud  scales support 
#     49     363      94      44 

# Pairs plot
cols <- c("saddlebrown", "orange", "olivedrab4", "royalblue3")
pairs(X, gap = 0, col = cols[as.numeric(y)]) # hard to separate visually
plot of chunk unnamed-chunk-27

Now we perform quadratic discriminant analysis:

vcr.obj <- vcr.da.train(X, y)

Construct the confusion matrix without and with outliers shown:

confmat <- confmat.vcr(vcr.obj, showOutliers = FALSE)
## 
## Confusion matrix:
##          predicted
## given     branch bud scales support
##   branch      45   1      1       2
##   bud          0 358      1       4
##   scales       2   0     90       2
##   support      6   3      0      35
## 
## The accuracy is 96%.
confmat.vcr(vcr.obj) 
## 
## Confusion matrix:
##          predicted
## given     branch bud scales support outl
##   branch      45   1      1       2    0
##   bud          0 353      1       4    5
##   scales       2   0     86       2    4
##   support      6   3      0      35    0
## 
## The accuracy is 96%.

Construct the stacked mosaic plot:

stackedplot(vcr.obj, classCols = cols, separSize = 0.6,
            minSize = 1.5,  main = "stacked plot of QDA on floral buds")
plot of chunk unnamed-chunk-30
# Version in paper:
# pdf("Floralbuds_QDA_stackplot_without_outliers.pdf",
#     width=5, height=4.3)
# stackedplot(vcr.obj, classCols = cols, separSize = 0.6,
#             minSize = 1.5, showOutliers = FALSE,
#             htitle = "given class", vtitle = "predicted class")
# dev.off()

Now make the silhouette plot:

#pdf("Floralbuds_QDA_silhouettes.pdf", width=5.0, height=4.3)
silplot(vcr.obj, classCols = cols,
        main = "Silhouette plot of QDA on floral bud data")      
##  classNumber classLabel classSize classAveSi
##            1     branch        49       0.75
##            2        bud       363       0.96
##            3     scales        94       0.93
##            4    support        44       0.57
plot of chunk unnamed-chunk-31
#dev.off()

The quasi residual plot can be made with the qresplot() function. We illustate this below by making the quasi residual plot against the sum of the variables. A correlation test confirms that the images with higher sums are significantly easier to classify:

PAC <- vcr.obj$PAC
feat <- rowSums(X); xlab = "rowSums(X)"
# pdf("Floralbuds_QDA_quasi_residual_plot.pdf", width=5, height=4.8)
qresplot(PAC, feat, xlab = xlab, plotErrorBars = TRUE, fac = 2, 
         main = "Floral buds: quasi residual plot")
plot of chunk unnamed-chunk-32
# dev.off()

cor.test(feat, PAC, method = "spearman") 
## 
## 	Spearman's rank correlation rho
## 
## data:  feat and PAC
## S = 39255896, p-value < 2.2e-16
## alternative hypothesis: true rho is not equal to 0
## sample estimates:
##        rho 
## -0.4156944

Construct the class maps, as shown in the paper:

labels <- c("branch", "bud", "scale", "support")

# classmap of class "bud"
#
# To identify the points that stand out:
# classmap(vcr.obj, 2, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf("Floralbuds_QDA_classmap_bud.pdf", width=7, height=7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.obj, 2, classCols = cols,
         main = "predictions of buds",
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
         cex.main = 1.5) 
# For marking points:
indstomark <- c(294, 70, 69, 152, 204) # from identify = TRUE above
labs  <- letters[seq_len(5)]
xvals <- coords[indstomark, 1] +
  c(0, 0.10, 0.14, 0.10, 0.08) # visual finetuning
yvals <- coords[indstomark, 2] +
  c(0.04, 0.04, 0, -0.03, +0.04)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("topleft", fill = cols[1:4], legend = labels, 
       cex = 1, ncol = 1, bg = "white")
plot of chunk unnamed-chunk-33
# dev.off()
par(oldpar)

All class maps:

#
# pdf(file = "Floralbuds_all_class_maps.pdf", width = 7, height = 7)
par(mfrow = c(2, 2))
par(mar = c(3.3, 3.2, 2.7, 1.0))
classmap(vcr.obj, 1, classCols = cols,
         main = "predictions of branches")
legend("topright", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
#
par(mar = c(3.3, 0.5, 2.7, 0.3))
classmap(vcr.obj, 2, classCols = cols,
         main = "predictions of buds")
labs  <- letters[seq_len(5)]
xvals <- coords[indstomark, 1] +
  c(0, 0.10, 0.14, 0.10, 0.08) # visual finetuning
yvals <- coords[indstomark, 2] +
  c(0.04, 0.04, 0, -0.03, 0.04)
# xvals <- c( 1.75, 1.68, 1.25, 3.25, 4.00)
# yvals <- c(0.045, 0.92, 0.54, 0.97, 0.045)
text(x = xvals, y = yvals, labels = labs, cex = 1.0)
legend("topleft", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
#
par(mar = c(3.3, 3.2, 2.7, 1.0))
classmap(vcr.obj, 3, classCols = cols,
         main = "predictions of scales")
legend("left", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
# 
par(mar = c(3.3, 0.5, 2.7, 0.3))
classmap(vcr.obj, 4, classCols = cols,
         main = "predictions of supports")
legend("topright", fill = cols, legend = labels,
       cex = 1, ncol = 1, bg = "white")
plot of chunk unnamed-chunk-34
# dev.off()
par(oldpar)

MNIST data

We now analyze the MNIST data, originally from the website of Yann LeCun. As the link on his website is currently down, we use a different source. Note that downloading the data may take a minute or two, depending on the speed of the internet connection.

mnist_url <- "https://wis.kuleuven.be/statdatascience/robust/data/mnist-rdata"
url.exists <- suppressWarnings(try(open.connection(url(mnist_url), open = "rt", timeout = 2),  silent = TRUE)[1], classes = "warning")

if (is.null(url.exists)) {load(url(mnist_url))} else {
  print(paste("The data source ", mnist_url, "is not active at the moment. The example can nevertheless be reproduced by downloading the mnist data from another source, formatting the training data to dimensions 60000 x 28 x 28, and running the code below."))
}
close(url(mnist_url))
X_train <- mnist$train$x
y_train <- as.factor(mnist$train$y)

head(y_train)
## [1] 5 0 4 1 9 2
## Levels: 0 1 2 3 4 5 6 7 8 9
# Levels: 0 1 2 3 4 5 6 7 8 9
dim(X_train) # 60000    28    28
## [1] 60000    28    28
length(y_train) # 60000
## [1] 60000

We now inspect the data by plotting a few images

plotImage = function(tempImage) {
  tdm = reshape2::melt(apply((tempImage), 2, rev))
  p = ggplot(tdm, aes(x = Var2, y = Var1, fill = (value))) +
    geom_raster() +
    guides(color = "none", size = "none", fill = "none") +
    theme(axis.title.x = element_blank(),
          axis.title.y = element_blank(),
          axis.text.x = element_blank(),
          axis.text.y = element_blank(),
          axis.ticks.x = element_blank(),
          axis.ticks.y = element_blank()) +
    scale_fill_gradient(low = "white", high = "black")
  p
}

plotImage(X_train[1, , ])
plot of chunk unnamed-chunk-36
plotImage(X_train[2, , ])
plot of chunk unnamed-chunk-36
plotImage(X_train[3, , ])
plot of chunk unnamed-chunk-36

We now unfold the array containing the data to a matrix, and inspect some sample images as well as the average image per digit:

# Change the dimensions of X for the sequel:
dim(X_train) <- c(60000, 28 * 28)
dim(X_train) # 60000    784
## [1] 60000   784
# Sampled digit images:
set.seed(123)
sampledigits <- list()
for (i in 0:9) {
  digit <- i
  idx <- sample(which(y_train == digit), size = 1)
  tempImage <- matrix(unlist(X_train[idx, ]), 28, 28)
  sampledigits[[i + 1]] <- plotImage(tempImage) 
}
psampledigits <- grid.arrange(grobs = sampledigits, ncol = 5)
plot of chunk unnamed-chunk-37
# ggsave("MNIST_sampled_images.pdf", plot = psampledigits,
#        width = 10, height = 1)


# Averaged digit images:
meanPlots <- list()
for (j in 0:9) {
  m.out <- colMeans(X_train[which(y_train == j), ])
  dim(m.out) <- c(28, 28)
  meanPlots[[j + 1]] <- plotImage(m.out) 
}
meanplot <- grid.arrange(grobs = meanPlots, ncol = 5)
plot of chunk unnamed-chunk-37
# ggsave("MNIST_averaged_images.pdf", plot = meanplot,
#        width = 10, height = 1)

Before performing discriminant analysis, we reduce the dimension of the data by PCA.

library(svd)
## Warning: package 'svd' was built under R version 4.1.3
ptm <- proc.time()
svd.out <- svd::propack.svd(X_train, neig = 50)
(proc.time() - ptm)[3]
## elapsed 
##   22.12
loadings <- svd.out$v
rm(svd.out)
dataProj <- as.matrix(X_train %*% loadings)
dim(dataProj)
## [1] 60000    50

Now we perform discriminant analysis, which takes roughly 5 seconds.

vcr.train <- vcr.da.train(X = dataProj, y_train)

We compute the confusion matrix and make the stacked mosaic plot:

confmat.vcr(vcr.train, showOutliers = FALSE)
## 
## Confusion matrix:
##      predicted
## given    0    1    2    3    4    5    6    7    8    9
##     0 5833    0   22    6    1   14    2    0   42    3
##     1    0 6436  104   14   32    0    2   13  138    3
##     2   14    1 5807   26   14    0    9   12   69    6
##     3    3    1   88 5821    4   52    0   18  120   24
##     4    6    1   21    3 5704    1   12   14   30   50
##     5   14    0    4   71    2 5222   17    0   78   13
##     6   27    2    6    2    8  114 5703    0   56    0
##     7   13    8   94   14   34   14    0 5936   54   98
##     8   10   24   40   72    8   40    2    4 5625   26
##     9   17    2   23   65   59   14    1   77   93 5598
## 
## The accuracy is 96.14%.
cols <- c("red3", "darkorange", "gold2", "darkolivegreen3",
         "darkolivegreen4", "cadetblue3", "deepskyblue4", 
         "darkslateblue", "darkorchid3", "deeppink4")

# stacked plot in paper:
# pdf("MNIST_stackplot_with_outliers.pdf", width=5, height=4.3)
stackedplot(vcr.train, classCols = cols, separSize = 0.6,
            minSize = 1.5, htitle = "given class",
            main = "Stacked plot of QDA on MNIST training data", vtitle = "predicted class")
plot of chunk unnamed-chunk-40
# dev.off()

The silhouette plot:

# pdf("MNIST_QDA_silhouettes.pdf", width=5.0, height=4.6)
silplot(vcr.train, classCols = cols,
        main = "Silhouette plot of QDA on MNIST training data")      
##  classNumber classLabel classSize classAveSi
##            1          0      5923       0.97
##            2          1      6742       0.91
##            3          2      5958       0.95
##            4          3      6131       0.90
##            5          4      5842       0.95
##            6          5      5421       0.92
##            7          6      5918       0.93
##            8          7      6265       0.89
##            9          8      5851       0.92
##           10          9      5949       0.88
plot of chunk unnamed-chunk-41
# dev.off()

Now we make the class maps.

wnq <- function(string, qwrite=TRUE) { # auxiliary function
  # writes a line without quotes
  if (qwrite) write(noquote(string), file = "", ncolumns = 100)
}

showdigit <- function(digit=digit, i, plotIt = TRUE) {
  idx = which(y_train == digit)[i]
  # wnq(paste("Estimated digit: ", as.numeric(vcr.train$pred[idx]), sep=""))
  tempImage <- matrix(unlist(X_train[idx, ]), 28, 28)
  if (plotIt) {plot(plotImage(tempImage))}
  return(plotImage(tempImage))
}

Class map of digit 0, shown in paper:

digit <- 0
#
# To identify outliers:
# classmap(vcr.train, digit+1, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf(paste0("MNIST_classmap_digit", digit, ".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.train, digit + 1, classCols = cols,
         main = paste0("predictions of digit ",digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
indstomark <- c(4000, 3964, 5891, 2485, 822, 
               2280, 2504, 3906, 5869, 1034) # from identify = TRUE
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(-0.04, -0.01, 0, -0.11, 0.06,
    0.07, 0.06, 0.10, 0.06, 0.09)
yvals <- coords[indstomark, 2] +
  c(-0.03, -0.03, -0.03, 0.022, -0.025, 
    -0.025, -0.035, -0.025, 0.03, 0.03)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("left", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
plot of chunk unnamed-chunk-43
# dev.off()
par(oldpar)
pred <- vcr.train$pred # needed for discussion plots
tempPreds <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit(digit, idx, plotIt = FALSE)
  tempplot <- arrangeGrob(tempplot, 
    bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] = tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, ncol = 5)
plot of chunk unnamed-chunk-44
# ggsave(paste0("MNIST_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))

Class map of digit 1, shown in paper:

digit <- 1
# pdf(paste0("MNIST_classmap_digit", digit, ".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
classmap(vcr.train, digit + 1, classCols = cols,
         main = paste0("predictions of digit ", digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
legend("left", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
plot of chunk unnamed-chunk-45
# dev.off()
par(oldpar)
# indices of the 1s predicted as 2 (takes a while):
#
indstomark <- which(vcr.train$predint[which(y_train == digit)] == 3)
length(indstomark) # 104
## [1] 104
labs  <- letters[1:length(indstomark)]
pred <- vcr.train$pred # needed for discussion plots
tempPreds    <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
        bottom = paste0("\"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 8)
plot of chunk unnamed-chunk-46
# ggsave(paste0("MNIST_discussionplot_digit", digit, "predictedAs2b.pdf"),
#        plot = discussionPlot, width = 10,
#        height = (length(indstomark) %/% 10 +
#                    (length(indstomark) %% 10 > 0)))
                 
# The digits 1 predicted as a 2 are mostly ones written with
# a horizontal line at the bottom.

Class map of digit 2:

digit <- 2
# To identify outliers:
# classmap(vcr.train, digit + 1, classCols = cols, identify = TRUE)
# Press "Esc" to get out.
#
# pdf(paste0("MNIST_classmap_digit", digit,".pdf"), width = 7, height = 7)
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.train, digit + 1, classCols = cols,
                  main = paste0("predictions of digit", digit), cex = 1.5, cex.lab = 1.5, cex.axis = 1.5,
                  cex.main = 1.5)
indstomark <- c(3164, 5434, 2319 , 4224, 3682, 
               2642, 4920, 1233, 3741, 3993) # from identify = TRUE
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(0, 0.08, 0, 0, 0, 0, 0, 0, 0, 0)
yvals <- coords[indstomark, 2] +
  c(-0.03, -0.03, -0.03, -0.03, -0.03, 
    -0.03, -0.03, -0.03, 0.03, 0.03)  
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("right", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
plot of chunk unnamed-chunk-47
# dev.off()
par(oldpar)
pred <- vcr.train$pred # needed for discussion plots
tempPreds    <- (pred[which(y_train == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
        bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 5)
plot of chunk unnamed-chunk-48
# ggsave(paste0("MNIST_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))

Now we analyze the MNIST test data. First load and inspect the data, and project it onto the PCA subspace extracted from the training data.

X_test <- mnist$test$x
y_test <- as.factor(mnist$test$y)

head(y_test)
## [1] 7 2 1 0 4 1
## Levels: 0 1 2 3 4 5 6 7 8 9
#
dim(X_test) # 10000    28    28
## [1] 10000    28    28
length(y_test) # 10000
## [1] 10000
plotImage(X_test[1, , ])
plot of chunk unnamed-chunk-49
plotImage(X_test[2, , ])
plot of chunk unnamed-chunk-49
plotImage(X_test[3, , ])
plot of chunk unnamed-chunk-49
dim(X_test) <- c(10000, 28 * 28)
dim(X_test) # 10000  784
## [1] 10000   784
dataProj_test <- as.matrix(X_test %*% loadings)

Now prepare the VCR object:

vcr.test <- vcr.da.newdata(Xnew = dataProj_test,
                           ynew = y_test,
                           vcr.da.train.out = vcr.train)

Build the confusion matrix and plot a stacked mosaic plot of the classification performance on the test data:

confmat.vcr(vcr.test, showOutliers = FALSE, showClassNumbers = TRUE)
## 
## Confusion matrix:
##      predicted
## given    1    2    3    4    5    6    7    8    9   10
##    1   970    0    1    0    0    2    1    1    5    0
##    2     0 1097   11    3    2    1    1    0   20    0
##    3     2    0 1002    3    3    0    2    1   19    0
##    4     1    0    9  972    0    5    0    2   17    4
##    5     0    0    4    0  965    0    3    2    2    6
##    6     2    0    1   18    0  859    1    1   10    0
##    7     8    1    2    0    4   12  924    0    7    0
##    8     1    2   28    1    3    2    0  958   14   19
##    9     3    0    9   12    1    5    1    2  935    6
##    10    5    1   11    6   10    2    0    6   18  950
## 
## The accuracy is 96.32%.
# In supplementary material:
# pdf("MNISTtest_stackplot_with_outliers.pdf", width = 5, height = 4.3)
stackedplot(vcr.test, classCols = cols, separSize = 0.6,
            main = "Stacked plot of QDA on MNIST test data",
            minSize = 1.5)
plot of chunk unnamed-chunk-51
# dev.off()

Silhouette plot:

#pdf("MNIST_test_QDA_silhouettes.pdf", width = 5.0, height = 4.6)
silplot(vcr.test, classCols = cols,
        main = "Silhouette plot of QDA on MNIST test data")      
##  classNumber classLabel classSize classAveSi
##            1          0       980       0.98
##            2          1      1135       0.93
##            3          2      1032       0.94
##            4          3      1010       0.92
##            5          4       982       0.96
##            6          5       892       0.92
##            7          6       958       0.93
##            8          7      1028       0.86
##            9          8       974       0.92
##           10          9      1009       0.88
plot of chunk unnamed-chunk-52
#dev.off()

Now we can construct the class maps on the test data. First for digit 0:

showdigit_test <- function(digit = digit, i, plotIt = TRUE) {
  idx = which(y_test == digit)[i]
  # wnq(paste("Estimated digit: ", as.numeric(vcr.test$pred[idx]), sep = ""))
  tempImage <- matrix(unlist(X_test[idx, ]), 28, 28)
  if (plotIt) {plot(plotImage(tempImage))}
  return(plotImage(tempImage))
}

digit <- 0
# classmap(vcr.test, digit+1, classCols = cols, identify = TRUE)
# pdf(paste0("MNISTtest_classmap_digit", digit,".pdf"))
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.test, digit + 1, classCols = cols,
         main = paste0("predictions of digit ", digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
indstomark <- c(140, 630, 241, 967, 189,
               377, 78, 943, 64, 354)
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(0.08, 0.07, -0.07, 0.06, 0,
    0.04, 0.05, 0.09, -0.04, 0.09)
yvals <- coords[indstomark, 2] +
  c(-0.025, -0.03, -0.024, -0.025, -0.03, 
    -0.03, -0.03, 0.022, 0.035, 0.03)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("left", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
plot of chunk unnamed-chunk-53
# dev.off()
par(oldpar)
pred <- vcr.test$pred # needed for discussion plots
tempPreds <- (pred[which(y_test == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit_test(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
      bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 5)
plot of chunk unnamed-chunk-54
# ggsave(paste0("MNISTtest_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))

Now for digit 3:

digit <- 3
# classmap(vcr.test, digit + 1, classCols = cols, identify = TRUE)
# pdf(paste0("MNISTtest_classmap_digit", digit, ".pdf"))
par(mar = c(3.6, 3.5, 2.4, 3.5))
coords <- classmap(vcr.test, digit + 1, classCols = cols,
         main = paste0("predictions of digit ", digit),
         cex = 1.5, cex.lab = 1.5, cex.axis = 1.5, 
         cex.main = 1.5)
indstomark <- c(883, 659, 262, 60, 310,
               832, 223, 784, 835, 289)
labs  <- letters[1:length(indstomark)]
xvals <- coords[indstomark, 1] +
  c(-0.01, 0.08, -0.10, 0.06, 0.07, 
    0.06, 0.03, 0.11, 0.02, 0.06)
yvals <- coords[indstomark, 2] +
  c(0.035, 0.033, -0.017, -0.022, -0.025, 
    -0.025, -0.033, -0.022, 0.035, 0.038)
text(x = xvals, y = yvals, labels = labs, cex = 1.5)
legend("right", fill = cols,
       legend = 0:9, cex = 1, ncol = 2, bg = "white")
plot of chunk unnamed-chunk-55
# dev.off()
par(oldpar)
pred <- vcr.test$pred # needed for discussion plots
tempPreds    <- (pred[which(y_test == digit)])[indstomark]
discussionPlots <- list()
for (i in 1:length(indstomark)) {
  idx <- indstomark[i]
  tempplot <- showdigit_test(digit, idx, FALSE)
  tempplot <- arrangeGrob(tempplot, 
    bottom = paste0("(", labs[i], ") \"", tempPreds[i], "\""))
  discussionPlots[[i]] <- tempplot
}
discussionPlot <- grid.arrange(grobs = discussionPlots, 
                                         ncol = 5)
plot of chunk unnamed-chunk-56
# ggsave(paste0("MNISTtest_discussionplot_digit", digit, ".pdf"),
#        plot = discussionPlot, width = 5,
#        height = (length(indstomark) %/% 5 +
#                    (length(indstomark) %% 5 > 0)))

mirror server hosted at Truenetwork, Russian Federation.