import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
35 Neural Networks in Python (with Keras)
35.1 Introduction
There are a few Python
packages available to train neural networks. See Table 35.1 for some examples. The packages vary in capabilities.
Several frameworks for ANNs and deep learning exist. TensorFlow, Microsoft CNTK, PyTorch, and Theano are among the most important ones.
Python
packages for neural network analysis.
Package | Notes |
---|---|
sci-kit |
Includes a neural_network module for training Multi-layer Perceptrons (MLP) and Bernoulli Restricted Boltzmann Machines (RBM) |
tensorflow |
Interface to TensorFlow, a free and open-source software library for machine learning and artificial intelligence |
torch |
Pytorch: Tensors and neural networks with GPU acceleration |
keras |
Deep learning library |
Keras has emerged as an important API (Application Programming Interface) for deep learning. It provides a consistent interface on top of JAX, TensorFlow, or PyTorch. While TensorFlow is very powerful, the learning curve can be steep and you tend to write a lot of code. On the other hand, you have complete control over the types of models you build and train with TensorFlow. That makes Keras so relevant: you can tap into the capabilities of TensorFlow with a simpler API.
Tools from the modern machine learning toolbox tend to be written in Python. The keras
package in Python
calls into Tensorflow, or whatever deep learning framework Keras is running on.
35.2 Running Keras in Python
Packages
Keras Basics
Training a neural network with keras
involves three steps:
Defining the network
Setting up the optimization
Fitting the model
Not until the third step does the algorithm get in contact with actual data. However, we need to know some things about the data in order to define the network in step 1: the dimensions of the input and output.
Defining the network
The most convenient way of specifying a multi layer neural network is by adding layers sequentially, from the input layer to the output layer. These starts with a call to keras.models.Sequential()
. Suppose we want to predict a continuous response (regression application) based on inputs \(x_1, \cdots, x_{19}\) with one hidden layer and dropout regularization.
The following statements define the model sequentially:
= keras.models.Sequential()
firstANN =(19,))) #preferred method to specify Input shape separately
firstANN.add(layers.Input(shape=50, activation='relu'))
firstANN.add(layers.Dense(units=0.4))
firstANN.add(layers.Dropout(rate=1, name='Output')) firstANN.add(layers.Dense(units
layers.Input
specifies the shape of the model input. layers.Dense()
adds a fully connected layer to the networks, the units=
option specifies the number of neurons in the layer. In summary, the hidden layer receives 19 inputs and has 50 output units (neurons) and ReLU activation. The output from the hidden layer is passed on (piped) to a dropout layer with a dropout rate of \(\phi = 0.4\). The result of the dropout layer is passed on to another fully connected layer with a single neuron. This is the output layer of the network. In other words, the last layer in the sequence is automatically the output layer. Since we are in a regression context to predict a numeric target variable, there is only one output unit in the final layer. If this was a classification problem with \(5\) categories, the last layer would have 5 units.
You can assign a name to each layer with the name=
option, this makes it easier to identify the layers in output. If you do not specify a name, Keras will assign a name that combines a description of the layer type with a numerical index (not always). The numeric indices can be confusing because they depend on counters internal to the Python code. Assigning an explicit name is recommended practice.
The activation=
option specifies the activation function \(\sigma()\) for the hidden layers and the output function \(g()\) for the output layer. The default is the identity (“linear”) activation, \(\sigma(x) = x\). This default is appropriate for the output layer in a regression application. For the hidden layer we choose the ReLU activation.
A list of activation functions supported by keras
(Keras), can be found here.
The basic neural network is now defined and we can find out how many parameters it entails.
print(firstANN.summary())
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 50) │ 1,000 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 50) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ Output (Dense) │ (None, 1) │ 51 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 1,051 (4.11 KB)
Trainable params: 1,051 (4.11 KB)
Non-trainable params: 0 (0.00 B)
None
With 19 inputs and 50 neurons, the first layer has 50 x 20 = 1000 parameters (20 slopes and an intercept for each output neuron). The dropout layer does not add any parameters to the estimation, it chooses output neurons of the previous layer at random and sets their activation to zero. The 50 neurons (some with activation set randomly to zero) are the input to the final layer, adding fifty weights (slopes) and one bias (intercept). The total number of parameters of this neural network is 1,051.
Setting up the optimization
The second step in training a model in Keras is to specify the particulars of the optimization with the Keras compile()
function. Typical specifications include the loss functions, the type of optimization algorithm, and the metrics evaluated by the model during training.
The following function call uses the RMSProp algorithm with mean-squared error loss function to estimate the parameters of the network. During training, the mean absolute error is also monitored in addition to the mean squared error.
compile(
firstANN.='mse', # mean squared error
loss='rmsprop',
optimizer=['mean_absolute_error']
metrics )
Depending on your environment, not all optimization algorithms are supported.
Fitting the model
The last step in training the network is to connect the defined and compiled model with training—and possibly test—data.
For this example we use the Hitters
data from the ISLP package. This is a data set with 322 observations of major league baseball players from the 1986 and 1987 seasons. The following code removes observations with missing values from the data frame, defines a vector of ids for the test data (1/3 of the observations) and computes a scaled and centered model matrix using all 19 input variables.
import numpy as np
import pandas as pd
from ISLP import load_data
= load_data('Hitters')
Hitters
# Remove rows with missing values
= Hitters.dropna()
Gitters
#Get number of rows
= len(Gitters)
n
#Set random seed for reproducibility
13)
keras.utils.set_random_seed(
# Create test set (1/3 of data)
= int(n / 3)
ntest = np.random.choice(n, size=ntest, replace=False)
testid
# Create feature matrix (excluding Salary)
= Gitters.drop('Salary', axis=1)
X
# Convert categorical variables to dummy variables
= pd.get_dummies(X, drop_first=True, dtype='int')
X
from sklearn.preprocessing import StandardScaler
# Scale the features
= StandardScaler()
scaler = pd.DataFrame(scaler.fit_transform(X))
x
# Target variable
= Gitters['Salary'].values
y
# Split into training and test sets
= x.iloc[-testid]
x_train = y[-testid]
y_train
= x.iloc[testid]
x_test = y[testid] y_test
Note that the model contains several factors (League
, Division
, NewLeague
) whose levels are encoded as binary variables in the model matrix. One could exclude those from scaling and centering as they already are in the proper range. In a regression model you would not want to scale these variables to preserve the interpretation of their coefficients. In a neural network interpretation of the model coefficients is not important and we include all columns of the model matrix in the scaling operation.
The following code fits the model to the training data (-testid
) using 20 epochs and a minibatch size of 32. That means the gradient is computed based on 32 randomly chosen observations in each step of the stochastic gradient descent algorithm. Since there are 176 training observations it takes \(176/32=5.5\) SGD steps to process all \(n\) observations. This is known as an epoch and is akin to the concept of an iteration in numerical optimization: a full pass through the data. The fundamental difference between an epoch and an iteration lies in the fact that updates of the parameters occur after each gradient computation. In a full iteration, there is one update after the pass through the entire data. In SGD with minibatch, there are multiple updates of the parameters, one for each minibatch.
Running 200 epochs with a batch size of 32 and a training set size of 176 results in 200 * 5.5 = 1,100 gradient evaluations.
The validation_data=
option lists the test data for the training. The objective function and metrics specified in the compile
command earlier are computed at each epoch for the training and the test data if the latter is specified. If you do not have a validation data set, you can specify validation_split=
and request that a fraction of the training data is held back for validation.
= firstANN.fit(
history
x_train, y_train,=200,
epochs=32,
batch_size=(x_test, y_test),
validation_data=2
verbose )
Epoch 1/200
3/3 - 0s - 72ms/step - loss: 389340.0312 - mean_absolute_error: 502.5588 - val_loss: 417511.1875 - val_mean_absolute_error: 492.3558
Epoch 2/200
3/3 - 0s - 8ms/step - loss: 389020.7812 - mean_absolute_error: 502.3522 - val_loss: 417332.5625 - val_mean_absolute_error: 492.2245
Epoch 3/200
3/3 - 0s - 9ms/step - loss: 388874.0625 - mean_absolute_error: 502.2130 - val_loss: 417194.1250 - val_mean_absolute_error: 492.1146
Epoch 4/200
3/3 - 0s - 9ms/step - loss: 388773.2812 - mean_absolute_error: 502.1937 - val_loss: 417067.2188 - val_mean_absolute_error: 492.0135
Epoch 5/200
3/3 - 0s - 8ms/step - loss: 388594.0312 - mean_absolute_error: 501.9636 - val_loss: 416946.8125 - val_mean_absolute_error: 491.9159
Epoch 6/200
3/3 - 0s - 8ms/step - loss: 388381.6562 - mean_absolute_error: 501.8437 - val_loss: 416836.7812 - val_mean_absolute_error: 491.8320
Epoch 7/200
3/3 - 0s - 8ms/step - loss: 388408.2188 - mean_absolute_error: 501.8445 - val_loss: 416726.9375 - val_mean_absolute_error: 491.7469
Epoch 8/200
3/3 - 0s - 8ms/step - loss: 388200.2812 - mean_absolute_error: 501.7393 - val_loss: 416617.5938 - val_mean_absolute_error: 491.6615
Epoch 9/200
3/3 - 0s - 8ms/step - loss: 388065.0000 - mean_absolute_error: 501.5553 - val_loss: 416506.8125 - val_mean_absolute_error: 491.5742
Epoch 10/200
3/3 - 0s - 8ms/step - loss: 387950.4062 - mean_absolute_error: 501.4736 - val_loss: 416403.3438 - val_mean_absolute_error: 491.4915
Epoch 11/200
3/3 - 0s - 8ms/step - loss: 387918.5312 - mean_absolute_error: 501.4520 - val_loss: 416301.3750 - val_mean_absolute_error: 491.4099
Epoch 12/200
3/3 - 0s - 8ms/step - loss: 387795.7188 - mean_absolute_error: 501.3584 - val_loss: 416201.7812 - val_mean_absolute_error: 491.3300
Epoch 13/200
3/3 - 0s - 8ms/step - loss: 387706.3438 - mean_absolute_error: 501.3469 - val_loss: 416100.2188 - val_mean_absolute_error: 491.2463
Epoch 14/200
3/3 - 0s - 8ms/step - loss: 387407.3438 - mean_absolute_error: 501.0790 - val_loss: 415996.6562 - val_mean_absolute_error: 491.1637
Epoch 15/200
3/3 - 0s - 8ms/step - loss: 387381.1562 - mean_absolute_error: 501.0563 - val_loss: 415889.7812 - val_mean_absolute_error: 491.0791
Epoch 16/200
3/3 - 0s - 8ms/step - loss: 387257.8438 - mean_absolute_error: 500.9598 - val_loss: 415785.8750 - val_mean_absolute_error: 490.9952
Epoch 17/200
3/3 - 0s - 8ms/step - loss: 387203.9688 - mean_absolute_error: 500.8877 - val_loss: 415686.4688 - val_mean_absolute_error: 490.9164
Epoch 18/200
3/3 - 0s - 8ms/step - loss: 386995.5938 - mean_absolute_error: 500.7290 - val_loss: 415581.0000 - val_mean_absolute_error: 490.8327
Epoch 19/200
3/3 - 0s - 8ms/step - loss: 387161.6562 - mean_absolute_error: 500.7883 - val_loss: 415486.8438 - val_mean_absolute_error: 490.7541
Epoch 20/200
3/3 - 0s - 8ms/step - loss: 386758.3438 - mean_absolute_error: 500.6389 - val_loss: 415379.0312 - val_mean_absolute_error: 490.6654
Epoch 21/200
3/3 - 0s - 8ms/step - loss: 386771.3438 - mean_absolute_error: 500.5865 - val_loss: 415274.7500 - val_mean_absolute_error: 490.5825
Epoch 22/200
3/3 - 0s - 8ms/step - loss: 386799.4375 - mean_absolute_error: 500.5797 - val_loss: 415172.9688 - val_mean_absolute_error: 490.4991
Epoch 23/200
3/3 - 0s - 8ms/step - loss: 386456.2812 - mean_absolute_error: 500.2074 - val_loss: 415067.0312 - val_mean_absolute_error: 490.4099
Epoch 24/200
3/3 - 0s - 8ms/step - loss: 386355.6875 - mean_absolute_error: 500.2285 - val_loss: 414961.3438 - val_mean_absolute_error: 490.3196
Epoch 25/200
3/3 - 0s - 8ms/step - loss: 386425.3750 - mean_absolute_error: 500.2291 - val_loss: 414854.3125 - val_mean_absolute_error: 490.2314
Epoch 26/200
3/3 - 0s - 8ms/step - loss: 386270.8125 - mean_absolute_error: 500.1129 - val_loss: 414746.4062 - val_mean_absolute_error: 490.1427
Epoch 27/200
3/3 - 0s - 9ms/step - loss: 385948.5625 - mean_absolute_error: 499.9853 - val_loss: 414635.8125 - val_mean_absolute_error: 490.0521
Epoch 28/200
3/3 - 0s - 8ms/step - loss: 386006.7500 - mean_absolute_error: 499.8953 - val_loss: 414529.4375 - val_mean_absolute_error: 489.9622
Epoch 29/200
3/3 - 0s - 8ms/step - loss: 386002.2500 - mean_absolute_error: 499.8737 - val_loss: 414421.5312 - val_mean_absolute_error: 489.8708
Epoch 30/200
3/3 - 0s - 8ms/step - loss: 385765.7812 - mean_absolute_error: 499.8581 - val_loss: 414308.1875 - val_mean_absolute_error: 489.7787
Epoch 31/200
3/3 - 0s - 8ms/step - loss: 385381.1875 - mean_absolute_error: 499.5080 - val_loss: 414190.0312 - val_mean_absolute_error: 489.6848
Epoch 32/200
3/3 - 0s - 8ms/step - loss: 385460.2188 - mean_absolute_error: 499.5469 - val_loss: 414080.0000 - val_mean_absolute_error: 489.5907
Epoch 33/200
3/3 - 0s - 8ms/step - loss: 385175.0938 - mean_absolute_error: 499.2926 - val_loss: 413965.0000 - val_mean_absolute_error: 489.4966
Epoch 34/200
3/3 - 0s - 8ms/step - loss: 385276.1250 - mean_absolute_error: 499.3307 - val_loss: 413852.5000 - val_mean_absolute_error: 489.4005
Epoch 35/200
3/3 - 0s - 8ms/step - loss: 384804.1875 - mean_absolute_error: 499.0443 - val_loss: 413731.0312 - val_mean_absolute_error: 489.3018
Epoch 36/200
3/3 - 0s - 8ms/step - loss: 384691.5625 - mean_absolute_error: 498.9231 - val_loss: 413609.4688 - val_mean_absolute_error: 489.2015
Epoch 37/200
3/3 - 0s - 8ms/step - loss: 384677.1875 - mean_absolute_error: 498.8477 - val_loss: 413489.6562 - val_mean_absolute_error: 489.1023
Epoch 38/200
3/3 - 0s - 8ms/step - loss: 384259.0938 - mean_absolute_error: 498.6401 - val_loss: 413363.5938 - val_mean_absolute_error: 488.9981
Epoch 39/200
3/3 - 0s - 8ms/step - loss: 384543.7188 - mean_absolute_error: 498.8074 - val_loss: 413245.5938 - val_mean_absolute_error: 488.9006
Epoch 40/200
3/3 - 0s - 8ms/step - loss: 384308.6875 - mean_absolute_error: 498.5863 - val_loss: 413125.7812 - val_mean_absolute_error: 488.8018
Epoch 41/200
3/3 - 0s - 8ms/step - loss: 384587.2188 - mean_absolute_error: 498.5696 - val_loss: 413011.5938 - val_mean_absolute_error: 488.7083
Epoch 42/200
3/3 - 0s - 8ms/step - loss: 384145.8750 - mean_absolute_error: 498.4301 - val_loss: 412891.0312 - val_mean_absolute_error: 488.6079
Epoch 43/200
3/3 - 0s - 8ms/step - loss: 383963.1875 - mean_absolute_error: 498.2796 - val_loss: 412765.4688 - val_mean_absolute_error: 488.5054
Epoch 44/200
3/3 - 0s - 8ms/step - loss: 384246.1250 - mean_absolute_error: 498.3518 - val_loss: 412647.8750 - val_mean_absolute_error: 488.4030
Epoch 45/200
3/3 - 0s - 8ms/step - loss: 383733.2500 - mean_absolute_error: 498.0811 - val_loss: 412520.0938 - val_mean_absolute_error: 488.2993
Epoch 46/200
3/3 - 0s - 8ms/step - loss: 383349.5312 - mean_absolute_error: 497.6969 - val_loss: 412386.1562 - val_mean_absolute_error: 488.1903
Epoch 47/200
3/3 - 0s - 8ms/step - loss: 383349.6562 - mean_absolute_error: 497.7733 - val_loss: 412255.1250 - val_mean_absolute_error: 488.0803
Epoch 48/200
3/3 - 0s - 8ms/step - loss: 383455.2188 - mean_absolute_error: 497.7986 - val_loss: 412125.1562 - val_mean_absolute_error: 487.9725
Epoch 49/200
3/3 - 0s - 8ms/step - loss: 383755.3438 - mean_absolute_error: 497.9748 - val_loss: 412003.9688 - val_mean_absolute_error: 487.8687
Epoch 50/200
3/3 - 0s - 8ms/step - loss: 382649.4375 - mean_absolute_error: 497.3878 - val_loss: 411860.3750 - val_mean_absolute_error: 487.7536
Epoch 51/200
3/3 - 0s - 8ms/step - loss: 382818.3438 - mean_absolute_error: 497.2996 - val_loss: 411723.6250 - val_mean_absolute_error: 487.6452
Epoch 52/200
3/3 - 0s - 8ms/step - loss: 382722.5000 - mean_absolute_error: 497.2722 - val_loss: 411586.8125 - val_mean_absolute_error: 487.5352
Epoch 53/200
3/3 - 0s - 10ms/step - loss: 382742.3438 - mean_absolute_error: 497.1627 - val_loss: 411448.7812 - val_mean_absolute_error: 487.4180
Epoch 54/200
3/3 - 0s - 8ms/step - loss: 382132.0312 - mean_absolute_error: 496.8816 - val_loss: 411307.6250 - val_mean_absolute_error: 487.3025
Epoch 55/200
3/3 - 0s - 8ms/step - loss: 381995.5938 - mean_absolute_error: 496.6758 - val_loss: 411165.0625 - val_mean_absolute_error: 487.1872
Epoch 56/200
3/3 - 0s - 8ms/step - loss: 382269.9375 - mean_absolute_error: 496.8251 - val_loss: 411023.2500 - val_mean_absolute_error: 487.0744
Epoch 57/200
3/3 - 0s - 8ms/step - loss: 382409.4062 - mean_absolute_error: 496.8909 - val_loss: 410887.5938 - val_mean_absolute_error: 486.9630
Epoch 58/200
3/3 - 0s - 8ms/step - loss: 381919.6875 - mean_absolute_error: 496.6476 - val_loss: 410739.5938 - val_mean_absolute_error: 486.8400
Epoch 59/200
3/3 - 0s - 8ms/step - loss: 381490.9062 - mean_absolute_error: 496.3203 - val_loss: 410591.1875 - val_mean_absolute_error: 486.7131
Epoch 60/200
3/3 - 0s - 8ms/step - loss: 381556.4062 - mean_absolute_error: 496.2415 - val_loss: 410444.0000 - val_mean_absolute_error: 486.5936
Epoch 61/200
3/3 - 0s - 8ms/step - loss: 381420.2188 - mean_absolute_error: 496.1440 - val_loss: 410292.1250 - val_mean_absolute_error: 486.4695
Epoch 62/200
3/3 - 0s - 8ms/step - loss: 380701.5000 - mean_absolute_error: 495.7862 - val_loss: 410129.7500 - val_mean_absolute_error: 486.3413
Epoch 63/200
3/3 - 0s - 8ms/step - loss: 381487.8750 - mean_absolute_error: 495.9503 - val_loss: 409982.7500 - val_mean_absolute_error: 486.2198
Epoch 64/200
3/3 - 0s - 8ms/step - loss: 380641.4688 - mean_absolute_error: 495.3927 - val_loss: 409818.7500 - val_mean_absolute_error: 486.0856
Epoch 65/200
3/3 - 0s - 8ms/step - loss: 380755.3438 - mean_absolute_error: 495.4585 - val_loss: 409658.7500 - val_mean_absolute_error: 485.9571
Epoch 66/200
3/3 - 0s - 8ms/step - loss: 380269.5000 - mean_absolute_error: 494.9557 - val_loss: 409494.6250 - val_mean_absolute_error: 485.8271
Epoch 67/200
3/3 - 0s - 8ms/step - loss: 379841.4062 - mean_absolute_error: 495.1231 - val_loss: 409326.8438 - val_mean_absolute_error: 485.6925
Epoch 68/200
3/3 - 0s - 8ms/step - loss: 380039.3750 - mean_absolute_error: 494.8650 - val_loss: 409161.5625 - val_mean_absolute_error: 485.5595
Epoch 69/200
3/3 - 0s - 8ms/step - loss: 379736.4062 - mean_absolute_error: 494.9195 - val_loss: 408993.0625 - val_mean_absolute_error: 485.4256
Epoch 70/200
3/3 - 0s - 8ms/step - loss: 379426.5625 - mean_absolute_error: 494.5951 - val_loss: 408821.1562 - val_mean_absolute_error: 485.2875
Epoch 71/200
3/3 - 0s - 8ms/step - loss: 379471.5312 - mean_absolute_error: 494.5508 - val_loss: 408656.1875 - val_mean_absolute_error: 485.1555
Epoch 72/200
3/3 - 0s - 8ms/step - loss: 379524.5000 - mean_absolute_error: 494.5577 - val_loss: 408486.9375 - val_mean_absolute_error: 485.0185
Epoch 73/200
3/3 - 0s - 8ms/step - loss: 379088.4688 - mean_absolute_error: 494.2766 - val_loss: 408315.4375 - val_mean_absolute_error: 484.8772
Epoch 74/200
3/3 - 0s - 8ms/step - loss: 377639.8125 - mean_absolute_error: 493.7297 - val_loss: 408122.7188 - val_mean_absolute_error: 484.7244
Epoch 75/200
3/3 - 0s - 8ms/step - loss: 378130.7500 - mean_absolute_error: 493.5963 - val_loss: 407940.8125 - val_mean_absolute_error: 484.5728
Epoch 76/200
3/3 - 0s - 8ms/step - loss: 378473.9062 - mean_absolute_error: 494.0424 - val_loss: 407763.7188 - val_mean_absolute_error: 484.4297
Epoch 77/200
3/3 - 0s - 8ms/step - loss: 378299.4375 - mean_absolute_error: 493.9529 - val_loss: 407585.6562 - val_mean_absolute_error: 484.2831
Epoch 78/200
3/3 - 0s - 8ms/step - loss: 378255.8125 - mean_absolute_error: 493.4101 - val_loss: 407407.5938 - val_mean_absolute_error: 484.1393
Epoch 79/200
3/3 - 0s - 8ms/step - loss: 378561.4688 - mean_absolute_error: 493.3312 - val_loss: 407234.4062 - val_mean_absolute_error: 483.9932
Epoch 80/200
3/3 - 0s - 8ms/step - loss: 378031.0625 - mean_absolute_error: 493.1086 - val_loss: 407050.0312 - val_mean_absolute_error: 483.8425
Epoch 81/200
3/3 - 0s - 8ms/step - loss: 377724.0625 - mean_absolute_error: 493.2688 - val_loss: 406866.4062 - val_mean_absolute_error: 483.6950
Epoch 82/200
3/3 - 0s - 8ms/step - loss: 377577.6562 - mean_absolute_error: 493.0010 - val_loss: 406675.9688 - val_mean_absolute_error: 483.5368
Epoch 83/200
3/3 - 0s - 8ms/step - loss: 378161.5312 - mean_absolute_error: 493.3419 - val_loss: 406489.3750 - val_mean_absolute_error: 483.3834
Epoch 84/200
3/3 - 0s - 8ms/step - loss: 377602.0312 - mean_absolute_error: 493.0906 - val_loss: 406303.0312 - val_mean_absolute_error: 483.2324
Epoch 85/200
3/3 - 0s - 8ms/step - loss: 377337.3750 - mean_absolute_error: 492.4758 - val_loss: 406108.8750 - val_mean_absolute_error: 483.0738
Epoch 86/200
3/3 - 0s - 8ms/step - loss: 376220.7188 - mean_absolute_error: 491.9900 - val_loss: 405899.9688 - val_mean_absolute_error: 482.9014
Epoch 87/200
3/3 - 0s - 8ms/step - loss: 376640.7812 - mean_absolute_error: 492.4343 - val_loss: 405699.8125 - val_mean_absolute_error: 482.7402
Epoch 88/200
3/3 - 0s - 8ms/step - loss: 376136.8125 - mean_absolute_error: 491.7885 - val_loss: 405493.0000 - val_mean_absolute_error: 482.5703
Epoch 89/200
3/3 - 0s - 8ms/step - loss: 376049.2812 - mean_absolute_error: 491.6884 - val_loss: 405288.5625 - val_mean_absolute_error: 482.3979
Epoch 90/200
3/3 - 0s - 8ms/step - loss: 375846.2188 - mean_absolute_error: 491.3922 - val_loss: 405085.8750 - val_mean_absolute_error: 482.2295
Epoch 91/200
3/3 - 0s - 8ms/step - loss: 375962.8438 - mean_absolute_error: 491.7348 - val_loss: 404879.0938 - val_mean_absolute_error: 482.0623
Epoch 92/200
3/3 - 0s - 8ms/step - loss: 375796.4688 - mean_absolute_error: 491.6170 - val_loss: 404674.6250 - val_mean_absolute_error: 481.8931
Epoch 93/200
3/3 - 0s - 8ms/step - loss: 373532.0312 - mean_absolute_error: 490.2703 - val_loss: 404437.8750 - val_mean_absolute_error: 481.7017
Epoch 94/200
3/3 - 0s - 8ms/step - loss: 373986.9375 - mean_absolute_error: 490.3645 - val_loss: 404215.2500 - val_mean_absolute_error: 481.5163
Epoch 95/200
3/3 - 0s - 8ms/step - loss: 374281.9688 - mean_absolute_error: 490.4469 - val_loss: 403995.4375 - val_mean_absolute_error: 481.3341
Epoch 96/200
3/3 - 0s - 8ms/step - loss: 374884.3125 - mean_absolute_error: 490.3987 - val_loss: 403785.9375 - val_mean_absolute_error: 481.1588
Epoch 97/200
3/3 - 0s - 8ms/step - loss: 374143.0000 - mean_absolute_error: 490.0483 - val_loss: 403566.2500 - val_mean_absolute_error: 480.9715
Epoch 98/200
3/3 - 0s - 8ms/step - loss: 374027.7188 - mean_absolute_error: 490.2426 - val_loss: 403339.0312 - val_mean_absolute_error: 480.7791
Epoch 99/200
3/3 - 0s - 8ms/step - loss: 373892.1562 - mean_absolute_error: 490.4114 - val_loss: 403113.2500 - val_mean_absolute_error: 480.5948
Epoch 100/200
3/3 - 0s - 8ms/step - loss: 374321.9688 - mean_absolute_error: 490.0417 - val_loss: 402897.1562 - val_mean_absolute_error: 480.4067
Epoch 101/200
3/3 - 0s - 8ms/step - loss: 372083.5000 - mean_absolute_error: 488.7498 - val_loss: 402656.6562 - val_mean_absolute_error: 480.2061
Epoch 102/200
3/3 - 0s - 8ms/step - loss: 372429.1250 - mean_absolute_error: 488.9255 - val_loss: 402422.7188 - val_mean_absolute_error: 480.0108
Epoch 103/200
3/3 - 0s - 8ms/step - loss: 372664.9062 - mean_absolute_error: 489.0961 - val_loss: 402195.0000 - val_mean_absolute_error: 479.8232
Epoch 104/200
3/3 - 0s - 8ms/step - loss: 372267.5312 - mean_absolute_error: 488.7855 - val_loss: 401961.7812 - val_mean_absolute_error: 479.6321
Epoch 105/200
3/3 - 0s - 8ms/step - loss: 371703.0312 - mean_absolute_error: 488.3731 - val_loss: 401718.2500 - val_mean_absolute_error: 479.4314
Epoch 106/200
3/3 - 0s - 8ms/step - loss: 370769.5000 - mean_absolute_error: 488.1224 - val_loss: 401466.3438 - val_mean_absolute_error: 479.2188
Epoch 107/200
3/3 - 0s - 8ms/step - loss: 371384.2188 - mean_absolute_error: 487.3106 - val_loss: 401229.2812 - val_mean_absolute_error: 479.0202
Epoch 108/200
3/3 - 0s - 8ms/step - loss: 371506.7188 - mean_absolute_error: 488.1428 - val_loss: 400983.6250 - val_mean_absolute_error: 478.8108
Epoch 109/200
3/3 - 0s - 8ms/step - loss: 371570.6875 - mean_absolute_error: 488.0952 - val_loss: 400741.3750 - val_mean_absolute_error: 478.6074
Epoch 110/200
3/3 - 0s - 8ms/step - loss: 370542.6875 - mean_absolute_error: 486.6576 - val_loss: 400496.2188 - val_mean_absolute_error: 478.3976
Epoch 111/200
3/3 - 0s - 8ms/step - loss: 370127.1875 - mean_absolute_error: 487.2672 - val_loss: 400238.7500 - val_mean_absolute_error: 478.1807
Epoch 112/200
3/3 - 0s - 8ms/step - loss: 369925.1875 - mean_absolute_error: 486.3921 - val_loss: 399982.3438 - val_mean_absolute_error: 477.9602
Epoch 113/200
3/3 - 0s - 8ms/step - loss: 369002.9375 - mean_absolute_error: 486.6703 - val_loss: 399712.5938 - val_mean_absolute_error: 477.7305
Epoch 114/200
3/3 - 0s - 8ms/step - loss: 369428.1250 - mean_absolute_error: 486.5616 - val_loss: 399452.6875 - val_mean_absolute_error: 477.5117
Epoch 115/200
3/3 - 0s - 8ms/step - loss: 368339.7188 - mean_absolute_error: 485.9211 - val_loss: 399186.7188 - val_mean_absolute_error: 477.2860
Epoch 116/200
3/3 - 0s - 8ms/step - loss: 368323.3438 - mean_absolute_error: 485.6072 - val_loss: 398920.4688 - val_mean_absolute_error: 477.0613
Epoch 117/200
3/3 - 0s - 8ms/step - loss: 370188.3750 - mean_absolute_error: 486.2615 - val_loss: 398679.7812 - val_mean_absolute_error: 476.8512
Epoch 118/200
3/3 - 0s - 8ms/step - loss: 368119.9062 - mean_absolute_error: 485.3836 - val_loss: 398402.0312 - val_mean_absolute_error: 476.6199
Epoch 119/200
3/3 - 0s - 8ms/step - loss: 369088.2188 - mean_absolute_error: 485.5779 - val_loss: 398146.4062 - val_mean_absolute_error: 476.3984
Epoch 120/200
3/3 - 0s - 8ms/step - loss: 366633.4375 - mean_absolute_error: 484.6035 - val_loss: 397862.3438 - val_mean_absolute_error: 476.1578
Epoch 121/200
3/3 - 0s - 8ms/step - loss: 367158.2188 - mean_absolute_error: 484.3943 - val_loss: 397583.7812 - val_mean_absolute_error: 475.9184
Epoch 122/200
3/3 - 0s - 8ms/step - loss: 367018.9062 - mean_absolute_error: 484.4288 - val_loss: 397309.8438 - val_mean_absolute_error: 475.6894
Epoch 123/200
3/3 - 0s - 8ms/step - loss: 367583.0312 - mean_absolute_error: 484.4478 - val_loss: 397035.5938 - val_mean_absolute_error: 475.4474
Epoch 124/200
3/3 - 0s - 8ms/step - loss: 367400.1250 - mean_absolute_error: 484.2590 - val_loss: 396770.7500 - val_mean_absolute_error: 475.2192
Epoch 125/200
3/3 - 0s - 8ms/step - loss: 365601.6562 - mean_absolute_error: 482.9560 - val_loss: 396480.4062 - val_mean_absolute_error: 474.9660
Epoch 126/200
3/3 - 0s - 8ms/step - loss: 364871.5625 - mean_absolute_error: 482.7477 - val_loss: 396190.7188 - val_mean_absolute_error: 474.7201
Epoch 127/200
3/3 - 0s - 8ms/step - loss: 364507.3750 - mean_absolute_error: 482.6626 - val_loss: 395893.2500 - val_mean_absolute_error: 474.4673
Epoch 128/200
3/3 - 0s - 8ms/step - loss: 366005.8750 - mean_absolute_error: 483.5061 - val_loss: 395613.1562 - val_mean_absolute_error: 474.2216
Epoch 129/200
3/3 - 0s - 8ms/step - loss: 365127.9688 - mean_absolute_error: 482.6041 - val_loss: 395322.6562 - val_mean_absolute_error: 473.9750
Epoch 130/200
3/3 - 0s - 8ms/step - loss: 365349.1875 - mean_absolute_error: 483.2705 - val_loss: 395036.7812 - val_mean_absolute_error: 473.7239
Epoch 131/200
3/3 - 0s - 8ms/step - loss: 365244.0938 - mean_absolute_error: 482.2295 - val_loss: 394754.5312 - val_mean_absolute_error: 473.4706
Epoch 132/200
3/3 - 0s - 8ms/step - loss: 363895.7812 - mean_absolute_error: 481.1255 - val_loss: 394451.1250 - val_mean_absolute_error: 473.2129
Epoch 133/200
3/3 - 0s - 8ms/step - loss: 364859.6562 - mean_absolute_error: 482.4388 - val_loss: 394158.9375 - val_mean_absolute_error: 472.9590
Epoch 134/200
3/3 - 0s - 8ms/step - loss: 364073.2500 - mean_absolute_error: 481.0860 - val_loss: 393860.5625 - val_mean_absolute_error: 472.7028
Epoch 135/200
3/3 - 0s - 8ms/step - loss: 362870.8125 - mean_absolute_error: 480.2015 - val_loss: 393553.3438 - val_mean_absolute_error: 472.4390
Epoch 136/200
3/3 - 0s - 8ms/step - loss: 361279.8750 - mean_absolute_error: 479.7006 - val_loss: 393233.4688 - val_mean_absolute_error: 472.1605
Epoch 137/200
3/3 - 0s - 8ms/step - loss: 361741.0938 - mean_absolute_error: 480.1780 - val_loss: 392922.4688 - val_mean_absolute_error: 471.8917
Epoch 138/200
3/3 - 0s - 8ms/step - loss: 361407.7500 - mean_absolute_error: 479.9792 - val_loss: 392613.4375 - val_mean_absolute_error: 471.6332
Epoch 139/200
3/3 - 0s - 8ms/step - loss: 361352.8750 - mean_absolute_error: 479.7640 - val_loss: 392295.6875 - val_mean_absolute_error: 471.3591
Epoch 140/200
3/3 - 0s - 8ms/step - loss: 361043.6875 - mean_absolute_error: 480.0116 - val_loss: 391980.9688 - val_mean_absolute_error: 471.0856
Epoch 141/200
3/3 - 0s - 8ms/step - loss: 362394.8438 - mean_absolute_error: 480.0092 - val_loss: 391678.8125 - val_mean_absolute_error: 470.8279
Epoch 142/200
3/3 - 0s - 8ms/step - loss: 359436.1250 - mean_absolute_error: 478.6681 - val_loss: 391353.1875 - val_mean_absolute_error: 470.5403
Epoch 143/200
3/3 - 0s - 8ms/step - loss: 360798.0312 - mean_absolute_error: 478.8389 - val_loss: 391037.6875 - val_mean_absolute_error: 470.2642
Epoch 144/200
3/3 - 0s - 8ms/step - loss: 359371.3438 - mean_absolute_error: 478.1148 - val_loss: 390709.5625 - val_mean_absolute_error: 469.9763
Epoch 145/200
3/3 - 0s - 8ms/step - loss: 358636.7500 - mean_absolute_error: 477.7131 - val_loss: 390375.0000 - val_mean_absolute_error: 469.6906
Epoch 146/200
3/3 - 0s - 8ms/step - loss: 360723.4375 - mean_absolute_error: 478.2954 - val_loss: 390068.7500 - val_mean_absolute_error: 469.4168
Epoch 147/200
3/3 - 0s - 8ms/step - loss: 358771.5000 - mean_absolute_error: 477.4695 - val_loss: 389742.7500 - val_mean_absolute_error: 469.1360
Epoch 148/200
3/3 - 0s - 8ms/step - loss: 358842.2188 - mean_absolute_error: 477.4017 - val_loss: 389404.9062 - val_mean_absolute_error: 468.8461
Epoch 149/200
3/3 - 0s - 8ms/step - loss: 358303.5938 - mean_absolute_error: 476.8569 - val_loss: 389066.8438 - val_mean_absolute_error: 468.5553
Epoch 150/200
3/3 - 0s - 8ms/step - loss: 358150.1562 - mean_absolute_error: 476.6016 - val_loss: 388734.0312 - val_mean_absolute_error: 468.2742
Epoch 151/200
3/3 - 0s - 8ms/step - loss: 357958.1562 - mean_absolute_error: 476.9546 - val_loss: 388396.8125 - val_mean_absolute_error: 467.9798
Epoch 152/200
3/3 - 0s - 9ms/step - loss: 356990.8438 - mean_absolute_error: 475.8693 - val_loss: 388052.9688 - val_mean_absolute_error: 467.6772
Epoch 153/200
3/3 - 0s - 8ms/step - loss: 356785.2812 - mean_absolute_error: 475.7096 - val_loss: 387710.8125 - val_mean_absolute_error: 467.3735
Epoch 154/200
3/3 - 0s - 8ms/step - loss: 355202.0625 - mean_absolute_error: 474.8802 - val_loss: 387354.4688 - val_mean_absolute_error: 467.0642
Epoch 155/200
3/3 - 0s - 8ms/step - loss: 355401.7500 - mean_absolute_error: 474.2985 - val_loss: 387007.9688 - val_mean_absolute_error: 466.7566
Epoch 156/200
3/3 - 0s - 8ms/step - loss: 354895.7812 - mean_absolute_error: 473.8029 - val_loss: 386658.6562 - val_mean_absolute_error: 466.4541
Epoch 157/200
3/3 - 0s - 8ms/step - loss: 354721.7812 - mean_absolute_error: 473.8486 - val_loss: 386309.5312 - val_mean_absolute_error: 466.1498
Epoch 158/200
3/3 - 0s - 8ms/step - loss: 353718.9062 - mean_absolute_error: 472.4281 - val_loss: 385959.4375 - val_mean_absolute_error: 465.8422
Epoch 159/200
3/3 - 0s - 8ms/step - loss: 353364.5000 - mean_absolute_error: 472.6308 - val_loss: 385607.9375 - val_mean_absolute_error: 465.5382
Epoch 160/200
3/3 - 0s - 8ms/step - loss: 352616.7812 - mean_absolute_error: 472.4825 - val_loss: 385258.0938 - val_mean_absolute_error: 465.2285
Epoch 161/200
3/3 - 0s - 8ms/step - loss: 355689.5938 - mean_absolute_error: 474.0358 - val_loss: 384921.7812 - val_mean_absolute_error: 464.9329
Epoch 162/200
3/3 - 0s - 8ms/step - loss: 353691.4375 - mean_absolute_error: 472.7648 - val_loss: 384573.5000 - val_mean_absolute_error: 464.6322
Epoch 163/200
3/3 - 0s - 8ms/step - loss: 351961.7812 - mean_absolute_error: 471.1688 - val_loss: 384203.4375 - val_mean_absolute_error: 464.2984
Epoch 164/200
3/3 - 0s - 8ms/step - loss: 353458.0938 - mean_absolute_error: 473.1573 - val_loss: 383835.5938 - val_mean_absolute_error: 463.9696
Epoch 165/200
3/3 - 0s - 8ms/step - loss: 349905.3750 - mean_absolute_error: 470.8405 - val_loss: 383453.6875 - val_mean_absolute_error: 463.6334
Epoch 166/200
3/3 - 0s - 8ms/step - loss: 350526.7500 - mean_absolute_error: 470.5634 - val_loss: 383087.8750 - val_mean_absolute_error: 463.3057
Epoch 167/200
3/3 - 0s - 8ms/step - loss: 351670.4375 - mean_absolute_error: 470.7737 - val_loss: 382726.1562 - val_mean_absolute_error: 462.9872
Epoch 168/200
3/3 - 0s - 8ms/step - loss: 351872.2188 - mean_absolute_error: 471.5077 - val_loss: 382356.5938 - val_mean_absolute_error: 462.6509
Epoch 169/200
3/3 - 0s - 8ms/step - loss: 352088.9062 - mean_absolute_error: 472.0206 - val_loss: 381999.5312 - val_mean_absolute_error: 462.3377
Epoch 170/200
3/3 - 0s - 8ms/step - loss: 350996.2188 - mean_absolute_error: 470.5678 - val_loss: 381625.6250 - val_mean_absolute_error: 462.0062
Epoch 171/200
3/3 - 0s - 8ms/step - loss: 349649.4062 - mean_absolute_error: 469.5199 - val_loss: 381244.9688 - val_mean_absolute_error: 461.6767
Epoch 172/200
3/3 - 0s - 8ms/step - loss: 348967.3125 - mean_absolute_error: 469.0111 - val_loss: 380857.0000 - val_mean_absolute_error: 461.3362
Epoch 173/200
3/3 - 0s - 8ms/step - loss: 349746.4375 - mean_absolute_error: 469.1169 - val_loss: 380484.1875 - val_mean_absolute_error: 461.0071
Epoch 174/200
3/3 - 0s - 8ms/step - loss: 348441.4688 - mean_absolute_error: 468.4128 - val_loss: 380101.8750 - val_mean_absolute_error: 460.6662
Epoch 175/200
3/3 - 0s - 8ms/step - loss: 348371.2188 - mean_absolute_error: 467.8301 - val_loss: 379712.7812 - val_mean_absolute_error: 460.3181
Epoch 176/200
3/3 - 0s - 8ms/step - loss: 348406.7500 - mean_absolute_error: 468.9245 - val_loss: 379332.5312 - val_mean_absolute_error: 459.9739
Epoch 177/200
3/3 - 0s - 8ms/step - loss: 347340.7500 - mean_absolute_error: 466.9924 - val_loss: 378942.3750 - val_mean_absolute_error: 459.6194
Epoch 178/200
3/3 - 0s - 8ms/step - loss: 346512.2812 - mean_absolute_error: 466.8334 - val_loss: 378542.7188 - val_mean_absolute_error: 459.2625
Epoch 179/200
3/3 - 0s - 8ms/step - loss: 348368.7812 - mean_absolute_error: 468.8810 - val_loss: 378160.7188 - val_mean_absolute_error: 458.9225
Epoch 180/200
3/3 - 0s - 8ms/step - loss: 346363.8750 - mean_absolute_error: 467.5661 - val_loss: 377760.8125 - val_mean_absolute_error: 458.5555
Epoch 181/200
3/3 - 0s - 8ms/step - loss: 348722.0312 - mean_absolute_error: 467.1577 - val_loss: 377390.8125 - val_mean_absolute_error: 458.2152
Epoch 182/200
3/3 - 0s - 8ms/step - loss: 342808.4688 - mean_absolute_error: 465.4976 - val_loss: 376965.0625 - val_mean_absolute_error: 457.8446
Epoch 183/200
3/3 - 0s - 8ms/step - loss: 344330.7500 - mean_absolute_error: 465.4167 - val_loss: 376563.2500 - val_mean_absolute_error: 457.4790
Epoch 184/200
3/3 - 0s - 8ms/step - loss: 343227.9688 - mean_absolute_error: 463.1898 - val_loss: 376156.5000 - val_mean_absolute_error: 457.0987
Epoch 185/200
3/3 - 0s - 8ms/step - loss: 341250.4062 - mean_absolute_error: 463.2173 - val_loss: 375733.3750 - val_mean_absolute_error: 456.7181
Epoch 186/200
3/3 - 0s - 8ms/step - loss: 343570.6562 - mean_absolute_error: 463.1572 - val_loss: 375331.5312 - val_mean_absolute_error: 456.3586
Epoch 187/200
3/3 - 0s - 8ms/step - loss: 343374.5625 - mean_absolute_error: 463.9510 - val_loss: 374921.9375 - val_mean_absolute_error: 455.9871
Epoch 188/200
3/3 - 0s - 8ms/step - loss: 342133.5312 - mean_absolute_error: 462.3269 - val_loss: 374506.7812 - val_mean_absolute_error: 455.6096
Epoch 189/200
3/3 - 0s - 8ms/step - loss: 339564.2812 - mean_absolute_error: 461.1679 - val_loss: 374078.4375 - val_mean_absolute_error: 455.2228
Epoch 190/200
3/3 - 0s - 8ms/step - loss: 341416.5938 - mean_absolute_error: 462.4289 - val_loss: 373667.0625 - val_mean_absolute_error: 454.8430
Epoch 191/200
3/3 - 0s - 8ms/step - loss: 340533.2188 - mean_absolute_error: 461.2674 - val_loss: 373248.1250 - val_mean_absolute_error: 454.4638
Epoch 192/200
3/3 - 0s - 8ms/step - loss: 339216.4688 - mean_absolute_error: 460.2765 - val_loss: 372822.5000 - val_mean_absolute_error: 454.0745
Epoch 193/200
3/3 - 0s - 8ms/step - loss: 342794.6562 - mean_absolute_error: 461.7596 - val_loss: 372424.0625 - val_mean_absolute_error: 453.7064
Epoch 194/200
3/3 - 0s - 8ms/step - loss: 341436.7812 - mean_absolute_error: 461.2386 - val_loss: 372008.2812 - val_mean_absolute_error: 453.3326
Epoch 195/200
3/3 - 0s - 8ms/step - loss: 339580.1250 - mean_absolute_error: 460.0045 - val_loss: 371584.2812 - val_mean_absolute_error: 452.9468
Epoch 196/200
3/3 - 0s - 8ms/step - loss: 337419.0000 - mean_absolute_error: 459.3830 - val_loss: 371140.0312 - val_mean_absolute_error: 452.5490
Epoch 197/200
3/3 - 0s - 8ms/step - loss: 338767.1875 - mean_absolute_error: 458.5282 - val_loss: 370723.0938 - val_mean_absolute_error: 452.1649
Epoch 198/200
3/3 - 0s - 8ms/step - loss: 337111.3750 - mean_absolute_error: 459.9707 - val_loss: 370281.2812 - val_mean_absolute_error: 451.7559
Epoch 199/200
3/3 - 0s - 8ms/step - loss: 337389.5000 - mean_absolute_error: 456.9238 - val_loss: 369857.6875 - val_mean_absolute_error: 451.3605
Epoch 200/200
3/3 - 0s - 8ms/step - loss: 337480.4062 - mean_absolute_error: 457.6003 - val_loss: 369429.6875 - val_mean_absolute_error: 450.9656
Keras reports for each epoch the value of the loss metric (mean squared error) for the training and validation data and the monitored metrics (mean absolute error) for the validation data. As you can see from the lengthy output, all criteria are still decreasing after 200 epochs. It is helpful to view the epoch history graphically. If you run the code in an interactive environment (e.g., Anaconda Spyder), the epoch history is displayed and updated live. You can always plot the epoch history
using the matplotlib
package:
# Plot training history
import matplotlib.pyplot as plt
=(12, 6))
plt.figure(figsize
2, 1, 1)
plt.subplot(range(len(history.history['loss'])), history.history['loss'], label='Training Loss', alpha=0.7)
plt.scatter(range(len(history.history['val_loss'])), history.history['val_loss'], label='Validation Loss', alpha=0.7)
plt.scatter('Model Loss')
plt.title('Epoch')
plt.xlabel('Loss')
plt.ylabel(
plt.legend()
2, 1, 2)
plt.subplot(range(len(history.history['mean_absolute_error'])), history.history['mean_absolute_error'], label='Training MAE')
plt.scatter(range(len(history.history['val_mean_absolute_error'])), history.history['val_mean_absolute_error'], label='Validation MAE')
plt.scatter('Model Mean Absolute Error')
plt.title('Epoch')
plt.xlabel('Mean Absolute Error')
plt.ylabel(
plt.legend()
plt.tight_layout() plt.show()

All criteria are steadily declining and have not leveled out after 200 epochs (Figure 35.1). As expected, the mean squared error and mean absolute error are higher in the validation data than in the training data. This is not always the case when training neural networks. Maybe surprisingly, after about 75 epochs the metrics are showing more ariability from epoch to epoch in the training data than in the validation data. Also, there is no guarantee that criteria decrease monotonically, the mean squared error of epoch \(t\) can be higher than that of epoch \(t-1\). We are looking for the results to settle down and stabilize before calling the optimization completed. More epochs need to be run in this example. Fortunately, you can continue where the previous run has left off. The following code trains the network for another 100 epochs:
firstANN.fit(
x_train, y_train,=100,
epochs=32,
batch_size=(x_test, y_test),
validation_data=2) verbose
Epoch 1/100
3/3 - 0s - 15ms/step - loss: 338878.7812 - mean_absolute_error: 458.2691 - val_loss: 369008.8750 - val_mean_absolute_error: 450.5800
Epoch 2/100
3/3 - 0s - 8ms/step - loss: 334863.4375 - mean_absolute_error: 456.6618 - val_loss: 368559.8750 - val_mean_absolute_error: 450.1714
Epoch 3/100
3/3 - 0s - 8ms/step - loss: 335287.4062 - mean_absolute_error: 457.1601 - val_loss: 368121.0312 - val_mean_absolute_error: 449.7616
Epoch 4/100
3/3 - 0s - 8ms/step - loss: 334838.4062 - mean_absolute_error: 455.0921 - val_loss: 367684.1875 - val_mean_absolute_error: 449.3633
Epoch 5/100
3/3 - 0s - 8ms/step - loss: 329148.4062 - mean_absolute_error: 453.5590 - val_loss: 367209.2812 - val_mean_absolute_error: 448.9365
Epoch 6/100
3/3 - 0s - 8ms/step - loss: 333515.8750 - mean_absolute_error: 453.2549 - val_loss: 366771.7812 - val_mean_absolute_error: 448.5345
Epoch 7/100
3/3 - 0s - 8ms/step - loss: 335810.2188 - mean_absolute_error: 455.5214 - val_loss: 366349.4688 - val_mean_absolute_error: 448.1279
Epoch 8/100
3/3 - 0s - 8ms/step - loss: 330326.4688 - mean_absolute_error: 452.4984 - val_loss: 365882.4062 - val_mean_absolute_error: 447.6920
Epoch 9/100
3/3 - 0s - 8ms/step - loss: 332975.3750 - mean_absolute_error: 454.1907 - val_loss: 365448.1250 - val_mean_absolute_error: 447.2830
Epoch 10/100
3/3 - 0s - 8ms/step - loss: 337247.4375 - mean_absolute_error: 457.0173 - val_loss: 365027.0000 - val_mean_absolute_error: 446.8822
Epoch 11/100
3/3 - 0s - 8ms/step - loss: 332007.6250 - mean_absolute_error: 453.7567 - val_loss: 364582.5312 - val_mean_absolute_error: 446.4705
Epoch 12/100
3/3 - 0s - 8ms/step - loss: 335806.3438 - mean_absolute_error: 455.2414 - val_loss: 364155.5312 - val_mean_absolute_error: 446.0692
Epoch 13/100
3/3 - 0s - 8ms/step - loss: 333222.0312 - mean_absolute_error: 452.7229 - val_loss: 363698.4688 - val_mean_absolute_error: 445.6436
Epoch 14/100
3/3 - 0s - 8ms/step - loss: 329377.0625 - mean_absolute_error: 451.6534 - val_loss: 363220.4375 - val_mean_absolute_error: 445.2020
Epoch 15/100
3/3 - 0s - 8ms/step - loss: 333061.8438 - mean_absolute_error: 453.4913 - val_loss: 362773.8438 - val_mean_absolute_error: 444.7910
Epoch 16/100
3/3 - 0s - 8ms/step - loss: 325156.3125 - mean_absolute_error: 448.1430 - val_loss: 362285.0000 - val_mean_absolute_error: 444.3306
Epoch 17/100
3/3 - 0s - 8ms/step - loss: 330444.8125 - mean_absolute_error: 451.6080 - val_loss: 361824.6250 - val_mean_absolute_error: 443.9007
Epoch 18/100
3/3 - 0s - 8ms/step - loss: 328990.5312 - mean_absolute_error: 450.1672 - val_loss: 361357.9688 - val_mean_absolute_error: 443.4631
Epoch 19/100
3/3 - 0s - 8ms/step - loss: 326936.3125 - mean_absolute_error: 449.8767 - val_loss: 360887.5312 - val_mean_absolute_error: 443.0224
Epoch 20/100
3/3 - 0s - 8ms/step - loss: 327672.4062 - mean_absolute_error: 449.2604 - val_loss: 360417.7500 - val_mean_absolute_error: 442.5815
Epoch 21/100
3/3 - 0s - 8ms/step - loss: 330911.9688 - mean_absolute_error: 450.9475 - val_loss: 359975.8125 - val_mean_absolute_error: 442.1567
Epoch 22/100
3/3 - 0s - 8ms/step - loss: 328743.0000 - mean_absolute_error: 450.0015 - val_loss: 359511.4375 - val_mean_absolute_error: 441.7124
Epoch 23/100
3/3 - 0s - 8ms/step - loss: 325571.7188 - mean_absolute_error: 446.8181 - val_loss: 359034.6250 - val_mean_absolute_error: 441.2596
Epoch 24/100
3/3 - 0s - 8ms/step - loss: 327623.9688 - mean_absolute_error: 448.6258 - val_loss: 358575.0625 - val_mean_absolute_error: 440.8286
Epoch 25/100
3/3 - 0s - 8ms/step - loss: 325030.4375 - mean_absolute_error: 447.0710 - val_loss: 358102.4688 - val_mean_absolute_error: 440.3892
Epoch 26/100
3/3 - 0s - 8ms/step - loss: 323709.0000 - mean_absolute_error: 444.2144 - val_loss: 357615.4062 - val_mean_absolute_error: 439.9251
Epoch 27/100
3/3 - 0s - 8ms/step - loss: 322602.3438 - mean_absolute_error: 444.6203 - val_loss: 357131.7188 - val_mean_absolute_error: 439.4618
Epoch 28/100
3/3 - 0s - 8ms/step - loss: 321078.8125 - mean_absolute_error: 443.8509 - val_loss: 356644.7500 - val_mean_absolute_error: 438.9994
Epoch 29/100
3/3 - 0s - 8ms/step - loss: 322219.6250 - mean_absolute_error: 446.2008 - val_loss: 356159.5938 - val_mean_absolute_error: 438.5360
Epoch 30/100
3/3 - 0s - 8ms/step - loss: 321905.4375 - mean_absolute_error: 443.3015 - val_loss: 355676.4062 - val_mean_absolute_error: 438.0792
Epoch 31/100
3/3 - 0s - 8ms/step - loss: 321208.3125 - mean_absolute_error: 445.4268 - val_loss: 355186.5000 - val_mean_absolute_error: 437.6194
Epoch 32/100
3/3 - 0s - 8ms/step - loss: 320751.4062 - mean_absolute_error: 443.2891 - val_loss: 354691.2812 - val_mean_absolute_error: 437.1460
Epoch 33/100
3/3 - 0s - 8ms/step - loss: 317576.6562 - mean_absolute_error: 441.4828 - val_loss: 354193.5938 - val_mean_absolute_error: 436.6757
Epoch 34/100
3/3 - 0s - 8ms/step - loss: 317415.7188 - mean_absolute_error: 441.4968 - val_loss: 353698.6562 - val_mean_absolute_error: 436.1998
Epoch 35/100
3/3 - 0s - 8ms/step - loss: 321688.5938 - mean_absolute_error: 442.4872 - val_loss: 353233.1562 - val_mean_absolute_error: 435.7613
Epoch 36/100
3/3 - 0s - 8ms/step - loss: 316865.0000 - mean_absolute_error: 438.9899 - val_loss: 352736.5312 - val_mean_absolute_error: 435.2791
Epoch 37/100
3/3 - 0s - 8ms/step - loss: 320893.0938 - mean_absolute_error: 442.1071 - val_loss: 352256.1562 - val_mean_absolute_error: 434.8124
Epoch 38/100
3/3 - 0s - 8ms/step - loss: 319322.1250 - mean_absolute_error: 440.7699 - val_loss: 351762.4062 - val_mean_absolute_error: 434.3334
Epoch 39/100
3/3 - 0s - 8ms/step - loss: 313257.4688 - mean_absolute_error: 436.2627 - val_loss: 351248.0938 - val_mean_absolute_error: 433.8530
Epoch 40/100
3/3 - 0s - 8ms/step - loss: 315215.1250 - mean_absolute_error: 438.8935 - val_loss: 350742.0312 - val_mean_absolute_error: 433.3588
Epoch 41/100
3/3 - 0s - 8ms/step - loss: 319677.5312 - mean_absolute_error: 439.3379 - val_loss: 350257.8750 - val_mean_absolute_error: 432.8792
Epoch 42/100
3/3 - 0s - 8ms/step - loss: 312479.4375 - mean_absolute_error: 435.7321 - val_loss: 349737.0312 - val_mean_absolute_error: 432.3784
Epoch 43/100
3/3 - 0s - 8ms/step - loss: 313209.3125 - mean_absolute_error: 436.6355 - val_loss: 349225.6562 - val_mean_absolute_error: 431.8856
Epoch 44/100
3/3 - 0s - 8ms/step - loss: 315749.4062 - mean_absolute_error: 437.7092 - val_loss: 348732.6875 - val_mean_absolute_error: 431.4084
Epoch 45/100
3/3 - 0s - 8ms/step - loss: 311243.9062 - mean_absolute_error: 435.4446 - val_loss: 348211.5625 - val_mean_absolute_error: 430.8958
Epoch 46/100
3/3 - 0s - 8ms/step - loss: 312979.5312 - mean_absolute_error: 436.2506 - val_loss: 347703.2812 - val_mean_absolute_error: 430.3990
Epoch 47/100
3/3 - 0s - 8ms/step - loss: 312922.3438 - mean_absolute_error: 434.9267 - val_loss: 347202.2188 - val_mean_absolute_error: 429.9159
Epoch 48/100
3/3 - 0s - 8ms/step - loss: 312363.6875 - mean_absolute_error: 433.7431 - val_loss: 346700.2500 - val_mean_absolute_error: 429.4303
Epoch 49/100
3/3 - 0s - 9ms/step - loss: 315776.4688 - mean_absolute_error: 436.3251 - val_loss: 346209.9375 - val_mean_absolute_error: 428.9495
Epoch 50/100
3/3 - 0s - 8ms/step - loss: 309339.5938 - mean_absolute_error: 434.7451 - val_loss: 345681.4062 - val_mean_absolute_error: 428.4406
Epoch 51/100
3/3 - 0s - 8ms/step - loss: 310154.3125 - mean_absolute_error: 433.8997 - val_loss: 345169.6875 - val_mean_absolute_error: 427.9389
Epoch 52/100
3/3 - 0s - 9ms/step - loss: 308910.2500 - mean_absolute_error: 429.3655 - val_loss: 344661.6875 - val_mean_absolute_error: 427.4356
Epoch 53/100
3/3 - 0s - 8ms/step - loss: 308637.5312 - mean_absolute_error: 431.3277 - val_loss: 344144.8125 - val_mean_absolute_error: 426.9249
Epoch 54/100
3/3 - 0s - 8ms/step - loss: 309635.5938 - mean_absolute_error: 431.3557 - val_loss: 343637.5312 - val_mean_absolute_error: 426.4245
Epoch 55/100
3/3 - 0s - 8ms/step - loss: 309459.4375 - mean_absolute_error: 431.4037 - val_loss: 343128.9062 - val_mean_absolute_error: 425.9232
Epoch 56/100
3/3 - 0s - 8ms/step - loss: 309945.0000 - mean_absolute_error: 431.5732 - val_loss: 342616.3750 - val_mean_absolute_error: 425.4187
Epoch 57/100
3/3 - 0s - 8ms/step - loss: 300704.2812 - mean_absolute_error: 428.0811 - val_loss: 342049.9062 - val_mean_absolute_error: 424.8633
Epoch 58/100
3/3 - 0s - 8ms/step - loss: 302770.4062 - mean_absolute_error: 429.4569 - val_loss: 341515.9375 - val_mean_absolute_error: 424.3341
Epoch 59/100
3/3 - 0s - 8ms/step - loss: 311568.5938 - mean_absolute_error: 430.2013 - val_loss: 341040.4688 - val_mean_absolute_error: 423.8449
Epoch 60/100
3/3 - 0s - 8ms/step - loss: 302909.0625 - mean_absolute_error: 426.2511 - val_loss: 340505.5625 - val_mean_absolute_error: 423.3093
Epoch 61/100
3/3 - 0s - 8ms/step - loss: 306957.5000 - mean_absolute_error: 427.2728 - val_loss: 340001.0000 - val_mean_absolute_error: 422.8109
Epoch 62/100
3/3 - 0s - 8ms/step - loss: 306142.4062 - mean_absolute_error: 428.0296 - val_loss: 339475.0312 - val_mean_absolute_error: 422.3002
Epoch 63/100
3/3 - 0s - 8ms/step - loss: 305467.4375 - mean_absolute_error: 427.9183 - val_loss: 338954.4062 - val_mean_absolute_error: 421.8003
Epoch 64/100
3/3 - 0s - 8ms/step - loss: 303861.4688 - mean_absolute_error: 427.0388 - val_loss: 338418.3125 - val_mean_absolute_error: 421.2736
Epoch 65/100
3/3 - 0s - 8ms/step - loss: 301830.2188 - mean_absolute_error: 426.7215 - val_loss: 337865.5625 - val_mean_absolute_error: 420.7334
Epoch 66/100
3/3 - 0s - 8ms/step - loss: 305477.8750 - mean_absolute_error: 429.0033 - val_loss: 337342.2188 - val_mean_absolute_error: 420.2216
Epoch 67/100
3/3 - 0s - 8ms/step - loss: 301234.2500 - mean_absolute_error: 425.2712 - val_loss: 336801.5000 - val_mean_absolute_error: 419.6997
Epoch 68/100
3/3 - 0s - 8ms/step - loss: 301662.2188 - mean_absolute_error: 424.2025 - val_loss: 336270.9688 - val_mean_absolute_error: 419.1961
Epoch 69/100
3/3 - 0s - 8ms/step - loss: 300614.2500 - mean_absolute_error: 423.9529 - val_loss: 335746.1250 - val_mean_absolute_error: 418.7067
Epoch 70/100
3/3 - 0s - 8ms/step - loss: 298111.9062 - mean_absolute_error: 422.3536 - val_loss: 335198.4062 - val_mean_absolute_error: 418.1861
Epoch 71/100
3/3 - 0s - 8ms/step - loss: 299844.7812 - mean_absolute_error: 422.9518 - val_loss: 334662.1250 - val_mean_absolute_error: 417.6949
Epoch 72/100
3/3 - 0s - 8ms/step - loss: 296022.9062 - mean_absolute_error: 421.6513 - val_loss: 334112.6875 - val_mean_absolute_error: 417.1930
Epoch 73/100
3/3 - 0s - 8ms/step - loss: 300743.7812 - mean_absolute_error: 426.4748 - val_loss: 333577.6875 - val_mean_absolute_error: 416.7086
Epoch 74/100
3/3 - 0s - 8ms/step - loss: 296373.2812 - mean_absolute_error: 422.4988 - val_loss: 333022.1562 - val_mean_absolute_error: 416.1972
Epoch 75/100
3/3 - 0s - 8ms/step - loss: 297467.7188 - mean_absolute_error: 421.9259 - val_loss: 332491.8750 - val_mean_absolute_error: 415.7056
Epoch 76/100
3/3 - 0s - 8ms/step - loss: 298410.0625 - mean_absolute_error: 421.5951 - val_loss: 331966.1562 - val_mean_absolute_error: 415.2253
Epoch 77/100
3/3 - 0s - 8ms/step - loss: 297633.9375 - mean_absolute_error: 420.7800 - val_loss: 331428.9062 - val_mean_absolute_error: 414.7278
Epoch 78/100
3/3 - 0s - 8ms/step - loss: 292024.6250 - mean_absolute_error: 415.3840 - val_loss: 330871.2500 - val_mean_absolute_error: 414.2109
Epoch 79/100
3/3 - 0s - 8ms/step - loss: 299579.7812 - mean_absolute_error: 422.8188 - val_loss: 330338.0312 - val_mean_absolute_error: 413.7285
Epoch 80/100
3/3 - 0s - 8ms/step - loss: 295126.9062 - mean_absolute_error: 418.9749 - val_loss: 329805.6875 - val_mean_absolute_error: 413.2499
Epoch 81/100
3/3 - 0s - 8ms/step - loss: 288677.7812 - mean_absolute_error: 414.9009 - val_loss: 329235.0938 - val_mean_absolute_error: 412.7441
Epoch 82/100
3/3 - 0s - 8ms/step - loss: 296059.8438 - mean_absolute_error: 418.8122 - val_loss: 328707.5938 - val_mean_absolute_error: 412.2660
Epoch 83/100
3/3 - 0s - 8ms/step - loss: 288580.0938 - mean_absolute_error: 416.1070 - val_loss: 328131.6875 - val_mean_absolute_error: 411.7457
Epoch 84/100
3/3 - 0s - 8ms/step - loss: 295256.0938 - mean_absolute_error: 419.9008 - val_loss: 327593.7812 - val_mean_absolute_error: 411.2584
Epoch 85/100
3/3 - 0s - 8ms/step - loss: 299297.6562 - mean_absolute_error: 419.7556 - val_loss: 327091.5312 - val_mean_absolute_error: 410.8019
Epoch 86/100
3/3 - 0s - 8ms/step - loss: 287452.4062 - mean_absolute_error: 413.2182 - val_loss: 326508.1250 - val_mean_absolute_error: 410.2713
Epoch 87/100
3/3 - 0s - 8ms/step - loss: 293774.3438 - mean_absolute_error: 418.9258 - val_loss: 325967.0000 - val_mean_absolute_error: 409.7813
Epoch 88/100
3/3 - 0s - 8ms/step - loss: 289258.6875 - mean_absolute_error: 414.4059 - val_loss: 325406.4688 - val_mean_absolute_error: 409.2814
Epoch 89/100
3/3 - 0s - 8ms/step - loss: 289764.4688 - mean_absolute_error: 415.6427 - val_loss: 324853.9062 - val_mean_absolute_error: 408.7940
Epoch 90/100
3/3 - 0s - 8ms/step - loss: 284947.5312 - mean_absolute_error: 411.3497 - val_loss: 324272.1250 - val_mean_absolute_error: 408.2770
Epoch 91/100
3/3 - 0s - 8ms/step - loss: 288999.3438 - mean_absolute_error: 412.6786 - val_loss: 323725.2812 - val_mean_absolute_error: 407.7930
Epoch 92/100
3/3 - 0s - 8ms/step - loss: 289545.7188 - mean_absolute_error: 414.0194 - val_loss: 323183.1250 - val_mean_absolute_error: 407.3061
Epoch 93/100
3/3 - 0s - 8ms/step - loss: 284122.5625 - mean_absolute_error: 410.3819 - val_loss: 322605.4688 - val_mean_absolute_error: 406.7966
Epoch 94/100
3/3 - 0s - 8ms/step - loss: 284313.7812 - mean_absolute_error: 410.0853 - val_loss: 322038.9375 - val_mean_absolute_error: 406.2941
Epoch 95/100
3/3 - 0s - 8ms/step - loss: 286282.5625 - mean_absolute_error: 410.0935 - val_loss: 321480.0312 - val_mean_absolute_error: 405.7941
Epoch 96/100
3/3 - 0s - 8ms/step - loss: 282012.1875 - mean_absolute_error: 408.2502 - val_loss: 320901.5625 - val_mean_absolute_error: 405.2786
Epoch 97/100
3/3 - 0s - 8ms/step - loss: 280432.1875 - mean_absolute_error: 406.8622 - val_loss: 320330.3125 - val_mean_absolute_error: 404.7721
Epoch 98/100
3/3 - 0s - 8ms/step - loss: 280127.4375 - mean_absolute_error: 405.5800 - val_loss: 319763.2500 - val_mean_absolute_error: 404.2691
Epoch 99/100
3/3 - 0s - 8ms/step - loss: 282455.1562 - mean_absolute_error: 407.5822 - val_loss: 319207.3438 - val_mean_absolute_error: 403.7832
Epoch 100/100
3/3 - 0s - 8ms/step - loss: 285201.0625 - mean_absolute_error: 409.6684 - val_loss: 318662.8125 - val_mean_absolute_error: 403.3157
<keras.src.callbacks.history.History at 0x38b2b9520>
When training models this way you keep your eyes on the epoch history to study the behavior of the loss function and other metrics on training and test data sets. You have to make a judgement call as to when the optimization has stabilized and further progress is minimal. Alternatively, you can install a function that stops the optimization when certain conditions are met.
This is done in the following code with the callback_early_stopping
callback function (results not shown here). The options of the early stopping function ask it to monitor the loss function on the validation data and stop the optimization when the criterion fails to decrease (mode="min"
) over 10 epochs (patience=10
). Any change of the monitored metric has to be at least 0.1 in magnitude to qualify as an improvement (min_delta=.1
).
from tensorflow.keras.callbacks import EarlyStopping
= EarlyStopping(
early_stopping ='val_loss',
monitor=10,
patience=0.1,
min_delta='min',
mode=True
restore_best_weights
)
firstANN.fit(
x_train, y_train,=400,
epochs=32,
batch_size=(x_test, y_test),
validation_data=[early_stopping],
callbacks=0
verbose )
<keras.src.callbacks.history.History at 0x38a645b50>
To see a list of all Keras callback functions type
A list of activation functions supported by keras
(Keras), can be found here. Alternatively, you can type the following:
# help(tf.keras.callbacks.)
at the console prompt.
Finally, we predict from the final model, and evaluate its performance on the test data. Due to the use of random elements in the fit (stochastic gradient descent, random dropout, …), the results vary slightly with each fit. Unfortunately the set.seed()
function does not ensure identical results (since the fitting is done in python
), so your results will differ slightly.
# Make predictions on test set and calculate mean absolute error
= firstANN.predict(x_test)
predvals
abs(y_test - predvals.flatten())) np.mean(np.
1/3 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step 3/3 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step
289.744113589232
Random numbers
Neural networks rely on random numbers for picking starting values, selecting observations into mini batches, selecting neurons in dropout layers, etc. The underlying code relies on Random and NumPy random number generators. TensorFlow has its own random number generator on top of that. Python code that uses Keras with the TensorFlow backend needs to set the seed for each random number generator to obtain reproducible results. The keras.utils.set_random_seed()
function sets seeds for base Python (with the random package), NumPy, and TensorFlow.
1) keras.utils.set_random_seed(
Even with this control, Python might generate non-reproducible results. Multi-threading operations on CPUs—and GPUs in particular—can produce a non-deterministic order of operations.
One recommendation to deal with non-deterministic results is training the model several times and averaging the results, essentially ensembling them. When a single training run takes several hours, doing it thirty times is not practical.
MNIST Image Classification
We now return to the MNIST image classification data introduced in Section 32.4. Recall that the data comprise 60,000 training images and 10,000 test images of handwritten digits (0–9). Each image has 28 x 28 pixels recording a grayscale value.
The MNIST data is provided by Keras:
Setup the data
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
# Load MNIST dataset
= mnist.load_data()
(x_train, g_train), (x_test, g_test)
# Check dimensions
print("x_train shape:", x_train.shape)
print("x_test shape:", x_test.shape)
x_train shape: (60000, 28, 28)
x_test shape: (10000, 28, 28)
The images are stored as a three-dimensional array, and need to be reshaped into a matrix. For classification tasks with \(k\) categories, Keras expects as the target values a matrix of \(k\) columns. Column \(k\) contains ones in the rows for observations where the observed category is \(k\), and zeros otherwise. This is called one-hot encoding of the target variable. Luckily, keras
has built-in functions that handle both tasks for us.
# Reshape the data to flatten the 28x28 images into 784-dimensional vectors
= x_train.reshape(x_train.shape[0], 784)
x_train = x_test.reshape(x_test.shape[0], 784)
x_test
# Convert labels to categorical (one-hot encoding)
= to_categorical(g_train, 10)
y_train = to_categorical(g_test, 10) y_test
Let’s look at the one-hot encoding of the target data. g_test
contains the value of the digit from 0–9. y_test
is a matrix with 10 columns, each column corresponds to one digit. If observation \(i\) represents digit \(j\) then there is a 1 in row \(i\), column \(j+1\) of the encoded matrix. For example, for the first twenty images:
# Display first 20 original labels
print(g_test[:20])
[7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4]
# Display first 20 one-hot encoded labels (all 10 columns)
print(y_test[:20, :10])
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]
Let’s look at the matrix of inputs. The next array shows the 28 x 28 - 784 input columns for the third image. The values are grayscale values between 0 and 255.
# Display the 3rd test sample (index 2)
print(x_test[2])
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 38 254 109 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 87 252 82 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 135 241 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 45 244 150 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 84 254 63 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 202 223 11
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 32 254 216 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 95 254
195 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 140 254 77 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 57
237 205 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 124 255 165 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 171 254 81 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 24 232 215 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 120 254 159 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 151 254 142 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 228 254 66 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 61 251 254 66 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 141 254 205 3 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 10 215 254 121
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 5 198 176 10 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
Finally, prior to training the network, we scale the input values to lie between 0–1.
# Normalize pixel values to range [0, 1]
= x_train.astype('float32') / 255.0
x_train = x_test.astype('float32') / 255.0 x_test
The target variable does not need to be scaled, the one-hot encoding together with the use of a softmax output function ensures that the output for each category is a value between 0 and 1, and that they sum to 1 across the 10 categories. We will interpret them as predicted probabilities that an observed image is assigned to a particular digit.
To classify the MNIST images we consider two types of neural networks in the remainder of this chapter: a multi layer ANN and a network without a hidden layer. The latter is a multi category perceptron and very similar to a multinomial logistic regression model.
Multi layer neural network
We now train the network shown in Figure 32.14, an ANN with two hidden layers. We also add dropout regularization layers after each fully connected hidden layer. The first layer specifies the input shape of 28 x 28 = 784. It has 128 neurons and ReLU activation. Why? Because.
This is followed by a first dropout layer with rate \(\phi_1 = 0.3\), another fully connected hidden layer with 64 nodes and hyperbolic tangent activation function, a second dropout layer with rate \(\phi_2 = 0.2\), and a final softmax output layer. Why? Because.
Setup the network
The following statements set up the network in keras
:
= keras.models.Sequential()
modelnn = (784,)))
modelnn.add (layers.Input(shape = 128, activation = "relu", name = "FirstHidden"))
modelnn.add (layers.Dense(units = 0.3, name = "FirstDropOut"))
modelnn.add (layers.Dropout(rate =64, activation="tanh", name = "SecondHidden"))
modelnn.add (layers.Dense(units= 0.2, name = "SecondDropOut"))
modelnn.add (layers.Dropout(rate =10, activation = "softmax", name="Output")) modelnn.add (layers.Dense(units
The summary()
function let’s us inspect whether we got it all right.
print(modelnn.summary())
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ FirstHidden (Dense) │ (None, 128) │ 100,480 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ FirstDropOut (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ SecondHidden (Dense) │ (None, 64) │ 8,256 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ SecondDropOut (Dropout) │ (None, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ Output (Dense) │ (None, 10) │ 650 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 109,386 (427.29 KB)
Trainable params: 109,386 (427.29 KB)
Non-trainable params: 0 (0.00 B)
None
The total number of parameters in this network is 109,386, a sizeable network but not a huge network.
Set up the optimization
Next, we add details to the model to specify the fitting algorithm. We fit the model by minimizing the categorical cross-entropy function and monitor the classification accuracy during the iterations.
compile(loss="categorical_crossentropy",
modelnn.=keras.optimizers.RMSprop(),
optimizer=["accuracy"]) metrics
Fit the model
We are ready to go. The final step is to supply training data, and fit the model. With a batch size of 128 observations, each epoch corresponds to 60,000 / 128 = 469 gradient evaluations.
= modelnn.fit(x_train,
history
y_train,=20,
epochs=128,
batch_size=(x_test, y_test),
validation_data=2)
verbose
=(12, 6))
plt.figure(figsize
2, 1, 1)
plt.subplot(range(len(history.history['loss'])), history.history['loss'], label='Training Loss')
plt.scatter(range(len(history.history['val_loss'])), history.history['val_loss'], label='Validation Loss')
plt.scatter('Model Loss')
plt.title('Epoch')
plt.xlabel('Loss')
plt.ylabel(
plt.legend()
2, 1, 2)
plt.subplot(
range(len(history.history['accuracy'])), history.history['accuracy'], label='Training Accuracy')
plt.scatter(range(len(history.history['val_accuracy'])), history.history['val_accuracy'], label='Validation Accuracy')
plt.scatter('Model Accuracy')
plt.title('Epoch')
plt.xlabel('Accuracy')
plt.ylabel(='lower right')
plt.legend(loc
plt.tight_layout() plt.show()
Epoch 1/20
469/469 - 1s - 2ms/step - accuracy: 0.8852 - loss: 0.3960 - val_accuracy: 0.9477 - val_loss: 0.1608
Epoch 2/20
469/469 - 0s - 943us/step - accuracy: 0.9439 - loss: 0.1883 - val_accuracy: 0.9657 - val_loss: 0.1122
Epoch 3/20
469/469 - 0s - 920us/step - accuracy: 0.9558 - loss: 0.1482 - val_accuracy: 0.9704 - val_loss: 0.0952
Epoch 4/20
469/469 - 0s - 915us/step - accuracy: 0.9627 - loss: 0.1248 - val_accuracy: 0.9747 - val_loss: 0.0873
Epoch 5/20
469/469 - 0s - 914us/step - accuracy: 0.9667 - loss: 0.1121 - val_accuracy: 0.9745 - val_loss: 0.0828
Epoch 6/20
469/469 - 0s - 921us/step - accuracy: 0.9692 - loss: 0.1021 - val_accuracy: 0.9774 - val_loss: 0.0772
Epoch 7/20
469/469 - 0s - 927us/step - accuracy: 0.9714 - loss: 0.0939 - val_accuracy: 0.9774 - val_loss: 0.0776
Epoch 8/20
469/469 - 0s - 933us/step - accuracy: 0.9726 - loss: 0.0891 - val_accuracy: 0.9783 - val_loss: 0.0742
Epoch 9/20
469/469 - 0s - 928us/step - accuracy: 0.9747 - loss: 0.0819 - val_accuracy: 0.9788 - val_loss: 0.0720
Epoch 10/20
469/469 - 0s - 931us/step - accuracy: 0.9754 - loss: 0.0783 - val_accuracy: 0.9796 - val_loss: 0.0696
Epoch 11/20
469/469 - 0s - 931us/step - accuracy: 0.9776 - loss: 0.0730 - val_accuracy: 0.9807 - val_loss: 0.0663
Epoch 12/20
469/469 - 0s - 940us/step - accuracy: 0.9774 - loss: 0.0719 - val_accuracy: 0.9797 - val_loss: 0.0687
Epoch 13/20
469/469 - 0s - 931us/step - accuracy: 0.9786 - loss: 0.0688 - val_accuracy: 0.9804 - val_loss: 0.0672
Epoch 14/20
469/469 - 0s - 942us/step - accuracy: 0.9797 - loss: 0.0648 - val_accuracy: 0.9819 - val_loss: 0.0656
Epoch 15/20
469/469 - 0s - 935us/step - accuracy: 0.9796 - loss: 0.0635 - val_accuracy: 0.9794 - val_loss: 0.0677
Epoch 16/20
469/469 - 0s - 938us/step - accuracy: 0.9807 - loss: 0.0606 - val_accuracy: 0.9811 - val_loss: 0.0648
Epoch 17/20
469/469 - 0s - 928us/step - accuracy: 0.9815 - loss: 0.0591 - val_accuracy: 0.9801 - val_loss: 0.0688
Epoch 18/20
469/469 - 0s - 940us/step - accuracy: 0.9818 - loss: 0.0579 - val_accuracy: 0.9802 - val_loss: 0.0675
Epoch 19/20
469/469 - 0s - 956us/step - accuracy: 0.9833 - loss: 0.0544 - val_accuracy: 0.9808 - val_loss: 0.0686
Epoch 20/20
469/469 - 0s - 966us/step - accuracy: 0.9824 - loss: 0.0542 - val_accuracy: 0.9797 - val_loss: 0.0698
After about 10 epochs the training and validation accuracy are stabilizing although the loss continues to decrease. Interestingly, the accuracy and loss in the 10,000 image validation set is better than in the 60,000 image training data set. Considering that the grayscale values are entered into this neural network as 784 numeric input variables without taking into account any spatial arrangement of the pixels on the image, a classification accuracy of 98% on unseen images is quite good. Whether that is sufficient depends on the application.
As we will see in Chapter 36, neural networks that specialize in the processing of grid-like data such as images easily improve on this performance.
Calculate predicted categories
To calculate the predicted categories for the images in the test data set, we use the predict
function. The result of that operation is a vector of 10 predicted probabilities for each observation.
= modelnn.predict(x_test) predvals
1/313 ━━━━━━━━━━━━━━━━━━━━ 4s 15ms/step 212/313 ━━━━━━━━━━━━━━━━━━━━ 0s 237us/step 313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 281us/step
For the first image, the probabilities that its digit belongs to any of the 10 classes is given by this vector
0].round(4) predvals[
array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], dtype=float32)
0].argmax() predvals[
7
0].max().round(4) predvals[
1.0
The maximum probability is 1.0 in position 7. The image is classified as a “7” (the digits are 0-based, as is Python’s index numbering scheme).
numpy
provides the convenience function argmax()
to perform this operation; it returns the index of the maximum value:
= predvals.argmax(axis=1) predcl
Which of the first 500 observations were misclassified?
= [i for i in range(500) if predcl[i] != g_test[i]]
miscl miscl
[124, 217, 233, 247, 259, 321, 340, 381, 445, 449]
print(f"Observed value for obs # {miscl[0]}: {g_test[miscl[0]]}")
print(f"Predicted value for obs # {miscl[0]}: {predcl[miscl[0]]}")
Observed value for obs # 124: 7
Predicted value for obs # 124: 4
The first misclassified observation is #124. The observed digit value is 7, the predicted value is 4. The softmax probabilities for this observation show why it predicted category 4:
= "{:.4f}".format
float_formatter ={'float_kind':float_formatter})
np.set_printoptions(formatter
0]] predvals[miscl[
array([0.0000, 0.0003, 0.0000, 0.0002, 0.7088, 0.0000, 0.0000, 0.2899,
0.0001, 0.0007], dtype=float32)
We can visualize the data with the image
function. The next code segment does this for the first observation in the data set and for the first two mis-classified observations:
# visualize the digits
def plotIt(id=0):
= x_test[id] # Shape should be (28, 28)
pixels = pixels.reshape((28,28))
im
=(6, 6))
plt.figure(figsize='gray', origin='upper')
plt.imshow(im, cmapf"Index #{id} -- Image label: {g_test[id]}, Predicted: {predcl[id]}")
plt.title(
plt.show()
# Visualize specific digits
0)
plotIt(0])
plotIt(miscl[1]) plotIt(miscl[
Multinomial logistic regression
A 98% accuracy is impressive, but maybe it is not good enough. In applications where the consequences of errors are high, this accuracy might be insufficient. Suppose we are using the trained network to recognize written digits on personal checks. Getting 200 out of 10,000 digits wrong would be unacceptable. Banks would deposit incorrect amounts all the time.
If that is the application for the trained algorithm, we should consider other models for these data. This raises an interesting question: how much did we gain by adding the layers of the network? If this is an effective strategy to increase accuracy then we could consider adding more layers. If not, then maybe we need to research an entirely different network architecture.
Before trying deeper alternatives we can establish one performance benchmark by removing the hidden layers and training what essentially is a single layer perceptron (Section 32.1). This model has an input layer and an output layer. In terms of the keras
syntax it is specified with a single layer:
= keras.models.Sequential()
modellr =(784,)))
modellr.add(layers.Input(shape=10, activation='softmax'))
modellr.add(layers.Dense(units
modellr.summary()
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense_1 (Dense) │ (None, 10) │ 7,850 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 7,850 (30.66 KB)
Trainable params: 7,850 (30.66 KB)
Non-trainable params: 0 (0.00 B)
This is essentially a multinomial logistic regression model with a 10-category target variable and 784 input variables. The model is much smaller than the previous network (it has only 7,850 parameters) but is huge if we think of it as a multinomial logistic regression model. Many software packages for multinomial regression would struggle to fit a model of this size. When articulated as a neural network, training such a model is actually a breeze.
We proceed just as before.
compile(loss="categorical_crossentropy",
modellr.=keras.optimizers.RMSprop(),
optimizer=["accuracy"])
metrics
# Train the model
= modellr.fit(x_train,
history_lr
y_train,=20,
epochs=128,
batch_size=(x_test, y_test),
validation_data= 2)
verbose
# Plot training history
=(12, 6))
plt.figure(figsize
2, 1, 1)
plt.subplot(range(len(history_lr.history['loss'])), history_lr.history['loss'], label='Training Loss')
plt.scatter(range(len(history_lr.history['loss'])), history_lr.history['val_loss'], label='Validation Loss')
plt.scatter('Model Loss')
plt.title('Epoch')
plt.xlabel('Loss')
plt.ylabel(
plt.legend()
2, 1, 2)
plt.subplot(range(len(history_lr.history['loss'])), history_lr.history['accuracy'], label='Training Accuracy')
plt.scatter(range(len(history_lr.history['loss'])), history_lr.history['val_accuracy'], label='Validation Accuracy')
plt.scatter('Model Accuracy')
plt.title('Epoch')
plt.xlabel('Accuracy')
plt.ylabel(
plt.legend()
plt.tight_layout() plt.show()
Epoch 1/20
469/469 - 0s - 832us/step - accuracy: 0.8531 - loss: 0.5938 - val_accuracy: 0.9078 - val_loss: 0.3434
Epoch 2/20
469/469 - 0s - 486us/step - accuracy: 0.9071 - loss: 0.3331 - val_accuracy: 0.9170 - val_loss: 0.3018
Epoch 3/20
469/469 - 0s - 495us/step - accuracy: 0.9147 - loss: 0.3045 - val_accuracy: 0.9209 - val_loss: 0.2875
Epoch 4/20
469/469 - 0s - 485us/step - accuracy: 0.9187 - loss: 0.2912 - val_accuracy: 0.9223 - val_loss: 0.2803
Epoch 5/20
469/469 - 0s - 495us/step - accuracy: 0.9213 - loss: 0.2831 - val_accuracy: 0.9237 - val_loss: 0.2759
Epoch 6/20
469/469 - 0s - 469us/step - accuracy: 0.9233 - loss: 0.2775 - val_accuracy: 0.9248 - val_loss: 0.2730
Epoch 7/20
469/469 - 0s - 481us/step - accuracy: 0.9245 - loss: 0.2733 - val_accuracy: 0.9265 - val_loss: 0.2711
Epoch 8/20
469/469 - 0s - 484us/step - accuracy: 0.9257 - loss: 0.2700 - val_accuracy: 0.9260 - val_loss: 0.2697
Epoch 9/20
469/469 - 0s - 488us/step - accuracy: 0.9265 - loss: 0.2674 - val_accuracy: 0.9260 - val_loss: 0.2686
Epoch 10/20
469/469 - 0s - 484us/step - accuracy: 0.9273 - loss: 0.2651 - val_accuracy: 0.9264 - val_loss: 0.2678
Epoch 11/20
469/469 - 0s - 478us/step - accuracy: 0.9280 - loss: 0.2632 - val_accuracy: 0.9262 - val_loss: 0.2672
Epoch 12/20
469/469 - 0s - 482us/step - accuracy: 0.9286 - loss: 0.2616 - val_accuracy: 0.9260 - val_loss: 0.2668
Epoch 13/20
469/469 - 0s - 483us/step - accuracy: 0.9291 - loss: 0.2601 - val_accuracy: 0.9262 - val_loss: 0.2664
Epoch 14/20
469/469 - 0s - 483us/step - accuracy: 0.9292 - loss: 0.2588 - val_accuracy: 0.9266 - val_loss: 0.2661
Epoch 15/20
469/469 - 0s - 487us/step - accuracy: 0.9296 - loss: 0.2577 - val_accuracy: 0.9265 - val_loss: 0.2659
Epoch 16/20
469/469 - 0s - 488us/step - accuracy: 0.9300 - loss: 0.2566 - val_accuracy: 0.9268 - val_loss: 0.2658
Epoch 17/20
469/469 - 0s - 484us/step - accuracy: 0.9303 - loss: 0.2557 - val_accuracy: 0.9268 - val_loss: 0.2657
Epoch 18/20
469/469 - 0s - 487us/step - accuracy: 0.9307 - loss: 0.2548 - val_accuracy: 0.9268 - val_loss: 0.2656
Epoch 19/20
469/469 - 0s - 485us/step - accuracy: 0.9310 - loss: 0.2540 - val_accuracy: 0.9269 - val_loss: 0.2656
Epoch 20/20
469/469 - 0s - 486us/step - accuracy: 0.9313 - loss: 0.2533 - val_accuracy: 0.9271 - val_loss: 0.2656
Even with just a single layer, the model performs quite well, its accuracy is around 93%. Adding the additional layer in the previous ANN did improve the accuracy. On the other hand, it took more than 100,000 extra parameters to move from 93% to 98% accuracy.