In this lab we will learn to code an Artificial Neural Network (ANN) in R both from scratch and using Keras / Tensorflow library. For simplicity, we consider a problem of linearly separable two classes of data points.


1 Problem formulation

Let us consider a simple problem of only two features, i.e. \(X_1\) and \(X_2\), and only four statistical observations (data points) that belong to two classes: circles and crosses. The four data points are schematically depicted below using \(X_1\) vs. \(X_2\) coordinate system. Obviously, the circles and crosses are separable with a linear decision boundary, i.e. hyperplane.

Let us implement the simplest possible feedforward / dense Artificial Neural Network (ANN) without hidden layers using Keras / Tensorflow library. Later, we will reproduce the results from Keras using from scratch coding ANN in R.

The architecture of the simplest ANN is displayed above, and includes two input nodes (two feature vectors \(X_1\) and \(X_2\)) and one output node, where the two classes are coded in the following way: circles are codded as 0 and crosses as 1. The weights \(w_1\) and \(w_2\) of the edges of the ANN graph are the fitting parameters of the model.

2 Keras solution

Let us first define the X matrix of the feature vectors and the y vector of the class labels:

X <- matrix(c(c(0, 0, 1, 1), c(0, 1, 0, 1)), ncol = 2)
X
##      [,1] [,2]
## [1,]    0    0
## [2,]    0    1
## [3,]    1    0
## [4,]    1    1
y <- c(0, 0, 1, 1)
y
## [1] 0 0 1 1

Now, let us define a sequential Keras model of the ANN corresponding to the scheme above and print the summary of the model. Here, we are going to use Sigmoid activation function on the output node because we have a binary classification problem.

library("keras")
model <- keras_model_sequential() 
model %>% layer_dense(units = 1, activation = 'sigmoid', input_shape = c(2))
summary(model)
## Model: "sequential_1"
## ___________________________________________________________________
## Layer (type)                  Output Shape              Param #    
## ===================================================================
## dense_1 (Dense)               (None, 1)                 3          
## ===================================================================
## Total params: 3
## Trainable params: 3
## Non-trainable params: 0
## ___________________________________________________________________

Next, we are going to compile and fit the Keras ANN model. Again, for simplicity, we are going to use Mean Squared Error (MSE) loss function, and Stochastic Gradient Descent (SGD) as an optimization algorithm (with a high learning rate 0.1). The training will be for 3000 epochs, it should be enough for MSE to reach zero.

model %>% compile(loss = 'mean_squared_error',optimizer = optimizer_sgd(lr = 0.1))
history <- model %>% fit(X, y, epochs = 3000)
plot(history$metrics$loss~seq(1:length(history$metrics$loss)),xlab="Epochs",ylab="Loss",col="blue",cex=0.5)

Finally, we will make predictions on the same data set. Overfitting is not a concern here because we want to make sure that the model was capable of linearly separating the two classes of data points.

model %>% predict(X)
##            [,1]
## [1,] 0.08816894
## [2,] 0.07254702
## [3,] 0.94130015
## [4,] 0.92842996
model %>% predict_classes(X)
##      [,1]
## [1,]    0
## [2,]    0
## [3,]    1
## [4,]    1

It looks like the Keras model successfully can assign correct labels to the four data points.

3 Coding ANN from scratch in R

Now we are going to implement the same ANN architecture from scratch in R. This will allow us to better understand the concepts like learning rate, gradient descent as well as to get an intuition of forward- and back-propagation. First of all, let us denote the sigmoid activation function on the output node as

\[\phi(s)=\frac{1}{1+\exp^{\displaystyle -s}}\]

The beauty of this function is that it has a simple derivative that is expressed through the sigmoid function itself:

\[\phi^\prime(s)=\phi(s)\left(1-\phi(s)\right)\] Next, the loss MSE function, i.e. the squared difference between the prediction y and the truth d, is given by the following simple equation:

\[E=\frac{1}{2}\sum_{i=1}^N\left(d_i-y_i\right)^2\] Finally, the gradient descent update rule can be written as follows:

\[w_{1,2}=w_{1,2}+\mu\frac{\partial E}{\partial w_{1,2}}\] \[\frac{\partial E}{\partial w_{1,2}}=-(d-y)*y*(1-y)*x_{1,2}\]

where \(\mu\) is a learning rate. Let us put it all together in a simple for-loop that updates the fitting parameters \(w_1\) and \(w_2\) via minimizing the mean squared error:

phi <- function(x){return(1/(1 + exp(-x)))}

X <- matrix(c(c(0, 0, 1, 1), c(0, 1, 0, 1)), ncol = 2)
d <- matrix(c(0, 0, 1, 1), ncol = 1)

