import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
35 Neural Networks in Python (with Keras)
35.1 Introduction
There are a few Python
packages available to train neural networks. See Table 35.1 for some examples. The packages vary in capabilities.
Several frameworks for ANNs and deep learning exist. TensorFlow, Microsoft CNTK, PyTorch, and Theano are among the most important ones.
Python
packages for neural network analysis.
Package | Notes |
---|---|
sci-kit |
Includes a neural_network module for training Multi-layer Perceptrons (MLP) and Bernoulli Restricted Boltzmann Machines (RBM) |
tensorflow |
Interface to TensorFlow, a free and open-source software library for machine learning and artificial intelligence |
torch |
Pytorch: Tensors and neural networks with GPU acceleration |
keras |
Deep learning library |
Keras has emerged as an important API (Application Programming Interface) for deep learning. It provides a consistent interface on top of JAX, TensorFlow, or PyTorch. While TensorFlow is very powerful, the learning curve can be steep and you tend to write a lot of code. On the other hand, you have complete control over the types of models you build and train with TensorFlow. That makes Keras so relevant: you can tap into the capabilities of TensorFlow with a simpler API.
Tools from the modern machine learning toolbox tend to be written in Python. The keras
package in Python
calls into Tensorflow, or whatever deep learning framework Keras is running on.
35.2 Running Keras in Python
Packages
Keras Basics
Training a neural network with keras
involves three steps:
Defining the network
Setting up the optimization
Fitting the model
Not until the third step does the algorithm get in contact with actual data. However, we need to know some things about the data in order to define the network in step 1: the dimensions of the input and output.
Defining the network
The most convenient way of specifying a multi layer neural network is by adding layers sequentially, from the input layer to the output layer. These starts with a call to keras.models.Sequential()
. Suppose we want to predict a continuous response (regression application) based on inputs \(x_1, \cdots, x_{19}\) with one hidden layer and dropout regularization.
The following statements define the model sequentially:
= keras.models.Sequential()
firstANN =(19,))) #preferred method to specify Input shape separately
firstANN.add(layers.Input(shape=50, activation='relu'))
firstANN.add(layers.Dense(units=0.4))
firstANN.add(layers.Dropout(rate=1, name='Output')) firstANN.add(layers.Dense(units
layers.Input
specifies the shape of the model input. layers.Dense()
adds a fully connected layer to the networks, the units=
option specifies the number of neurons in the layer. In summary, the hidden layer receives 19 inputs and has 50 output units (neurons) and ReLU activation. The output from the hidden layer is passed on (piped) to a dropout layer with a dropout rate of \(\phi = 0.4\). The result of the dropout layer is passed on to another fully connected layer with a single neuron. This is the output layer of the network. In other words, the last layer in the sequence is automatically the output layer. Since we are in a regression context to predict a numeric target variable, there is only one output unit in the final layer. If this was a classification problem with \(5\) categories, the last layer would have 5 units.
You can assign a name to each layer with the name=
option, this makes it easier to identify the layers in output. If you do not specify a name, Keras will assign a name that combines a description of the layer type with a numerical index (not always). The numeric indices can be confusing because they depend on counters internal to the Python code. Assigning an explicit name is recommended practice.
The activation=
option specifies the activation function \(\sigma()\) for the hidden layers and the output function \(g()\) for the output layer. The default is the identity (“linear”) activation, \(\sigma(x) = x\). This default is appropriate for the output layer in a regression application. For the hidden layer we choose the ReLU activation.
A list of activation functions supported by keras
(Keras), can be found here.
The basic neural network is now defined and we can find out how many parameters it entails.
print(firstANN.summary())
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense (Dense) │ (None, 50) │ 1,000 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dropout (Dropout) │ (None, 50) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ Output (Dense) │ (None, 1) │ 51 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 1,051 (4.11 KB)
Trainable params: 1,051 (4.11 KB)
Non-trainable params: 0 (0.00 B)
None
With 19 inputs and 50 neurons, the first layer has 50 x 20 = 1000 parameters (20 slopes and an intercept for each output neuron). The dropout layer does not add any parameters to the estimation, it chooses output neurons of the previous layer at random and sets their activation to zero. The 50 neurons (some with activation set randomly to zero) are the input to the final layer, adding fifty weights (slopes) and one bias (intercept). The total number of parameters of this neural network is 1,051.
Setting up the optimization
The second step in training a model in Keras is to specify the particulars of the optimization with the Keras compile()
function. Typical specifications include the loss functions, the type of optimization algorithm, and the metrics evaluated by the model during training.
The following function call uses the RMSProp algorithm with mean-squared error loss function to estimate the parameters of the network. During training, the mean absolute error is also monitored in addition to the mean squared error.
compile(
firstANN.='mse', # mean squared error
loss='rmsprop',
optimizer=['mean_absolute_error']
metrics )
Depending on your environment, not all optimization algorithms are supported.
Fitting the model
The last step in training the network is to connect the defined and compiled model with training—and possibly test—data.
For this example we use the Hitters
data from the ISLP package. This is a data set with 322 observations of major league baseball players from the 1986 and 1987 seasons. The following code removes observations with missing values from the data frame, defines a vector of ids for the test data (1/3 of the observations) and computes a scaled and centered model matrix using all 19 input variables.
import numpy as np
import pandas as pd
from ISLP import load_data
= load_data('Hitters')
Hitters
# Remove rows with missing values
= Hitters.dropna()
Gitters
#Get number of rows
= len(Gitters)
n
#Set random seed for reproducibility
13)
keras.utils.set_random_seed(
# Create test set (1/3 of data)
= int(n / 3)
ntest = np.random.choice(n, size=ntest, replace=False)
testid
# Create feature matrix (excluding Salary)
= Gitters.drop('Salary', axis=1)
X
# Convert categorical variables to dummy variables
= pd.get_dummies(X, drop_first=True, dtype='int')
X
from sklearn.preprocessing import StandardScaler
# Scale the features
= StandardScaler()
scaler = pd.DataFrame(scaler.fit_transform(X))
x
# Target variable
= Gitters['Salary'].values
y
# Split into training and test sets
= x.iloc[-testid]
x_train = y[-testid]
y_train
= x.iloc[testid]
x_test = y[testid] y_test
Note that the model contains several factors (League
, Division
, NewLeague
) whose levels are encoded as binary variables in the model matrix. One could exclude those from scaling and centering as they already are in the proper range. In a regression model you would not want to scale these variables to preserve the interpretation of their coefficients. In a neural network interpretation of the model coefficients is not important and we include all columns of the model matrix in the scaling operation.
The following code fits the model to the training data (-testid
) using 20 epochs and a minibatch size of 32. That means the gradient is computed based on 32 randomly chosen observations in each step of the stochastic gradient descent algorithm. Since there are 176 training observations it takes \(176/32=5.5\) SGD steps to process all \(n\) observations. This is known as an epoch and is akin to the concept of an iteration in numerical optimization: a full pass through the data. The fundamental difference between an epoch and an iteration lies in the fact that updates of the parameters occur after each gradient computation. In a full iteration, there is one update after the pass through the entire data. In SGD with minibatch, there are multiple updates of the parameters, one for each minibatch.
Running 200 epochs with a batch size of 32 and a training set size of 176 results in 200 * 5.5 = 1,100 gradient evaluations.
The validation_data=
option lists the test data for the training. The objective function and metrics specified in the compile
command earlier are computed at each epoch for the training and the test data if the latter is specified. If you do not have a validation data set, you can specify validation_split=
and request that a fraction of the training data is held back for validation.
= firstANN.fit(
history
x_train, y_train,=200,
epochs=32,
batch_size=(x_test, y_test),
validation_data=2
verbose )
Epoch 1/200
3/3 - 0s - 60ms/step - loss: 388486.0625 - mean_absolute_error: 502.0297 - val_loss: 416839.9062 - val_mean_absolute_error: 491.9469
Epoch 2/200
3/3 - 0s - 6ms/step - loss: 388364.3750 - mean_absolute_error: 501.8449 - val_loss: 416676.4062 - val_mean_absolute_error: 491.8086
Epoch 3/200
3/3 - 0s - 6ms/step - loss: 388021.6875 - mean_absolute_error: 501.6020 - val_loss: 416534.7500 - val_mean_absolute_error: 491.6915
Epoch 4/200
3/3 - 0s - 6ms/step - loss: 387976.5625 - mean_absolute_error: 501.5332 - val_loss: 416402.5312 - val_mean_absolute_error: 491.5871
Epoch 5/200
3/3 - 0s - 6ms/step - loss: 387753.9375 - mean_absolute_error: 501.4447 - val_loss: 416273.6562 - val_mean_absolute_error: 491.4811
Epoch 6/200
3/3 - 0s - 6ms/step - loss: 387682.5625 - mean_absolute_error: 501.3298 - val_loss: 416149.3438 - val_mean_absolute_error: 491.3789
Epoch 7/200
3/3 - 0s - 6ms/step - loss: 387462.8438 - mean_absolute_error: 501.2206 - val_loss: 416027.9688 - val_mean_absolute_error: 491.2835
Epoch 8/200
3/3 - 0s - 6ms/step - loss: 387380.5000 - mean_absolute_error: 501.1319 - val_loss: 415906.0312 - val_mean_absolute_error: 491.1881
Epoch 9/200
3/3 - 0s - 6ms/step - loss: 387306.4375 - mean_absolute_error: 501.0211 - val_loss: 415788.3750 - val_mean_absolute_error: 491.0915
Epoch 10/200
3/3 - 0s - 6ms/step - loss: 387075.7812 - mean_absolute_error: 500.8499 - val_loss: 415671.4375 - val_mean_absolute_error: 490.9995
Epoch 11/200
3/3 - 0s - 6ms/step - loss: 387019.7812 - mean_absolute_error: 500.8143 - val_loss: 415556.8750 - val_mean_absolute_error: 490.9101
Epoch 12/200
3/3 - 0s - 6ms/step - loss: 386916.2188 - mean_absolute_error: 500.8023 - val_loss: 415439.7812 - val_mean_absolute_error: 490.8142
Epoch 13/200
3/3 - 0s - 6ms/step - loss: 386714.5625 - mean_absolute_error: 500.5801 - val_loss: 415319.0938 - val_mean_absolute_error: 490.7181
Epoch 14/200
3/3 - 0s - 6ms/step - loss: 386508.4688 - mean_absolute_error: 500.4550 - val_loss: 415199.4375 - val_mean_absolute_error: 490.6267
Epoch 15/200
3/3 - 0s - 6ms/step - loss: 386514.4688 - mean_absolute_error: 500.4272 - val_loss: 415080.5938 - val_mean_absolute_error: 490.5330
Epoch 16/200
3/3 - 0s - 6ms/step - loss: 386406.0625 - mean_absolute_error: 500.3553 - val_loss: 414956.5625 - val_mean_absolute_error: 490.4357
Epoch 17/200
3/3 - 0s - 6ms/step - loss: 386105.4375 - mean_absolute_error: 500.1644 - val_loss: 414826.9375 - val_mean_absolute_error: 490.3358
Epoch 18/200
3/3 - 0s - 6ms/step - loss: 386024.9062 - mean_absolute_error: 500.0773 - val_loss: 414705.8750 - val_mean_absolute_error: 490.2396
Epoch 19/200
3/3 - 0s - 6ms/step - loss: 385818.9375 - mean_absolute_error: 499.9171 - val_loss: 414579.2188 - val_mean_absolute_error: 490.1403
Epoch 20/200
3/3 - 0s - 6ms/step - loss: 385879.2188 - mean_absolute_error: 499.9145 - val_loss: 414452.5938 - val_mean_absolute_error: 490.0388
Epoch 21/200
3/3 - 0s - 6ms/step - loss: 385689.2812 - mean_absolute_error: 499.7969 - val_loss: 414322.3438 - val_mean_absolute_error: 489.9374
Epoch 22/200
3/3 - 0s - 6ms/step - loss: 385556.7812 - mean_absolute_error: 499.6568 - val_loss: 414194.6562 - val_mean_absolute_error: 489.8388
Epoch 23/200
3/3 - 0s - 6ms/step - loss: 385348.9062 - mean_absolute_error: 499.5273 - val_loss: 414063.4375 - val_mean_absolute_error: 489.7364
Epoch 24/200
3/3 - 0s - 6ms/step - loss: 385294.2500 - mean_absolute_error: 499.4364 - val_loss: 413931.2500 - val_mean_absolute_error: 489.6312
Epoch 25/200
3/3 - 0s - 6ms/step - loss: 384787.8750 - mean_absolute_error: 499.1058 - val_loss: 413789.4375 - val_mean_absolute_error: 489.5216
Epoch 26/200
3/3 - 0s - 6ms/step - loss: 384841.4375 - mean_absolute_error: 499.1262 - val_loss: 413649.7500 - val_mean_absolute_error: 489.4167
Epoch 27/200
3/3 - 0s - 6ms/step - loss: 384789.0938 - mean_absolute_error: 499.1046 - val_loss: 413511.1250 - val_mean_absolute_error: 489.3115
Epoch 28/200
3/3 - 0s - 6ms/step - loss: 384774.5938 - mean_absolute_error: 499.0273 - val_loss: 413375.6875 - val_mean_absolute_error: 489.2089
Epoch 29/200
3/3 - 0s - 6ms/step - loss: 384371.7188 - mean_absolute_error: 498.8341 - val_loss: 413227.2500 - val_mean_absolute_error: 489.0982
Epoch 30/200
3/3 - 0s - 6ms/step - loss: 384642.8438 - mean_absolute_error: 498.8668 - val_loss: 413087.7188 - val_mean_absolute_error: 488.9924
Epoch 31/200
3/3 - 0s - 6ms/step - loss: 384046.3125 - mean_absolute_error: 498.6273 - val_loss: 412939.0312 - val_mean_absolute_error: 488.8818
Epoch 32/200
3/3 - 0s - 6ms/step - loss: 383864.5000 - mean_absolute_error: 498.4449 - val_loss: 412789.1875 - val_mean_absolute_error: 488.7663
Epoch 33/200
3/3 - 0s - 6ms/step - loss: 383875.9688 - mean_absolute_error: 498.2418 - val_loss: 412640.9062 - val_mean_absolute_error: 488.6510
Epoch 34/200
3/3 - 0s - 6ms/step - loss: 383899.5000 - mean_absolute_error: 498.4843 - val_loss: 412492.4062 - val_mean_absolute_error: 488.5359
Epoch 35/200
3/3 - 0s - 6ms/step - loss: 383103.0000 - mean_absolute_error: 497.7596 - val_loss: 412333.2812 - val_mean_absolute_error: 488.4176
Epoch 36/200
3/3 - 0s - 6ms/step - loss: 383119.1875 - mean_absolute_error: 497.9456 - val_loss: 412171.9688 - val_mean_absolute_error: 488.2989
Epoch 37/200
3/3 - 0s - 6ms/step - loss: 383015.1875 - mean_absolute_error: 497.6767 - val_loss: 412013.6875 - val_mean_absolute_error: 488.1805
Epoch 38/200
3/3 - 0s - 6ms/step - loss: 383179.0625 - mean_absolute_error: 497.8082 - val_loss: 411863.2500 - val_mean_absolute_error: 488.0657
Epoch 39/200
3/3 - 0s - 6ms/step - loss: 382813.6875 - mean_absolute_error: 497.3881 - val_loss: 411706.0625 - val_mean_absolute_error: 487.9463
Epoch 40/200
3/3 - 0s - 6ms/step - loss: 382060.5625 - mean_absolute_error: 497.0763 - val_loss: 411535.0000 - val_mean_absolute_error: 487.8198
Epoch 41/200
3/3 - 0s - 6ms/step - loss: 382711.9688 - mean_absolute_error: 497.4388 - val_loss: 411381.7500 - val_mean_absolute_error: 487.7069
Epoch 42/200
3/3 - 0s - 6ms/step - loss: 382573.0312 - mean_absolute_error: 497.3329 - val_loss: 411222.6562 - val_mean_absolute_error: 487.5864
Epoch 43/200
3/3 - 0s - 6ms/step - loss: 381977.4062 - mean_absolute_error: 496.9503 - val_loss: 411052.0312 - val_mean_absolute_error: 487.4578
Epoch 44/200
3/3 - 0s - 6ms/step - loss: 381707.4688 - mean_absolute_error: 496.9191 - val_loss: 410879.7188 - val_mean_absolute_error: 487.3334
Epoch 45/200
3/3 - 0s - 6ms/step - loss: 380939.5312 - mean_absolute_error: 496.3333 - val_loss: 410693.9375 - val_mean_absolute_error: 487.1967
Epoch 46/200
3/3 - 0s - 6ms/step - loss: 381652.7188 - mean_absolute_error: 496.6101 - val_loss: 410528.0312 - val_mean_absolute_error: 487.0757
Epoch 47/200
3/3 - 0s - 6ms/step - loss: 381093.3750 - mean_absolute_error: 496.3058 - val_loss: 410351.0312 - val_mean_absolute_error: 486.9452
Epoch 48/200
3/3 - 0s - 6ms/step - loss: 381167.4375 - mean_absolute_error: 496.1671 - val_loss: 410178.7188 - val_mean_absolute_error: 486.8167
Epoch 49/200
3/3 - 0s - 6ms/step - loss: 380788.5625 - mean_absolute_error: 495.9537 - val_loss: 409995.3438 - val_mean_absolute_error: 486.6807
Epoch 50/200
3/3 - 0s - 6ms/step - loss: 380414.4375 - mean_absolute_error: 495.8411 - val_loss: 409810.1250 - val_mean_absolute_error: 486.5479
Epoch 51/200
3/3 - 0s - 6ms/step - loss: 379982.8438 - mean_absolute_error: 495.4834 - val_loss: 409614.5625 - val_mean_absolute_error: 486.4047
Epoch 52/200
3/3 - 0s - 6ms/step - loss: 380572.8125 - mean_absolute_error: 495.8604 - val_loss: 409435.3438 - val_mean_absolute_error: 486.2751
Epoch 53/200
3/3 - 0s - 6ms/step - loss: 380101.1875 - mean_absolute_error: 495.5234 - val_loss: 409249.0000 - val_mean_absolute_error: 486.1403
Epoch 54/200
3/3 - 0s - 6ms/step - loss: 379681.0000 - mean_absolute_error: 495.4297 - val_loss: 409056.4688 - val_mean_absolute_error: 485.9995
Epoch 55/200
3/3 - 0s - 6ms/step - loss: 379885.4688 - mean_absolute_error: 495.3477 - val_loss: 408865.3750 - val_mean_absolute_error: 485.8605
Epoch 56/200
3/3 - 0s - 6ms/step - loss: 379842.3125 - mean_absolute_error: 495.1807 - val_loss: 408678.2188 - val_mean_absolute_error: 485.7268
Epoch 57/200
3/3 - 0s - 6ms/step - loss: 379560.5625 - mean_absolute_error: 495.2909 - val_loss: 408487.0000 - val_mean_absolute_error: 485.5900
Epoch 58/200
3/3 - 0s - 6ms/step - loss: 379032.6562 - mean_absolute_error: 494.6447 - val_loss: 408290.2188 - val_mean_absolute_error: 485.4455
Epoch 59/200
3/3 - 0s - 6ms/step - loss: 378539.0312 - mean_absolute_error: 494.3192 - val_loss: 408083.9062 - val_mean_absolute_error: 485.2957
Epoch 60/200
3/3 - 0s - 6ms/step - loss: 378500.6875 - mean_absolute_error: 494.2646 - val_loss: 407879.3125 - val_mean_absolute_error: 485.1460
Epoch 61/200
3/3 - 0s - 6ms/step - loss: 378845.3438 - mean_absolute_error: 494.4706 - val_loss: 407679.7188 - val_mean_absolute_error: 485.0021
Epoch 62/200
3/3 - 0s - 6ms/step - loss: 378631.4375 - mean_absolute_error: 494.4530 - val_loss: 407476.1250 - val_mean_absolute_error: 484.8536
Epoch 63/200
3/3 - 0s - 6ms/step - loss: 377979.7812 - mean_absolute_error: 494.0929 - val_loss: 407262.9375 - val_mean_absolute_error: 484.6985
Epoch 64/200
3/3 - 0s - 6ms/step - loss: 378048.9062 - mean_absolute_error: 493.8495 - val_loss: 407053.8438 - val_mean_absolute_error: 484.5495
Epoch 65/200
3/3 - 0s - 6ms/step - loss: 377537.5312 - mean_absolute_error: 493.5985 - val_loss: 406837.7812 - val_mean_absolute_error: 484.3951
Epoch 66/200
3/3 - 0s - 6ms/step - loss: 376416.3438 - mean_absolute_error: 493.1370 - val_loss: 406608.0312 - val_mean_absolute_error: 484.2307
Epoch 67/200
3/3 - 0s - 6ms/step - loss: 377670.3125 - mean_absolute_error: 493.6776 - val_loss: 406399.5938 - val_mean_absolute_error: 484.0809
Epoch 68/200
3/3 - 0s - 6ms/step - loss: 376703.9688 - mean_absolute_error: 492.9100 - val_loss: 406177.2812 - val_mean_absolute_error: 483.9204
Epoch 69/200
3/3 - 0s - 6ms/step - loss: 376626.1562 - mean_absolute_error: 492.8196 - val_loss: 405957.0000 - val_mean_absolute_error: 483.7623
Epoch 70/200
3/3 - 0s - 6ms/step - loss: 375350.9062 - mean_absolute_error: 492.0239 - val_loss: 405723.7188 - val_mean_absolute_error: 483.5962
Epoch 71/200
3/3 - 0s - 6ms/step - loss: 376024.3438 - mean_absolute_error: 492.4986 - val_loss: 405495.7188 - val_mean_absolute_error: 483.4353
Epoch 72/200
3/3 - 0s - 6ms/step - loss: 376656.5000 - mean_absolute_error: 492.6935 - val_loss: 405280.3125 - val_mean_absolute_error: 483.2761
Epoch 73/200
3/3 - 0s - 6ms/step - loss: 375158.9375 - mean_absolute_error: 492.1847 - val_loss: 405043.5312 - val_mean_absolute_error: 483.1081
Epoch 74/200
3/3 - 0s - 6ms/step - loss: 374905.4375 - mean_absolute_error: 491.7352 - val_loss: 404805.8750 - val_mean_absolute_error: 482.9377
Epoch 75/200
3/3 - 0s - 6ms/step - loss: 375157.6250 - mean_absolute_error: 491.9117 - val_loss: 404572.1250 - val_mean_absolute_error: 482.7654
Epoch 76/200
3/3 - 0s - 6ms/step - loss: 373985.5938 - mean_absolute_error: 491.5028 - val_loss: 404327.5938 - val_mean_absolute_error: 482.5939
Epoch 77/200
3/3 - 0s - 6ms/step - loss: 374335.0312 - mean_absolute_error: 491.3471 - val_loss: 404086.1250 - val_mean_absolute_error: 482.4163
Epoch 78/200
3/3 - 0s - 6ms/step - loss: 374333.3438 - mean_absolute_error: 491.1887 - val_loss: 403847.9062 - val_mean_absolute_error: 482.2443
Epoch 79/200
3/3 - 0s - 6ms/step - loss: 373872.4062 - mean_absolute_error: 490.8742 - val_loss: 403601.6562 - val_mean_absolute_error: 482.0674
Epoch 80/200
3/3 - 0s - 6ms/step - loss: 373543.1250 - mean_absolute_error: 490.7766 - val_loss: 403352.6875 - val_mean_absolute_error: 481.8855
Epoch 81/200
3/3 - 0s - 6ms/step - loss: 373597.9688 - mean_absolute_error: 490.3448 - val_loss: 403108.3125 - val_mean_absolute_error: 481.7078
Epoch 82/200
3/3 - 0s - 6ms/step - loss: 373545.3125 - mean_absolute_error: 490.5653 - val_loss: 402863.0938 - val_mean_absolute_error: 481.5305
Epoch 83/200
3/3 - 0s - 6ms/step - loss: 372347.4375 - mean_absolute_error: 490.1262 - val_loss: 402597.1562 - val_mean_absolute_error: 481.3428
Epoch 84/200
3/3 - 0s - 6ms/step - loss: 372487.0312 - mean_absolute_error: 490.2118 - val_loss: 402337.4375 - val_mean_absolute_error: 481.1567
Epoch 85/200
3/3 - 0s - 6ms/step - loss: 371893.0000 - mean_absolute_error: 489.6255 - val_loss: 402073.5312 - val_mean_absolute_error: 480.9651
Epoch 86/200
3/3 - 0s - 6ms/step - loss: 372337.5938 - mean_absolute_error: 489.6902 - val_loss: 401818.6562 - val_mean_absolute_error: 480.7779
Epoch 87/200
3/3 - 0s - 6ms/step - loss: 371839.6875 - mean_absolute_error: 489.2011 - val_loss: 401557.0938 - val_mean_absolute_error: 480.5858
Epoch 88/200
3/3 - 0s - 6ms/step - loss: 370071.7500 - mean_absolute_error: 488.2936 - val_loss: 401275.6875 - val_mean_absolute_error: 480.3847
Epoch 89/200
3/3 - 0s - 6ms/step - loss: 369790.2188 - mean_absolute_error: 488.0841 - val_loss: 400994.1250 - val_mean_absolute_error: 480.1784
Epoch 90/200
3/3 - 0s - 6ms/step - loss: 371252.4688 - mean_absolute_error: 488.7569 - val_loss: 400733.5625 - val_mean_absolute_error: 479.9895
Epoch 91/200
3/3 - 0s - 6ms/step - loss: 370293.1875 - mean_absolute_error: 488.5912 - val_loss: 400457.7812 - val_mean_absolute_error: 479.7899
Epoch 92/200
3/3 - 0s - 6ms/step - loss: 368767.4062 - mean_absolute_error: 487.4139 - val_loss: 400172.7500 - val_mean_absolute_error: 479.5848
Epoch 93/200
3/3 - 0s - 6ms/step - loss: 369646.2188 - mean_absolute_error: 487.8867 - val_loss: 399896.3750 - val_mean_absolute_error: 479.3791
Epoch 94/200
3/3 - 0s - 6ms/step - loss: 369074.8125 - mean_absolute_error: 487.6077 - val_loss: 399614.9062 - val_mean_absolute_error: 479.1714
Epoch 95/200
3/3 - 0s - 6ms/step - loss: 369946.9375 - mean_absolute_error: 487.5156 - val_loss: 399342.9375 - val_mean_absolute_error: 478.9716
Epoch 96/200
3/3 - 0s - 6ms/step - loss: 369152.4688 - mean_absolute_error: 487.2570 - val_loss: 399059.8750 - val_mean_absolute_error: 478.7653
Epoch 97/200
3/3 - 0s - 6ms/step - loss: 368583.3438 - mean_absolute_error: 487.4776 - val_loss: 398770.9375 - val_mean_absolute_error: 478.5570
Epoch 98/200
3/3 - 0s - 6ms/step - loss: 367835.6250 - mean_absolute_error: 487.1383 - val_loss: 398468.5625 - val_mean_absolute_error: 478.3348
Epoch 99/200
3/3 - 0s - 6ms/step - loss: 367526.2188 - mean_absolute_error: 486.2899 - val_loss: 398170.3438 - val_mean_absolute_error: 478.1170
Epoch 100/200
3/3 - 0s - 6ms/step - loss: 368908.9688 - mean_absolute_error: 486.9029 - val_loss: 397889.4375 - val_mean_absolute_error: 477.9042
Epoch 101/200
3/3 - 0s - 6ms/step - loss: 366864.1562 - mean_absolute_error: 485.9761 - val_loss: 397585.1875 - val_mean_absolute_error: 477.6821
Epoch 102/200
3/3 - 0s - 6ms/step - loss: 366421.4688 - mean_absolute_error: 485.2467 - val_loss: 397277.6562 - val_mean_absolute_error: 477.4545
Epoch 103/200
3/3 - 0s - 6ms/step - loss: 366058.6250 - mean_absolute_error: 485.5532 - val_loss: 396964.1250 - val_mean_absolute_error: 477.2196
Epoch 104/200
3/3 - 0s - 6ms/step - loss: 366459.0312 - mean_absolute_error: 485.3553 - val_loss: 396662.0312 - val_mean_absolute_error: 476.9945
Epoch 105/200
3/3 - 0s - 6ms/step - loss: 363244.0938 - mean_absolute_error: 483.5682 - val_loss: 396326.8438 - val_mean_absolute_error: 476.7440
Epoch 106/200
3/3 - 0s - 6ms/step - loss: 365870.4062 - mean_absolute_error: 484.8974 - val_loss: 396028.0938 - val_mean_absolute_error: 476.5174
Epoch 107/200
3/3 - 0s - 6ms/step - loss: 366684.7500 - mean_absolute_error: 485.0379 - val_loss: 395737.2812 - val_mean_absolute_error: 476.2988
Epoch 108/200
3/3 - 0s - 6ms/step - loss: 364142.7188 - mean_absolute_error: 484.1183 - val_loss: 395413.1562 - val_mean_absolute_error: 476.0589
Epoch 109/200
3/3 - 0s - 6ms/step - loss: 363561.7812 - mean_absolute_error: 483.0749 - val_loss: 395088.8125 - val_mean_absolute_error: 475.8187
Epoch 110/200
3/3 - 0s - 6ms/step - loss: 364529.3750 - mean_absolute_error: 483.5722 - val_loss: 394780.2812 - val_mean_absolute_error: 475.5893
Epoch 111/200
3/3 - 0s - 6ms/step - loss: 362655.6250 - mean_absolute_error: 482.4728 - val_loss: 394459.3438 - val_mean_absolute_error: 475.3486
Epoch 112/200
3/3 - 0s - 6ms/step - loss: 362565.3125 - mean_absolute_error: 482.0943 - val_loss: 394133.3750 - val_mean_absolute_error: 475.1020
Epoch 113/200
3/3 - 0s - 6ms/step - loss: 361419.8125 - mean_absolute_error: 481.9980 - val_loss: 393799.3438 - val_mean_absolute_error: 474.8523
Epoch 114/200
3/3 - 0s - 6ms/step - loss: 364027.3438 - mean_absolute_error: 483.1765 - val_loss: 393487.9062 - val_mean_absolute_error: 474.6173
Epoch 115/200
3/3 - 0s - 6ms/step - loss: 362293.7812 - mean_absolute_error: 483.1160 - val_loss: 393152.5000 - val_mean_absolute_error: 474.3635
Epoch 116/200
3/3 - 0s - 6ms/step - loss: 361563.4375 - mean_absolute_error: 482.0281 - val_loss: 392820.5938 - val_mean_absolute_error: 474.1121
Epoch 117/200
3/3 - 0s - 6ms/step - loss: 361742.2188 - mean_absolute_error: 481.3855 - val_loss: 392484.7812 - val_mean_absolute_error: 473.8531
Epoch 118/200
3/3 - 0s - 6ms/step - loss: 361105.6562 - mean_absolute_error: 481.3491 - val_loss: 392142.7500 - val_mean_absolute_error: 473.5945
Epoch 119/200
3/3 - 0s - 6ms/step - loss: 358738.0625 - mean_absolute_error: 480.1275 - val_loss: 391791.6875 - val_mean_absolute_error: 473.3267
Epoch 120/200
3/3 - 0s - 6ms/step - loss: 360015.8750 - mean_absolute_error: 480.2613 - val_loss: 391452.5000 - val_mean_absolute_error: 473.0651
Epoch 121/200
3/3 - 0s - 6ms/step - loss: 359674.7812 - mean_absolute_error: 480.6773 - val_loss: 391114.0625 - val_mean_absolute_error: 472.8065
Epoch 122/200
3/3 - 0s - 6ms/step - loss: 358852.1562 - mean_absolute_error: 479.5386 - val_loss: 390762.2500 - val_mean_absolute_error: 472.5412
Epoch 123/200
3/3 - 0s - 6ms/step - loss: 357259.7812 - mean_absolute_error: 478.7014 - val_loss: 390400.0938 - val_mean_absolute_error: 472.2665
Epoch 124/200
3/3 - 0s - 6ms/step - loss: 359645.1250 - mean_absolute_error: 480.0151 - val_loss: 390067.8750 - val_mean_absolute_error: 472.0100
Epoch 125/200
3/3 - 0s - 6ms/step - loss: 357001.6875 - mean_absolute_error: 478.5434 - val_loss: 389702.0625 - val_mean_absolute_error: 471.7307
Epoch 126/200
3/3 - 0s - 6ms/step - loss: 357230.5625 - mean_absolute_error: 479.2144 - val_loss: 389351.5000 - val_mean_absolute_error: 471.4699
Epoch 127/200
3/3 - 0s - 6ms/step - loss: 356324.8750 - mean_absolute_error: 477.7635 - val_loss: 388993.5625 - val_mean_absolute_error: 471.1945
Epoch 128/200
3/3 - 0s - 6ms/step - loss: 357363.1875 - mean_absolute_error: 478.5337 - val_loss: 388645.3750 - val_mean_absolute_error: 470.9240
Epoch 129/200
3/3 - 0s - 6ms/step - loss: 357999.4062 - mean_absolute_error: 478.4773 - val_loss: 388301.3438 - val_mean_absolute_error: 470.6612
Epoch 130/200
3/3 - 0s - 6ms/step - loss: 356246.7500 - mean_absolute_error: 477.6912 - val_loss: 387930.4375 - val_mean_absolute_error: 470.3739
Epoch 131/200
3/3 - 0s - 6ms/step - loss: 355565.1562 - mean_absolute_error: 475.9735 - val_loss: 387563.9062 - val_mean_absolute_error: 470.0876
Epoch 132/200
3/3 - 0s - 6ms/step - loss: 356281.3750 - mean_absolute_error: 477.5853 - val_loss: 387204.7500 - val_mean_absolute_error: 469.8106
Epoch 133/200
3/3 - 0s - 6ms/step - loss: 354654.7812 - mean_absolute_error: 476.2086 - val_loss: 386823.1875 - val_mean_absolute_error: 469.5172
Epoch 134/200
3/3 - 0s - 6ms/step - loss: 354162.9062 - mean_absolute_error: 475.1717 - val_loss: 386447.8125 - val_mean_absolute_error: 469.2269
Epoch 135/200
3/3 - 0s - 6ms/step - loss: 354565.6875 - mean_absolute_error: 476.3767 - val_loss: 386069.6875 - val_mean_absolute_error: 468.9341
Epoch 136/200
3/3 - 0s - 6ms/step - loss: 352412.3125 - mean_absolute_error: 474.8342 - val_loss: 385688.0312 - val_mean_absolute_error: 468.6381
Epoch 137/200
3/3 - 0s - 6ms/step - loss: 355332.1562 - mean_absolute_error: 475.9840 - val_loss: 385323.6562 - val_mean_absolute_error: 468.3482
Epoch 138/200
3/3 - 0s - 6ms/step - loss: 353287.7812 - mean_absolute_error: 474.6594 - val_loss: 384944.5938 - val_mean_absolute_error: 468.0461
Epoch 139/200
3/3 - 0s - 6ms/step - loss: 351783.1562 - mean_absolute_error: 474.1898 - val_loss: 384552.0312 - val_mean_absolute_error: 467.7370
Epoch 140/200
3/3 - 0s - 6ms/step - loss: 352167.1875 - mean_absolute_error: 474.6198 - val_loss: 384169.4375 - val_mean_absolute_error: 467.4398
Epoch 141/200
3/3 - 0s - 6ms/step - loss: 351374.1250 - mean_absolute_error: 473.8765 - val_loss: 383776.3125 - val_mean_absolute_error: 467.1318
Epoch 142/200
3/3 - 0s - 6ms/step - loss: 351726.0000 - mean_absolute_error: 473.9620 - val_loss: 383386.2188 - val_mean_absolute_error: 466.8254
Epoch 143/200
3/3 - 0s - 6ms/step - loss: 351790.5625 - mean_absolute_error: 473.5245 - val_loss: 383011.8750 - val_mean_absolute_error: 466.5314
Epoch 144/200
3/3 - 0s - 6ms/step - loss: 349712.2500 - mean_absolute_error: 471.1650 - val_loss: 382615.9062 - val_mean_absolute_error: 466.2201
Epoch 145/200
3/3 - 0s - 6ms/step - loss: 349838.4688 - mean_absolute_error: 472.1855 - val_loss: 382221.1562 - val_mean_absolute_error: 465.9094
Epoch 146/200
3/3 - 0s - 6ms/step - loss: 351217.9375 - mean_absolute_error: 472.5852 - val_loss: 381833.3438 - val_mean_absolute_error: 465.5952
Epoch 147/200
3/3 - 0s - 6ms/step - loss: 348790.5625 - mean_absolute_error: 472.0415 - val_loss: 381427.2812 - val_mean_absolute_error: 465.2798
Epoch 148/200
3/3 - 0s - 6ms/step - loss: 348132.9688 - mean_absolute_error: 471.3526 - val_loss: 381020.1250 - val_mean_absolute_error: 464.9601
Epoch 149/200
3/3 - 0s - 6ms/step - loss: 344812.0938 - mean_absolute_error: 470.3674 - val_loss: 380587.4062 - val_mean_absolute_error: 464.6228
Epoch 150/200
3/3 - 0s - 6ms/step - loss: 346755.1875 - mean_absolute_error: 470.8776 - val_loss: 380178.4688 - val_mean_absolute_error: 464.3010
Epoch 151/200
3/3 - 0s - 6ms/step - loss: 345183.5312 - mean_absolute_error: 469.1143 - val_loss: 379762.5625 - val_mean_absolute_error: 463.9720
Epoch 152/200
3/3 - 0s - 6ms/step - loss: 346121.9062 - mean_absolute_error: 469.6398 - val_loss: 379354.3125 - val_mean_absolute_error: 463.6471
Epoch 153/200
3/3 - 0s - 6ms/step - loss: 347805.7500 - mean_absolute_error: 469.9405 - val_loss: 378954.3750 - val_mean_absolute_error: 463.3254
Epoch 154/200
3/3 - 0s - 6ms/step - loss: 345419.6250 - mean_absolute_error: 468.3078 - val_loss: 378538.7500 - val_mean_absolute_error: 462.9927
Epoch 155/200
3/3 - 0s - 6ms/step - loss: 341886.5938 - mean_absolute_error: 465.8555 - val_loss: 378108.0000 - val_mean_absolute_error: 462.6513
Epoch 156/200
3/3 - 0s - 6ms/step - loss: 345938.9062 - mean_absolute_error: 469.7661 - val_loss: 377705.3750 - val_mean_absolute_error: 462.3337
Epoch 157/200
3/3 - 0s - 6ms/step - loss: 344007.1875 - mean_absolute_error: 468.2422 - val_loss: 377289.0938 - val_mean_absolute_error: 462.0031
Epoch 158/200
3/3 - 0s - 6ms/step - loss: 341974.6250 - mean_absolute_error: 467.2412 - val_loss: 376849.7812 - val_mean_absolute_error: 461.6506
Epoch 159/200
3/3 - 0s - 6ms/step - loss: 343292.5938 - mean_absolute_error: 467.3942 - val_loss: 376440.5938 - val_mean_absolute_error: 461.3245
Epoch 160/200
3/3 - 0s - 6ms/step - loss: 341729.8438 - mean_absolute_error: 466.0130 - val_loss: 376016.3750 - val_mean_absolute_error: 460.9867
Epoch 161/200
3/3 - 0s - 6ms/step - loss: 343597.0625 - mean_absolute_error: 466.8106 - val_loss: 375597.5938 - val_mean_absolute_error: 460.6504
Epoch 162/200
3/3 - 0s - 6ms/step - loss: 343411.9688 - mean_absolute_error: 466.3511 - val_loss: 375182.1250 - val_mean_absolute_error: 460.3161
Epoch 163/200
3/3 - 0s - 6ms/step - loss: 342387.8750 - mean_absolute_error: 466.0270 - val_loss: 374752.4375 - val_mean_absolute_error: 459.9651
Epoch 164/200
3/3 - 0s - 6ms/step - loss: 341017.6562 - mean_absolute_error: 465.0678 - val_loss: 374313.1562 - val_mean_absolute_error: 459.6122
Epoch 165/200
3/3 - 0s - 6ms/step - loss: 343248.5938 - mean_absolute_error: 465.4084 - val_loss: 373895.4375 - val_mean_absolute_error: 459.2758
Epoch 166/200
3/3 - 0s - 6ms/step - loss: 336636.3750 - mean_absolute_error: 461.8055 - val_loss: 373432.9062 - val_mean_absolute_error: 458.9002
Epoch 167/200
3/3 - 0s - 6ms/step - loss: 339104.5000 - mean_absolute_error: 464.2389 - val_loss: 372992.4688 - val_mean_absolute_error: 458.5395
Epoch 168/200
3/3 - 0s - 6ms/step - loss: 339672.5000 - mean_absolute_error: 464.0140 - val_loss: 372555.0938 - val_mean_absolute_error: 458.1811
Epoch 169/200
3/3 - 0s - 6ms/step - loss: 336042.4688 - mean_absolute_error: 462.9228 - val_loss: 372095.9688 - val_mean_absolute_error: 457.8123
Epoch 170/200
3/3 - 0s - 6ms/step - loss: 338547.9688 - mean_absolute_error: 462.9911 - val_loss: 371659.2500 - val_mean_absolute_error: 457.4561
Epoch 171/200
3/3 - 0s - 6ms/step - loss: 336286.4375 - mean_absolute_error: 462.3979 - val_loss: 371202.0000 - val_mean_absolute_error: 457.0852
Epoch 172/200
3/3 - 0s - 6ms/step - loss: 337823.8750 - mean_absolute_error: 462.3825 - val_loss: 370758.0625 - val_mean_absolute_error: 456.7201
Epoch 173/200
3/3 - 0s - 6ms/step - loss: 336933.5625 - mean_absolute_error: 461.0280 - val_loss: 370315.3750 - val_mean_absolute_error: 456.3624
Epoch 174/200
3/3 - 0s - 6ms/step - loss: 335185.1562 - mean_absolute_error: 459.7150 - val_loss: 369863.5312 - val_mean_absolute_error: 455.9883
Epoch 175/200
3/3 - 0s - 6ms/step - loss: 332966.6562 - mean_absolute_error: 460.3362 - val_loss: 369396.4375 - val_mean_absolute_error: 455.6133
Epoch 176/200
3/3 - 0s - 6ms/step - loss: 334832.5312 - mean_absolute_error: 460.7783 - val_loss: 368941.8750 - val_mean_absolute_error: 455.2387
Epoch 177/200
3/3 - 0s - 6ms/step - loss: 330226.8438 - mean_absolute_error: 459.3182 - val_loss: 368449.8438 - val_mean_absolute_error: 454.8383
Epoch 178/200
3/3 - 0s - 6ms/step - loss: 336422.1250 - mean_absolute_error: 461.2665 - val_loss: 368019.9688 - val_mean_absolute_error: 454.4809
Epoch 179/200
3/3 - 0s - 6ms/step - loss: 336272.6562 - mean_absolute_error: 459.8713 - val_loss: 367587.5000 - val_mean_absolute_error: 454.1153
Epoch 180/200
3/3 - 0s - 6ms/step - loss: 331326.3438 - mean_absolute_error: 457.9808 - val_loss: 367108.5625 - val_mean_absolute_error: 453.7223
Epoch 181/200
3/3 - 0s - 6ms/step - loss: 329859.2812 - mean_absolute_error: 456.1172 - val_loss: 366637.2812 - val_mean_absolute_error: 453.3390
Epoch 182/200
3/3 - 0s - 6ms/step - loss: 333142.6562 - mean_absolute_error: 458.0789 - val_loss: 366183.4688 - val_mean_absolute_error: 452.9633
Epoch 183/200
3/3 - 0s - 6ms/step - loss: 334313.2812 - mean_absolute_error: 458.4320 - val_loss: 365741.9375 - val_mean_absolute_error: 452.5912
Epoch 184/200
3/3 - 0s - 6ms/step - loss: 335057.6875 - mean_absolute_error: 458.7414 - val_loss: 365293.4375 - val_mean_absolute_error: 452.2067
Epoch 185/200
3/3 - 0s - 6ms/step - loss: 335695.5000 - mean_absolute_error: 458.2279 - val_loss: 364845.5625 - val_mean_absolute_error: 451.8333
Epoch 186/200
3/3 - 0s - 6ms/step - loss: 329460.7500 - mean_absolute_error: 455.0563 - val_loss: 364349.9688 - val_mean_absolute_error: 451.4233
Epoch 187/200
3/3 - 0s - 6ms/step - loss: 327034.1250 - mean_absolute_error: 453.8841 - val_loss: 363853.3750 - val_mean_absolute_error: 451.0091
Epoch 188/200
3/3 - 0s - 6ms/step - loss: 328330.5938 - mean_absolute_error: 456.2253 - val_loss: 363369.9688 - val_mean_absolute_error: 450.6119
Epoch 189/200
3/3 - 0s - 6ms/step - loss: 330951.7812 - mean_absolute_error: 456.7161 - val_loss: 362904.1875 - val_mean_absolute_error: 450.2242
Epoch 190/200
3/3 - 0s - 6ms/step - loss: 328024.2812 - mean_absolute_error: 454.6515 - val_loss: 362428.4688 - val_mean_absolute_error: 449.8307
Epoch 191/200
3/3 - 0s - 6ms/step - loss: 328093.2500 - mean_absolute_error: 453.9495 - val_loss: 361941.9688 - val_mean_absolute_error: 449.4154
Epoch 192/200
3/3 - 0s - 6ms/step - loss: 327136.0625 - mean_absolute_error: 453.1562 - val_loss: 361461.6562 - val_mean_absolute_error: 449.0129
Epoch 193/200
3/3 - 0s - 6ms/step - loss: 328581.0625 - mean_absolute_error: 455.1703 - val_loss: 360980.9062 - val_mean_absolute_error: 448.6056
Epoch 194/200
3/3 - 0s - 6ms/step - loss: 323409.6875 - mean_absolute_error: 451.0608 - val_loss: 360477.9688 - val_mean_absolute_error: 448.1843
Epoch 195/200
3/3 - 0s - 6ms/step - loss: 326472.3750 - mean_absolute_error: 453.1009 - val_loss: 359997.6562 - val_mean_absolute_error: 447.7808
Epoch 196/200
3/3 - 0s - 6ms/step - loss: 324789.8438 - mean_absolute_error: 452.3427 - val_loss: 359501.6562 - val_mean_absolute_error: 447.3632
Epoch 197/200
3/3 - 0s - 6ms/step - loss: 325098.6250 - mean_absolute_error: 453.4882 - val_loss: 359012.3750 - val_mean_absolute_error: 446.9469
Epoch 198/200
3/3 - 0s - 6ms/step - loss: 324289.4375 - mean_absolute_error: 450.9069 - val_loss: 358521.1562 - val_mean_absolute_error: 446.5310
Epoch 199/200
3/3 - 0s - 6ms/step - loss: 324940.4688 - mean_absolute_error: 452.2688 - val_loss: 358025.5938 - val_mean_absolute_error: 446.1035
Epoch 200/200
3/3 - 0s - 6ms/step - loss: 319032.0312 - mean_absolute_error: 448.2297 - val_loss: 357505.8750 - val_mean_absolute_error: 445.6594
Keras reports for each epoch the value of the loss metric (mean squared error) for the training and validation data and the monitored metrics (mean absolute error) for the validation data. As you can see from the lengthy output, all criteria are still decreasing after 200 epochs. It is helpful to view the epoch history graphically. If you run the code in an interactive environment (e.g., Anaconda Spyder), the epoch history is displayed and updated live. You can always plot the epoch history
using the matplotlib
package:
# Plot training history
import matplotlib.pyplot as plt
=(12, 6))
plt.figure(figsize
2, 1, 1)
plt.subplot(range(len(history.history['loss'])), history.history['loss'], label='Training Loss', alpha=0.7)
plt.scatter(range(len(history.history['val_loss'])), history.history['val_loss'], label='Validation Loss', alpha=0.7)
plt.scatter('Model Loss')
plt.title('Epoch')
plt.xlabel('Loss')
plt.ylabel(
plt.legend()
2, 1, 2)
plt.subplot(range(len(history.history['mean_absolute_error'])), history.history['mean_absolute_error'], label='Training MAE')
plt.scatter(range(len(history.history['val_mean_absolute_error'])), history.history['val_mean_absolute_error'], label='Validation MAE')
plt.scatter('Model Mean Absolute Error')
plt.title('Epoch')
plt.xlabel('Mean Absolute Error')
plt.ylabel(
plt.legend()
plt.tight_layout() plt.show()

All criteria are steadily declining and have not leveled out after 200 epochs (Figure 35.1). As expected, the mean squared error and mean absolute error are higher in the validation data than in the training data. This is not always the case when training neural networks. Maybe surprisingly, after about 75 epochs the metrics are showing more ariability from epoch to epoch in the training data than in the validation data. Also, there is no guarantee that criteria decrease monotonically, the mean squared error of epoch \(t\) can be higher than that of epoch \(t-1\). We are looking for the results to settle down and stabilize before calling the optimization completed. More epochs need to be run in this example. Fortunately, you can continue where the previous run has left off. The following code trains the network for another 100 epochs:
firstANN.fit(
x_train, y_train,=100,
epochs=32,
batch_size=(x_test, y_test),
validation_data=2) verbose
Epoch 1/100
3/3 - 0s - 19ms/step - loss: 323937.6562 - mean_absolute_error: 449.4143 - val_loss: 357014.8125 - val_mean_absolute_error: 445.2343
Epoch 2/100
3/3 - 0s - 12ms/step - loss: 319381.0000 - mean_absolute_error: 448.2860 - val_loss: 356502.3750 - val_mean_absolute_error: 444.7992
Epoch 3/100
3/3 - 0s - 6ms/step - loss: 321118.4688 - mean_absolute_error: 448.6617 - val_loss: 356009.1875 - val_mean_absolute_error: 444.3873
Epoch 4/100
3/3 - 0s - 6ms/step - loss: 319522.9375 - mean_absolute_error: 446.2070 - val_loss: 355502.8125 - val_mean_absolute_error: 443.9497
Epoch 5/100
3/3 - 0s - 6ms/step - loss: 320085.8438 - mean_absolute_error: 446.9688 - val_loss: 355007.8750 - val_mean_absolute_error: 443.5231
Epoch 6/100
3/3 - 0s - 6ms/step - loss: 317181.1250 - mean_absolute_error: 446.1050 - val_loss: 354494.1562 - val_mean_absolute_error: 443.0797
Epoch 7/100
3/3 - 0s - 6ms/step - loss: 322741.6250 - mean_absolute_error: 448.2556 - val_loss: 354010.1250 - val_mean_absolute_error: 442.6494
Epoch 8/100
3/3 - 0s - 6ms/step - loss: 321246.1562 - mean_absolute_error: 447.0703 - val_loss: 353520.4062 - val_mean_absolute_error: 442.2297
Epoch 9/100
3/3 - 0s - 6ms/step - loss: 317558.5938 - mean_absolute_error: 447.2191 - val_loss: 353000.5312 - val_mean_absolute_error: 441.7872
Epoch 10/100
3/3 - 0s - 6ms/step - loss: 318489.2500 - mean_absolute_error: 445.2418 - val_loss: 352494.0312 - val_mean_absolute_error: 441.3463
Epoch 11/100
3/3 - 0s - 6ms/step - loss: 319823.4688 - mean_absolute_error: 446.9185 - val_loss: 351995.5938 - val_mean_absolute_error: 440.9074
Epoch 12/100
3/3 - 0s - 6ms/step - loss: 311913.2500 - mean_absolute_error: 442.1082 - val_loss: 351452.0000 - val_mean_absolute_error: 440.4411
Epoch 13/100
3/3 - 0s - 6ms/step - loss: 316366.4062 - mean_absolute_error: 443.6360 - val_loss: 350947.3125 - val_mean_absolute_error: 440.0007
Epoch 14/100
3/3 - 0s - 6ms/step - loss: 321783.8125 - mean_absolute_error: 445.9359 - val_loss: 350468.7812 - val_mean_absolute_error: 439.5779
Epoch 15/100
3/3 - 0s - 6ms/step - loss: 313855.0312 - mean_absolute_error: 441.1244 - val_loss: 349938.4375 - val_mean_absolute_error: 439.1177
Epoch 16/100
3/3 - 0s - 6ms/step - loss: 309676.9062 - mean_absolute_error: 440.2711 - val_loss: 349392.9062 - val_mean_absolute_error: 438.6456
Epoch 17/100
3/3 - 0s - 6ms/step - loss: 313989.5625 - mean_absolute_error: 443.3115 - val_loss: 348874.6250 - val_mean_absolute_error: 438.1914
Epoch 18/100
3/3 - 0s - 6ms/step - loss: 312668.9062 - mean_absolute_error: 440.8082 - val_loss: 348347.5000 - val_mean_absolute_error: 437.7237
Epoch 19/100
3/3 - 0s - 6ms/step - loss: 308587.5000 - mean_absolute_error: 438.5932 - val_loss: 347813.5938 - val_mean_absolute_error: 437.2508
Epoch 20/100
3/3 - 0s - 6ms/step - loss: 311675.0000 - mean_absolute_error: 438.8261 - val_loss: 347290.9062 - val_mean_absolute_error: 436.7880
Epoch 21/100
3/3 - 0s - 6ms/step - loss: 310333.0938 - mean_absolute_error: 440.5679 - val_loss: 346770.3125 - val_mean_absolute_error: 436.3299
Epoch 22/100
3/3 - 0s - 6ms/step - loss: 305580.8750 - mean_absolute_error: 437.5335 - val_loss: 346215.1875 - val_mean_absolute_error: 435.8399
Epoch 23/100
3/3 - 0s - 6ms/step - loss: 311452.3438 - mean_absolute_error: 439.6997 - val_loss: 345704.4688 - val_mean_absolute_error: 435.3791
Epoch 24/100
3/3 - 0s - 6ms/step - loss: 308057.2812 - mean_absolute_error: 439.9796 - val_loss: 345172.5938 - val_mean_absolute_error: 434.9092
Epoch 25/100
3/3 - 0s - 6ms/step - loss: 307704.7812 - mean_absolute_error: 437.3419 - val_loss: 344643.9688 - val_mean_absolute_error: 434.4417
Epoch 26/100
3/3 - 0s - 6ms/step - loss: 309932.1250 - mean_absolute_error: 437.9238 - val_loss: 344137.6562 - val_mean_absolute_error: 433.9914
Epoch 27/100
3/3 - 0s - 6ms/step - loss: 310016.7500 - mean_absolute_error: 439.2341 - val_loss: 343611.7188 - val_mean_absolute_error: 433.5265
Epoch 28/100
3/3 - 0s - 6ms/step - loss: 306950.3125 - mean_absolute_error: 437.6606 - val_loss: 343075.2500 - val_mean_absolute_error: 433.0461
Epoch 29/100
3/3 - 0s - 6ms/step - loss: 306805.5312 - mean_absolute_error: 436.8515 - val_loss: 342540.7500 - val_mean_absolute_error: 432.5713
Epoch 30/100
3/3 - 0s - 6ms/step - loss: 304741.7812 - mean_absolute_error: 435.3903 - val_loss: 341999.9062 - val_mean_absolute_error: 432.0836
Epoch 31/100
3/3 - 0s - 6ms/step - loss: 305041.5938 - mean_absolute_error: 435.9172 - val_loss: 341464.6562 - val_mean_absolute_error: 431.6056
Epoch 32/100
3/3 - 0s - 6ms/step - loss: 304501.7812 - mean_absolute_error: 436.7936 - val_loss: 340922.2188 - val_mean_absolute_error: 431.1154
Epoch 33/100
3/3 - 0s - 6ms/step - loss: 306947.7812 - mean_absolute_error: 435.7789 - val_loss: 340386.8438 - val_mean_absolute_error: 430.6361
Epoch 34/100
3/3 - 0s - 6ms/step - loss: 306232.4688 - mean_absolute_error: 434.5029 - val_loss: 339853.8750 - val_mean_absolute_error: 430.1491
Epoch 35/100
3/3 - 0s - 6ms/step - loss: 301526.3438 - mean_absolute_error: 432.2083 - val_loss: 339300.4688 - val_mean_absolute_error: 429.6428
Epoch 36/100
3/3 - 0s - 6ms/step - loss: 303033.1250 - mean_absolute_error: 430.3025 - val_loss: 338766.2188 - val_mean_absolute_error: 429.1497
Epoch 37/100
3/3 - 0s - 6ms/step - loss: 299555.6875 - mean_absolute_error: 430.6850 - val_loss: 338202.0312 - val_mean_absolute_error: 428.6314
Epoch 38/100
3/3 - 0s - 6ms/step - loss: 301898.3438 - mean_absolute_error: 432.1886 - val_loss: 337657.2812 - val_mean_absolute_error: 428.1342
Epoch 39/100
3/3 - 0s - 6ms/step - loss: 298342.1562 - mean_absolute_error: 432.6975 - val_loss: 337098.0312 - val_mean_absolute_error: 427.6256
Epoch 40/100
3/3 - 0s - 6ms/step - loss: 300030.0312 - mean_absolute_error: 430.5144 - val_loss: 336556.0625 - val_mean_absolute_error: 427.1268
Epoch 41/100
3/3 - 0s - 6ms/step - loss: 299962.9375 - mean_absolute_error: 429.9703 - val_loss: 336013.1875 - val_mean_absolute_error: 426.6354
Epoch 42/100
3/3 - 0s - 6ms/step - loss: 295985.1875 - mean_absolute_error: 428.3843 - val_loss: 335444.1875 - val_mean_absolute_error: 426.1105
Epoch 43/100
3/3 - 0s - 6ms/step - loss: 298032.0938 - mean_absolute_error: 431.0013 - val_loss: 334891.7812 - val_mean_absolute_error: 425.6140
Epoch 44/100
3/3 - 0s - 6ms/step - loss: 298007.6250 - mean_absolute_error: 430.5838 - val_loss: 334346.2500 - val_mean_absolute_error: 425.1152
Epoch 45/100
3/3 - 0s - 6ms/step - loss: 292888.5938 - mean_absolute_error: 423.8717 - val_loss: 333787.4375 - val_mean_absolute_error: 424.6048
Epoch 46/100
3/3 - 0s - 6ms/step - loss: 298915.0938 - mean_absolute_error: 429.8017 - val_loss: 333251.5625 - val_mean_absolute_error: 424.1134
Epoch 47/100
3/3 - 0s - 6ms/step - loss: 294517.0000 - mean_absolute_error: 423.5296 - val_loss: 332691.8125 - val_mean_absolute_error: 423.5997
Epoch 48/100
3/3 - 0s - 6ms/step - loss: 302118.1562 - mean_absolute_error: 427.5919 - val_loss: 332179.7188 - val_mean_absolute_error: 423.1152
Epoch 49/100
3/3 - 0s - 6ms/step - loss: 294722.6562 - mean_absolute_error: 427.0931 - val_loss: 331622.2812 - val_mean_absolute_error: 422.5969
Epoch 50/100
3/3 - 0s - 6ms/step - loss: 295143.6875 - mean_absolute_error: 426.6412 - val_loss: 331056.9062 - val_mean_absolute_error: 422.0786
Epoch 51/100
3/3 - 0s - 6ms/step - loss: 298170.3438 - mean_absolute_error: 427.5193 - val_loss: 330519.4375 - val_mean_absolute_error: 421.5737
Epoch 52/100
3/3 - 0s - 6ms/step - loss: 289761.7812 - mean_absolute_error: 423.0247 - val_loss: 329929.4375 - val_mean_absolute_error: 421.0258
Epoch 53/100
3/3 - 0s - 6ms/step - loss: 290783.1250 - mean_absolute_error: 422.3107 - val_loss: 329368.7188 - val_mean_absolute_error: 420.5050
Epoch 54/100
3/3 - 0s - 6ms/step - loss: 291920.6875 - mean_absolute_error: 422.6753 - val_loss: 328803.6250 - val_mean_absolute_error: 419.9784
Epoch 55/100
3/3 - 0s - 6ms/step - loss: 290103.8125 - mean_absolute_error: 422.4146 - val_loss: 328239.1250 - val_mean_absolute_error: 419.4466
Epoch 56/100
3/3 - 0s - 6ms/step - loss: 295783.0000 - mean_absolute_error: 424.5052 - val_loss: 327701.5938 - val_mean_absolute_error: 418.9373
Epoch 57/100
3/3 - 0s - 6ms/step - loss: 288562.8125 - mean_absolute_error: 421.2257 - val_loss: 327135.7812 - val_mean_absolute_error: 418.4074
Epoch 58/100
3/3 - 0s - 6ms/step - loss: 290584.2500 - mean_absolute_error: 423.0610 - val_loss: 326570.5625 - val_mean_absolute_error: 417.8769
Epoch 59/100
3/3 - 0s - 6ms/step - loss: 291334.3438 - mean_absolute_error: 422.8824 - val_loss: 326012.7500 - val_mean_absolute_error: 417.3474
Epoch 60/100
3/3 - 0s - 6ms/step - loss: 287939.5000 - mean_absolute_error: 418.3311 - val_loss: 325451.0625 - val_mean_absolute_error: 416.8220
Epoch 61/100
3/3 - 0s - 6ms/step - loss: 282171.0938 - mean_absolute_error: 416.9257 - val_loss: 324867.3125 - val_mean_absolute_error: 416.2711
Epoch 62/100
3/3 - 0s - 6ms/step - loss: 287579.2500 - mean_absolute_error: 420.3936 - val_loss: 324305.9688 - val_mean_absolute_error: 415.7350
Epoch 63/100
3/3 - 0s - 6ms/step - loss: 281387.3438 - mean_absolute_error: 416.7201 - val_loss: 323708.0000 - val_mean_absolute_error: 415.1718
Epoch 64/100
3/3 - 0s - 6ms/step - loss: 290604.7500 - mean_absolute_error: 421.6886 - val_loss: 323158.4688 - val_mean_absolute_error: 414.6393
Epoch 65/100
3/3 - 0s - 6ms/step - loss: 285920.9375 - mean_absolute_error: 417.9337 - val_loss: 322587.2812 - val_mean_absolute_error: 414.0877
Epoch 66/100
3/3 - 0s - 6ms/step - loss: 281280.2812 - mean_absolute_error: 416.1843 - val_loss: 321997.1562 - val_mean_absolute_error: 413.5276
Epoch 67/100
3/3 - 0s - 6ms/step - loss: 281973.1875 - mean_absolute_error: 414.9070 - val_loss: 321432.7812 - val_mean_absolute_error: 412.9807
Epoch 68/100
3/3 - 0s - 6ms/step - loss: 284085.4375 - mean_absolute_error: 417.1419 - val_loss: 320878.0000 - val_mean_absolute_error: 412.4518
Epoch 69/100
3/3 - 0s - 6ms/step - loss: 280357.2812 - mean_absolute_error: 414.9259 - val_loss: 320300.5000 - val_mean_absolute_error: 411.8957
Epoch 70/100
3/3 - 0s - 6ms/step - loss: 279103.4375 - mean_absolute_error: 414.0888 - val_loss: 319723.6875 - val_mean_absolute_error: 411.3466
Epoch 71/100
3/3 - 0s - 6ms/step - loss: 284805.9375 - mean_absolute_error: 416.6817 - val_loss: 319161.7500 - val_mean_absolute_error: 410.7971
Epoch 72/100
3/3 - 0s - 6ms/step - loss: 277462.9062 - mean_absolute_error: 412.4871 - val_loss: 318587.7812 - val_mean_absolute_error: 410.2541
Epoch 73/100
3/3 - 0s - 6ms/step - loss: 282340.3125 - mean_absolute_error: 415.6058 - val_loss: 318038.1875 - val_mean_absolute_error: 409.7235
Epoch 74/100
3/3 - 0s - 6ms/step - loss: 283341.9688 - mean_absolute_error: 413.2817 - val_loss: 317479.9688 - val_mean_absolute_error: 409.1671
Epoch 75/100
3/3 - 0s - 6ms/step - loss: 278622.7188 - mean_absolute_error: 411.2881 - val_loss: 316899.5312 - val_mean_absolute_error: 408.6055
Epoch 76/100
3/3 - 0s - 6ms/step - loss: 274410.1250 - mean_absolute_error: 408.5329 - val_loss: 316304.5625 - val_mean_absolute_error: 408.0222
Epoch 77/100
3/3 - 0s - 6ms/step - loss: 285485.3750 - mean_absolute_error: 412.4106 - val_loss: 315776.0312 - val_mean_absolute_error: 407.4972
Epoch 78/100
3/3 - 0s - 6ms/step - loss: 277831.4688 - mean_absolute_error: 410.0921 - val_loss: 315212.7500 - val_mean_absolute_error: 406.9456
Epoch 79/100
3/3 - 0s - 6ms/step - loss: 278685.3438 - mean_absolute_error: 409.8770 - val_loss: 314628.0000 - val_mean_absolute_error: 406.3659
Epoch 80/100
3/3 - 0s - 6ms/step - loss: 274865.4375 - mean_absolute_error: 405.5993 - val_loss: 314041.3438 - val_mean_absolute_error: 405.7914
Epoch 81/100
3/3 - 0s - 6ms/step - loss: 273142.6562 - mean_absolute_error: 405.4822 - val_loss: 313445.8125 - val_mean_absolute_error: 405.1954
Epoch 82/100
3/3 - 0s - 6ms/step - loss: 271829.0000 - mean_absolute_error: 405.5426 - val_loss: 312858.3125 - val_mean_absolute_error: 404.6177
Epoch 83/100
3/3 - 0s - 6ms/step - loss: 276238.4062 - mean_absolute_error: 409.4067 - val_loss: 312286.1875 - val_mean_absolute_error: 404.0399
Epoch 84/100
3/3 - 0s - 6ms/step - loss: 266623.1875 - mean_absolute_error: 406.3788 - val_loss: 311682.4688 - val_mean_absolute_error: 403.4432
Epoch 85/100
3/3 - 0s - 6ms/step - loss: 279720.5000 - mean_absolute_error: 409.8026 - val_loss: 311141.3125 - val_mean_absolute_error: 402.8980
Epoch 86/100
3/3 - 0s - 6ms/step - loss: 274504.8125 - mean_absolute_error: 407.5530 - val_loss: 310562.7500 - val_mean_absolute_error: 402.3270
Epoch 87/100
3/3 - 0s - 6ms/step - loss: 271432.3125 - mean_absolute_error: 406.2660 - val_loss: 309991.0000 - val_mean_absolute_error: 401.7680
Epoch 88/100
3/3 - 0s - 6ms/step - loss: 272125.1875 - mean_absolute_error: 403.4680 - val_loss: 309428.1875 - val_mean_absolute_error: 401.2198
Epoch 89/100
3/3 - 0s - 6ms/step - loss: 269321.3125 - mean_absolute_error: 405.2219 - val_loss: 308818.5312 - val_mean_absolute_error: 400.6218
Epoch 90/100
3/3 - 0s - 6ms/step - loss: 273044.2812 - mean_absolute_error: 407.9886 - val_loss: 308237.0312 - val_mean_absolute_error: 400.0498
Epoch 91/100
3/3 - 0s - 6ms/step - loss: 267253.1875 - mean_absolute_error: 402.7285 - val_loss: 307660.6875 - val_mean_absolute_error: 399.4750
Epoch 92/100
3/3 - 0s - 6ms/step - loss: 271568.5938 - mean_absolute_error: 403.8477 - val_loss: 307085.6875 - val_mean_absolute_error: 398.8963
Epoch 93/100
3/3 - 0s - 6ms/step - loss: 270870.6875 - mean_absolute_error: 403.1011 - val_loss: 306494.6875 - val_mean_absolute_error: 398.3112
Epoch 94/100
3/3 - 0s - 6ms/step - loss: 264551.4375 - mean_absolute_error: 400.2390 - val_loss: 305901.8125 - val_mean_absolute_error: 397.7414
Epoch 95/100
3/3 - 0s - 6ms/step - loss: 263769.2500 - mean_absolute_error: 402.5366 - val_loss: 305307.2500 - val_mean_absolute_error: 397.1617
Epoch 96/100
3/3 - 0s - 6ms/step - loss: 262891.0000 - mean_absolute_error: 399.1795 - val_loss: 304707.5000 - val_mean_absolute_error: 396.5775
Epoch 97/100
3/3 - 0s - 6ms/step - loss: 263679.6875 - mean_absolute_error: 401.0838 - val_loss: 304125.5312 - val_mean_absolute_error: 396.0099
Epoch 98/100
3/3 - 0s - 6ms/step - loss: 269905.3750 - mean_absolute_error: 404.7471 - val_loss: 303570.8438 - val_mean_absolute_error: 395.4639
Epoch 99/100
3/3 - 0s - 6ms/step - loss: 261381.2812 - mean_absolute_error: 397.3543 - val_loss: 302974.4062 - val_mean_absolute_error: 394.8804
Epoch 100/100
3/3 - 0s - 6ms/step - loss: 262630.5938 - mean_absolute_error: 401.3788 - val_loss: 302377.5312 - val_mean_absolute_error: 394.2883
<keras.src.callbacks.history.History at 0x37eca58b0>
When training models this way you keep your eyes on the epoch history to study the behavior of the loss function and other metrics on training and test data sets. You have to make a judgement call as to when the optimization has stabilized and further progress is minimal. Alternatively, you can install a function that stops the optimization when certain conditions are met.
This is done in the following code with the callback_early_stopping
callback function (results not shown here). The options of the early stopping function ask it to monitor the loss function on the validation data and stop the optimization when the criterion fails to decrease (mode="min"
) over 10 epochs (patience=10
). Any change of the monitored metric has to be at least 0.1 in magnitude to qualify as an improvement (min_delta=.1
).
from tensorflow.keras.callbacks import EarlyStopping
= EarlyStopping(
early_stopping ='val_loss',
monitor=10,
patience=0.1,
min_delta='min',
mode=True
restore_best_weights
)
firstANN.fit(
x_train, y_train,=400,
epochs=32,
batch_size=(x_test, y_test),
validation_data=[early_stopping],
callbacks=0
verbose )
<keras.src.callbacks.history.History at 0x37fd15a90>
To see a list of all Keras callback functions type
A list of activation functions supported by keras
(Keras), can be found here. Alternatively, you can type the following:
# help(tf.keras.callbacks.)
at the console prompt.
Finally, we predict from the final model, and evaluate its performance on the test data. Due to the use of random elements in the fit (stochastic gradient descent, random dropout, …), the results vary slightly with each fit. Unfortunately the set.seed()
function does not ensure identical results (since the fitting is done in python
), so your results will differ slightly.
# Make predictions on test set and calculate mean absolute error
= firstANN.predict(x_test)
predvals
abs(y_test - predvals.flatten())) np.mean(np.
1/3 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 3/3 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step
286.4261176628025
Random numbers
Neural networks rely on random numbers for picking starting values, selecting observations into mini batches, selecting neurons in dropout layers, etc. The underlying code relies on Random and NumPy random number generators. TensorFlow has its own random number generator on top of that. Python code that uses Keras with the TensorFlow backend needs to set the seed for each random number generator to obtain reproducible results. The keras.utils.set_random_seed()
function sets seeds for base Python (with the random package), NumPy, and TensorFlow.
1) keras.utils.set_random_seed(
Even with this control, Python might generate non-reproducible results. Multi-threading operations on CPUs—and GPUs in particular—can produce a non-deterministic order of operations.
One recommendation to deal with non-deterministic results is training the model several times and averaging the results, essentially ensembling them. When a single training run takes several hours, doing it thirty times is not practical.
MNIST Image Classification
We now return to the MNIST image classification data introduced in Section 32.4. Recall that the data comprise 60,000 training images and 10,000 test images of handwritten digits (0–9). Each image has 28 x 28 pixels recording a grayscale value.
The MNIST data is provided by Keras:
Setup the data
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np
# Load MNIST dataset
= mnist.load_data()
(x_train, g_train), (x_test, g_test)
# Check dimensions
print("x_train shape:", x_train.shape)
print("x_test shape:", x_test.shape)
x_train shape: (60000, 28, 28)
x_test shape: (10000, 28, 28)
The images are stored as a three-dimensional array, and need to be reshaped into a matrix. For classification tasks with \(k\) categories, Keras expects as the target values a matrix of \(k\) columns. Column \(k\) contains ones in the rows for observations where the observed category is \(k\), and zeros otherwise. This is called one-hot encoding of the target variable. Luckily, keras
has built-in functions that handle both tasks for us.
# Reshape the data to flatten the 28x28 images into 784-dimensional vectors
= x_train.reshape(x_train.shape[0], 784)
x_train = x_test.reshape(x_test.shape[0], 784)
x_test
# Convert labels to categorical (one-hot encoding)
= to_categorical(g_train, 10)
y_train = to_categorical(g_test, 10) y_test
Let’s look at the one-hot encoding of the target data. g_test
contains the value of the digit from 0–9. y_test
is a matrix with 10 columns, each column corresponds to one digit. If observation \(i\) represents digit \(j\) then there is a 1 in row \(i\), column \(j+1\) of the encoded matrix. For example, for the first twenty images:
# Display first 20 original labels
print(g_test[:20])
[7 2 1 0 4 1 4 9 5 9 0 6 9 0 1 5 9 7 3 4]
# Display first 20 one-hot encoded labels (all 10 columns)
print(y_test[:20, :10])
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
[0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]]
Let’s look at the matrix of inputs. The next array shows the 28 x 28 - 784 input columns for the third image. The values are grayscale values between 0 and 255.
# Display the 3rd test sample (index 2)
print(x_test[2])
[ 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 38 254 109 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 87 252 82 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 135 241 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 45 244 150 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 84 254 63 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 202 223 11
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 32 254 216 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 95 254
195 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 140 254 77 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 57
237 205 8 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 124 255 165 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 171 254 81 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 24 232 215 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 120 254 159 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 151 254 142 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 228 254 66 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 61 251 254 66 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 141 254 205 3 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 10 215 254 121
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 5 198 176 10 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0]
Finally, prior to training the network, we scale the input values to lie between 0–1.
# Normalize pixel values to range [0, 1]
= x_train.astype('float32') / 255.0
x_train = x_test.astype('float32') / 255.0 x_test
The target variable does not need to be scaled, the one-hot encoding together with the use of a softmax output function ensures that the output for each category is a value between 0 and 1, and that they sum to 1 across the 10 categories. We will interpret them as predicted probabilities that an observed image is assigned to a particular digit.
To classify the MNIST images we consider two types of neural networks in the remainder of this chapter: a multi layer ANN and a network without a hidden layer. The latter is a multi category perceptron and very similar to a multinomial logistic regression model.
Multi layer neural network
We now train the network shown in Figure 32.14, an ANN with two hidden layers. We also add dropout regularization layers after each fully connected hidden layer. The first layer specifies the input shape of 28 x 28 = 784. It has 128 neurons and ReLU activation. Why? Because.
This is followed by a first dropout layer with rate \(\phi_1 = 0.3\), another fully connected hidden layer with 64 nodes and hyperbolic tangent activation function, a second dropout layer with rate \(\phi_2 = 0.2\), and a final softmax output layer. Why? Because.
Setup the network
The following statements set up the network in keras
:
= keras.models.Sequential()
modelnn = (784,)))
modelnn.add (layers.Input(shape = 128, activation = "relu", name = "FirstHidden"))
modelnn.add (layers.Dense(units = 0.3, name = "FirstDropOut"))
modelnn.add (layers.Dropout(rate =64, activation="tanh", name = "SecondHidden"))
modelnn.add (layers.Dense(units= 0.2, name = "SecondDropOut"))
modelnn.add (layers.Dropout(rate =10, activation = "softmax", name="Output")) modelnn.add (layers.Dense(units
The summary()
function let’s us inspect whether we got it all right.
print(modelnn.summary())
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ FirstHidden (Dense) │ (None, 128) │ 100,480 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ FirstDropOut (Dropout) │ (None, 128) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ SecondHidden (Dense) │ (None, 64) │ 8,256 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ SecondDropOut (Dropout) │ (None, 64) │ 0 │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ Output (Dense) │ (None, 10) │ 650 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 109,386 (427.29 KB)
Trainable params: 109,386 (427.29 KB)
Non-trainable params: 0 (0.00 B)
None
The total number of parameters in this network is 109,386, a sizeable network but not a huge network.
Set up the optimization
Next, we add details to the model to specify the fitting algorithm. We fit the model by minimizing the categorical cross-entropy function and monitor the classification accuracy during the iterations.
compile(loss="categorical_crossentropy",
modelnn.=keras.optimizers.RMSprop(),
optimizer=["accuracy"]) metrics
Fit the model
We are ready to go. The final step is to supply training data, and fit the model. With a batch size of 128 observations, each epoch corresponds to 60,000 / 128 = 469 gradient evaluations.
= modelnn.fit(x_train,
history
y_train,=20,
epochs=128,
batch_size=(x_test, y_test),
validation_data=2)
verbose
=(12, 6))
plt.figure(figsize
2, 1, 1)
plt.subplot(range(len(history.history['loss'])), history.history['loss'], label='Training Loss')
plt.scatter(range(len(history.history['val_loss'])), history.history['val_loss'], label='Validation Loss')
plt.scatter('Model Loss')
plt.title('Epoch')
plt.xlabel('Loss')
plt.ylabel(
plt.legend()
2, 1, 2)
plt.subplot(
range(len(history.history['accuracy'])), history.history['accuracy'], label='Training Accuracy')
plt.scatter(range(len(history.history['val_accuracy'])), history.history['val_accuracy'], label='Validation Accuracy')
plt.scatter('Model Accuracy')
plt.title('Epoch')
plt.xlabel('Accuracy')
plt.ylabel(='lower right')
plt.legend(loc
plt.tight_layout() plt.show()
Epoch 1/20
469/469 - 1s - 1ms/step - accuracy: 0.8852 - loss: 0.3960 - val_accuracy: 0.9477 - val_loss: 0.1608
Epoch 2/20
469/469 - 0s - 875us/step - accuracy: 0.9439 - loss: 0.1883 - val_accuracy: 0.9657 - val_loss: 0.1122
Epoch 3/20
469/469 - 0s - 868us/step - accuracy: 0.9558 - loss: 0.1482 - val_accuracy: 0.9704 - val_loss: 0.0952
Epoch 4/20
469/469 - 0s - 869us/step - accuracy: 0.9627 - loss: 0.1248 - val_accuracy: 0.9747 - val_loss: 0.0873
Epoch 5/20
469/469 - 0s - 870us/step - accuracy: 0.9667 - loss: 0.1121 - val_accuracy: 0.9745 - val_loss: 0.0828
Epoch 6/20
469/469 - 0s - 875us/step - accuracy: 0.9692 - loss: 0.1021 - val_accuracy: 0.9774 - val_loss: 0.0772
Epoch 7/20
469/469 - 0s - 877us/step - accuracy: 0.9714 - loss: 0.0939 - val_accuracy: 0.9774 - val_loss: 0.0776
Epoch 8/20
469/469 - 0s - 875us/step - accuracy: 0.9726 - loss: 0.0891 - val_accuracy: 0.9783 - val_loss: 0.0742
Epoch 9/20
469/469 - 0s - 874us/step - accuracy: 0.9747 - loss: 0.0819 - val_accuracy: 0.9788 - val_loss: 0.0720
Epoch 10/20
469/469 - 0s - 879us/step - accuracy: 0.9754 - loss: 0.0783 - val_accuracy: 0.9796 - val_loss: 0.0696
Epoch 11/20
469/469 - 0s - 888us/step - accuracy: 0.9776 - loss: 0.0730 - val_accuracy: 0.9807 - val_loss: 0.0663
Epoch 12/20
469/469 - 0s - 881us/step - accuracy: 0.9774 - loss: 0.0719 - val_accuracy: 0.9797 - val_loss: 0.0687
Epoch 13/20
469/469 - 0s - 873us/step - accuracy: 0.9786 - loss: 0.0688 - val_accuracy: 0.9804 - val_loss: 0.0672
Epoch 14/20
469/469 - 0s - 881us/step - accuracy: 0.9797 - loss: 0.0648 - val_accuracy: 0.9819 - val_loss: 0.0656
Epoch 15/20
469/469 - 0s - 874us/step - accuracy: 0.9796 - loss: 0.0635 - val_accuracy: 0.9794 - val_loss: 0.0677
Epoch 16/20
469/469 - 0s - 872us/step - accuracy: 0.9807 - loss: 0.0606 - val_accuracy: 0.9811 - val_loss: 0.0648
Epoch 17/20
469/469 - 0s - 881us/step - accuracy: 0.9815 - loss: 0.0591 - val_accuracy: 0.9801 - val_loss: 0.0688
Epoch 18/20
469/469 - 0s - 886us/step - accuracy: 0.9818 - loss: 0.0579 - val_accuracy: 0.9802 - val_loss: 0.0675
Epoch 19/20
469/469 - 0s - 878us/step - accuracy: 0.9833 - loss: 0.0544 - val_accuracy: 0.9808 - val_loss: 0.0686
Epoch 20/20
469/469 - 0s - 888us/step - accuracy: 0.9824 - loss: 0.0542 - val_accuracy: 0.9797 - val_loss: 0.0698
After about 10 epochs the training and validation accuracy are stabilizing although the loss continues to decrease. Interestingly, the accuracy and loss in the 10,000 image validation set is better than in the 60,000 image training data set. Considering that the grayscale values are entered into this neural network as 784 numeric input variables without taking into account any spatial arrangement of the pixels on the image, a classification accuracy of 98% on unseen images is quite good. Whether that is sufficient depends on the application.
As we will see in Chapter 36, neural networks that specialize in the processing of grid-like data such as images easily improve on this performance.
Calculate predicted categories
To calculate the predicted categories for the images in the test data set, we use the predict
function. The result of that operation is a vector of 10 predicted probabilities for each observation.
= modelnn.predict(x_test) predvals
1/313 ━━━━━━━━━━━━━━━━━━━━ 21s 70ms/step 231/313 ━━━━━━━━━━━━━━━━━━━━ 0s 218us/step 313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 252us/step
For the first image, the probabilities that its digit belongs to any of the 10 classes is given by this vector
0].round(4) predvals[
array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.], dtype=float32)
0].argmax() predvals[
7
0].max().round(4) predvals[
1.0
The maximum probability is 1.0 in position 7. The image is classified as a “7” (the digits are 0-based, as is Python’s index numbering scheme).
numpy
provides the convenience function argmax()
to perform this operation; it returns the index of the maximum value:
= predvals.argmax(axis=1) predcl
Which of the first 500 observations were misclassified?
= [i for i in range(500) if predcl[i] != g_test[i]]
miscl miscl
[124, 217, 233, 247, 259, 321, 340, 381, 445, 449]
print(f"Observed value for obs # {miscl[0]}: {g_test[miscl[0]]}")
print(f"Predicted value for obs # {miscl[0]}: {predcl[miscl[0]]}")
Observed value for obs # 124: 7
Predicted value for obs # 124: 4
The first misclassified observation is #124. The observed digit value is 7, the predicted value is 4. The softmax probabilities for this observation show why it predicted category 4:
= "{:.4f}".format
float_formatter ={'float_kind':float_formatter})
np.set_printoptions(formatter
0]] predvals[miscl[
array([0.0000, 0.0003, 0.0000, 0.0002, 0.7088, 0.0000, 0.0000, 0.2899,
0.0001, 0.0007], dtype=float32)
We can visualize the data with the image
function. The next code segment does this for the first observation in the data set and for the first two mis-classified observations:
# visualize the digits
def plotIt(id=0):
= x_test[id] # Shape should be (28, 28)
pixels = pixels.reshape((28,28))
im
=(6, 6))
plt.figure(figsize='gray', origin='upper')
plt.imshow(im, cmapf"Index #{id} -- Image label: {g_test[id]}, Predicted: {predcl[id]}")
plt.title(
plt.show()
# Visualize specific digits
0)
plotIt(0])
plotIt(miscl[1]) plotIt(miscl[
Multinomial logistic regression
A 98% accuracy is impressive, but maybe it is not good enough. In applications where the consequences of errors are high, this accuracy might be insufficient. Suppose we are using the trained network to recognize written digits on personal checks. Getting 200 out of 10,000 digits wrong would be unacceptable. Banks would deposit incorrect amounts all the time.
If that is the application for the trained algorithm, we should consider other models for these data. This raises an interesting question: how much did we gain by adding the layers of the network? If this is an effective strategy to increase accuracy then we could consider adding more layers. If not, then maybe we need to research an entirely different network architecture.
Before trying deeper alternatives we can establish one performance benchmark by removing the hidden layers and training what essentially is a single layer perceptron (Section 32.1). This model has an input layer and an output layer. In terms of the keras
syntax it is specified with a single layer:
= keras.models.Sequential()
modellr =(784,)))
modellr.add(layers.Input(shape=10, activation='softmax'))
modellr.add(layers.Dense(units
modellr.summary()
Model: "sequential_2"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ dense_1 (Dense) │ (None, 10) │ 7,850 │ └─────────────────────────────────┴────────────────────────┴───────────────┘
Total params: 7,850 (30.66 KB)
Trainable params: 7,850 (30.66 KB)
Non-trainable params: 0 (0.00 B)
This is essentially a multinomial logistic regression model with a 10-category target variable and 784 input variables. The model is much smaller than the previous network (it has only 7,850 parameters) but is huge if we think of it as a multinomial logistic regression model. Many software packages for multinomial regression would struggle to fit a model of this size. When articulated as a neural network, training such a model is actually a breeze.
We proceed just as before.
compile(loss="categorical_crossentropy",
modellr.=keras.optimizers.RMSprop(),
optimizer=["accuracy"])
metrics
# Train the model
= modellr.fit(x_train,
history_lr
y_train,=20,
epochs=128,
batch_size=(x_test, y_test),
validation_data= 2)
verbose
# Plot training history
=(12, 6))
plt.figure(figsize
2, 1, 1)
plt.subplot(range(len(history_lr.history['loss'])), history_lr.history['loss'], label='Training Loss')
plt.scatter(range(len(history_lr.history['loss'])), history_lr.history['val_loss'], label='Validation Loss')
plt.scatter('Model Loss')
plt.title('Epoch')
plt.xlabel('Loss')
plt.ylabel(
plt.legend()
2, 1, 2)
plt.subplot(range(len(history_lr.history['loss'])), history_lr.history['accuracy'], label='Training Accuracy')
plt.scatter(range(len(history_lr.history['loss'])), history_lr.history['val_accuracy'], label='Validation Accuracy')
plt.scatter('Model Accuracy')
plt.title('Epoch')
plt.xlabel('Accuracy')
plt.ylabel(
plt.legend()
plt.tight_layout() plt.show()
Epoch 1/20
469/469 - 0s - 737us/step - accuracy: 0.8531 - loss: 0.5938 - val_accuracy: 0.9078 - val_loss: 0.3434
Epoch 2/20
469/469 - 0s - 426us/step - accuracy: 0.9071 - loss: 0.3331 - val_accuracy: 0.9170 - val_loss: 0.3018
Epoch 3/20
469/469 - 0s - 427us/step - accuracy: 0.9147 - loss: 0.3045 - val_accuracy: 0.9209 - val_loss: 0.2875
Epoch 4/20
469/469 - 0s - 426us/step - accuracy: 0.9187 - loss: 0.2912 - val_accuracy: 0.9223 - val_loss: 0.2803
Epoch 5/20
469/469 - 0s - 424us/step - accuracy: 0.9213 - loss: 0.2831 - val_accuracy: 0.9237 - val_loss: 0.2759
Epoch 6/20
469/469 - 0s - 434us/step - accuracy: 0.9233 - loss: 0.2775 - val_accuracy: 0.9248 - val_loss: 0.2730
Epoch 7/20
469/469 - 0s - 431us/step - accuracy: 0.9245 - loss: 0.2733 - val_accuracy: 0.9265 - val_loss: 0.2711
Epoch 8/20
469/469 - 0s - 429us/step - accuracy: 0.9257 - loss: 0.2700 - val_accuracy: 0.9260 - val_loss: 0.2697
Epoch 9/20
469/469 - 0s - 426us/step - accuracy: 0.9265 - loss: 0.2674 - val_accuracy: 0.9260 - val_loss: 0.2686
Epoch 10/20
469/469 - 0s - 430us/step - accuracy: 0.9273 - loss: 0.2651 - val_accuracy: 0.9264 - val_loss: 0.2678
Epoch 11/20
469/469 - 0s - 426us/step - accuracy: 0.9280 - loss: 0.2632 - val_accuracy: 0.9262 - val_loss: 0.2672
Epoch 12/20
469/469 - 0s - 429us/step - accuracy: 0.9286 - loss: 0.2616 - val_accuracy: 0.9260 - val_loss: 0.2668
Epoch 13/20
469/469 - 0s - 426us/step - accuracy: 0.9291 - loss: 0.2601 - val_accuracy: 0.9262 - val_loss: 0.2664
Epoch 14/20
469/469 - 0s - 427us/step - accuracy: 0.9292 - loss: 0.2588 - val_accuracy: 0.9266 - val_loss: 0.2661
Epoch 15/20
469/469 - 0s - 429us/step - accuracy: 0.9296 - loss: 0.2577 - val_accuracy: 0.9265 - val_loss: 0.2659
Epoch 16/20
469/469 - 0s - 430us/step - accuracy: 0.9300 - loss: 0.2566 - val_accuracy: 0.9268 - val_loss: 0.2658
Epoch 17/20
469/469 - 0s - 429us/step - accuracy: 0.9303 - loss: 0.2557 - val_accuracy: 0.9268 - val_loss: 0.2657
Epoch 18/20
469/469 - 0s - 427us/step - accuracy: 0.9307 - loss: 0.2548 - val_accuracy: 0.9268 - val_loss: 0.2656
Epoch 19/20
469/469 - 0s - 428us/step - accuracy: 0.9310 - loss: 0.2540 - val_accuracy: 0.9269 - val_loss: 0.2656
Epoch 20/20
469/469 - 0s - 428us/step - accuracy: 0.9313 - loss: 0.2533 - val_accuracy: 0.9271 - val_loss: 0.2656
Even with just a single layer, the model performs quite well, its accuracy is around 93%. Adding the additional layer in the previous ANN did improve the accuracy. On the other hand, it took more than 100,000 extra parameters to move from 93% to 98% accuracy.