mu <- 0.1; N_epochs <- 10000; E <- vector()
w <- matrix(c(0.1, 0.5), ncol = 1) #initialization of weights w1 and w2
for(epochs in 1:N_epochs)
{
  #Forward propagation
  y <- phi(X %*% w - 3) #here we use -3 as a fixed bias (not optimized for simplicity)
  
  #Backward propagation
  E <- append(E, sum((d-y)^2))
  dE_dw <- (d-y) * y * (1-y)
  w <- w + mu * (t(X) %*% dE_dw)
}
plot(E ~ seq(1:N_epochs), cex = 0.5, xlab = "Epochs", ylab = "Error", col="red")

The mean squared error seems to be decreasing and reaching zero. Let us display the final y vector of predicted labels, it should be equal to the d vector of true labels.

y
##            [,1]
## [1,] 0.04742587
## [2,] 0.03166204
## [3,] 0.98444378
## [4,] 0.97650411

Indeed, the predicted values of labels are very close to the true ones and similar to the ones obtained from Keras solution. Well done, we have successfully implemented an ANN from scratch in R!

4 Session info

## R version 4.1.0 (2021-05-18)
## Platform: x86_64-conda-linux-gnu (64-bit)
## Running under: Ubuntu 20.04.2 LTS
## 
## Matrix products: default
## BLAS/LAPACK: /home/roy/miniconda3/envs/r-4.1/lib/libopenblasp-r0.3.15.so
## 
## locale:
##  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
##  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
##  [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] parallel  stats4    stats     graphics  grDevices utils    
## [7] datasets  methods   base     
## 
## other attached packages:
##  [1] keras_2.4.0                 fontawesome_0.2.1          
##  [3] captioner_2.2.3             spatstat_2.1-0             
##  [5] spatstat.linnet_2.2-1       spatstat.core_2.2-0        
##  [7] rpart_4.1-15                nlme_3.1-152               
##  [9] spatstat.geom_2.2-0         spatstat.data_2.1-0        
## [11] MASS_7.3-54                 sm_2.2-5.6                 
## [13] xaringanExtra_0.5.2         bookdown_0.22              
## [15] knitr_1.33                  future_1.21.0              
## [17] BiocParallel_1.26.0         SingleR_1.6.1              
## [19] celldex_1.2.0               SummarizedExperiment_1.22.0
## [21] Biobase_2.52.0              GenomicRanges_1.44.0       
## [23] GenomeInfoDb_1.28.0         IRanges_2.26.0             
## [25] S4Vectors_0.30.0            BiocGenerics_0.38.0        
## [27] MatrixGenerics_1.4.0        matrixStats_0.59.0         
## [29] patchwork_1.1.1             ggplot2_3.3.4              
## [31] stringr_1.4.0               tidyr_1.1.3                
## [33] dplyr_1.0.7                 SeuratObject_4.0.2         
## [35] Seurat_4.0.3               
## 
## loaded via a namespace (and not attached):
##   [1] rappdirs_0.3.3                highcharter_0.8.2            
##   [3] scattermore_0.7               bit64_4.0.5                  
##   [5] irlba_2.3.3                   DelayedArray_0.18.0          
##   [7] data.table_1.14.0             KEGGREST_1.32.0              
##   [9] RCurl_1.98-1.3                generics_0.1.0               
##  [11] snow_0.4-3                    ScaledMatrix_1.0.0           
##  [13] callr_3.7.0                   cowplot_1.1.1                
##  [15] RSQLite_2.2.5                 shadowtext_0.0.8             
##  [17] RANN_2.6.1                    tensorflow_2.5.0             
##  [19] bit_4.0.4                     rlist_0.4.6.1                
##  [21] lubridate_1.7.10              httpuv_1.6.1                 
##  [23] assertthat_0.2.1              gower_0.2.2                  
##  [25] xfun_0.24                     jquerylib_0.1.4              
##  [27] evaluate_0.14                 promises_1.2.0.1             
##  [29] fansi_0.5.0                   dbplyr_2.1.1                 
##  [31] readxl_1.3.1                  igraph_1.2.6                 
##  [33] DBI_1.1.1                     quantmod_0.4.18              
##  [35] htmlwidgets_1.5.3             purrr_0.3.4                  
##  [37] ellipsis_0.3.2                RSpectra_0.16-0              
##  [39] backports_1.2.1               V8_3.4.2                     
##  [41] nuclon_0.0.1                  deldir_0.2-10                
##  [43] sparseMatrixStats_1.4.0       vctrs_0.3.8                  
##  [45] SingleCellExperiment_1.14.1   remotes_2.4.0                
##  [47] SeuratDisk_0.0.0.9019         TTR_0.24.2                   
##  [49] ROCR_1.0-11                   abind_1.4-5                  
##  [51] cachem_1.0.5                  withr_2.4.2                  
##  [53] sctransform_0.3.2             xts_0.12.1                   
##  [55] prettyunits_1.1.1             goftest_1.2-2                
##  [57] cluster_2.1.2                 ExperimentHub_2.0.0          
##  [59] lazyeval_0.2.2                crayon_1.4.1                 
##  [61] hdf5r_1.3.3                   recipes_0.1.16               
##  [63] pkgconfig_2.0.3               labeling_0.4.2               
##  [65] nnet_7.3-16                   rlang_0.4.11                 
##  [67] globals_0.14.0                lifecycle_1.0.0              
##  [69] miniUI_0.1.1.1                filelock_1.0.2               
##  [71] dbscan_1.1-8                  BiocFileCache_2.0.0          
##  [73] rsvd_1.0.5                    AnnotationHub_3.0.0          
##  [75] cellranger_1.1.0              tcltk_4.1.0                  
##  [77] rprojroot_2.0.2               polyclip_1.10-0              
##  [79] lmtest_0.9-38                 Matrix_1.3-4                 
##  [81] zoo_1.8-9                     base64enc_0.1-3              
##  [83] whisker_0.4                   ggridges_0.5.3               
##  [85] processx_3.5.2                png_0.1-7                    
##  [87] viridisLite_0.4.0             bitops_1.0-7                 
##  [89] KernSmooth_2.23-20            Biostrings_2.60.1            
##  [91] blob_1.2.1                    DelayedMatrixStats_1.14.0    
##  [93] parallelly_1.26.0             beachmat_2.8.0               
##  [95] scales_1.1.1                  memoise_2.0.0                
##  [97] magrittr_2.0.1                plyr_1.8.6                   
##  [99] ica_1.0-2                     zlibbioc_1.38.0              
## [101] compiler_4.1.0                RColorBrewer_1.1-2           
## [103] fitdistrplus_1.1-5            cli_2.5.0                    
## [105] XVector_0.32.0                listenv_0.8.0                
## [107] pbapply_1.4-3                 ps_1.6.0                     
## [109] mgcv_1.8-36                   tidyselect_1.1.1             
## [111] stringi_1.6.2                 forcats_0.5.1                
## [113] highr_0.9                     yaml_2.2.1                   
## [115] BiocSingular_1.8.1            ggrepel_0.9.1                
## [117] grid_4.1.0                    sass_0.4.0                   
## [119] randomcoloR_1.1.0.1           tools_4.1.0                  
## [121] future.apply_1.7.0            rstudioapi_0.13              
## [123] gridExtra_2.3                 prodlim_2019.11.13           
## [125] farver_2.1.0                  Rtsne_0.15                   
## [127] digest_0.6.27                 BiocManager_1.30.16          
## [129] shiny_1.6.0                   lava_1.6.9                   
## [131] Rcpp_1.0.6                    broom_0.7.7                  
## [133] BiocVersion_3.13.1            later_1.2.0                  
## [135] writexl_1.4.0                 RcppAnnoy_0.0.18             
## [137] httr_1.4.2                    AnnotationDbi_1.54.1         
## [139] colorspace_2.0-1              tensor_1.5                   
## [141] reticulate_1.20               splines_4.1.0                
## [143] uwot_0.1.10                   spatstat.utils_2.2-0         
## [145] plotly_4.9.4.1                xtable_1.8-4                 
## [147] jsonlite_1.7.2                timeDate_3043.102            
## [149] zeallot_0.1.0                 ipred_0.9-11                 
## [151] R6_2.5.0                      pillar_1.6.1                 
## [153] htmltools_0.5.1.1             mime_0.10                    
## [155] glue_1.4.2                    fastmap_1.1.0                
## [157] BiocNeighbors_1.10.0          class_7.3-19                 
## [159] interactiveDisplayBase_1.30.0 codetools_0.2-18             
## [161] pkgbuild_1.2.0                utf8_1.2.1                   
## [163] lattice_0.20-44               bslib_0.2.5.1                
## [165] spatstat.sparse_2.0-0         tibble_3.1.2                 
## [167] curl_4.3.1                    tfruns_1.5.0                 
## [169] leiden_0.3.8                  survival_3.2-11              
## [171] rmarkdown_2.9                 munsell_0.5.0                
## [173] GenomeInfoDbData_1.2.6        xaringan_0.21                
## [175] reshape2_1.4.4                gtable_0.3.0

Built on: 22-Jun-2021 at 13:36:42.


2021SciLifeLabNBISRaukR website twitter