Skip to main content

Get Started with Keras 3.0 + MLflow

Download this notebook

This tutorial is an end-to-end tutorial on training a MINST classifier with Keras 3.0 and logging results with MLflow. It will demonstrate the use of mlflow.keras.MlflowCallback, and how to subclass it to implement custom logging logic.

Keras is a high-level api that is designed to be simple, flexible, and powerful - allowing everyone from beginners to advanced users to quickly build, train, and evaluate models. Keras 3.0, or Keras Core, is a full rewrite of the Keras codebase that rebases it on top of a modular backend architecture. It makes it possible to run Keras workflows on top of arbitrary frameworks 鈥� starting with TensorFlow, JAX, and PyTorch.

Install Packages鈥�/a>

pip install -q keras mlflow jax jaxlib torch tensorflow

Import Packages / Configure Backend鈥�/a>

Keras 3.0 is inherently multi-backend, so you will need to set the backend environment variable before importing the package.

import os

# You can use 'tensorflow', 'torch' or 'jax' as backend. Make sure to set the environment variable before importing.
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import numpy as np

import mlflow
Using TensorFlow backend

Load Dataset鈥�/a>

We will use the MNIST dataset. This is a dataset of handwritten digits and will be used for an image classification task. There are 10 classes corresponding to the 10 digits.

(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=3)
x_test = np.expand_dims(x_test, axis=3)
x_train[0].shape
(28, 28, 1)
# Visualize Dataset
import matplotlib.pyplot as plt

grid = 3
fig, axes = plt.subplots(grid, grid, figsize=(6, 6))
for i in range(grid):
for j in range(grid):
axes[i][j].imshow(x_train[i * grid + j])
axes[i][j].set_title(f"label={y_train[i * grid + j]}")
plt.tight_layout()

Build Model鈥�/a>

We will use the Keras 3.0 sequential API to build a simple CNN.

NUM_CLASSES = 10
INPUT_SHAPE = (28, 28, 1)


def initialize_model():
return keras.Sequential(
[
keras.Input(shape=INPUT_SHAPE),
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(NUM_CLASSES, activation="softmax"),
]
)


model = initialize_model()
model.summary()
Model: "sequential"
鈹忊攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹斥攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹斥攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲敁
鈹�span style="font-weight: bold"> Layer (type)                    鈹�span style="font-weight: bold"> Output Shape              鈹�span style="font-weight: bold">    Param # 鈹�
鈹♀攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈺団攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈺団攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲敥
鈹� conv2d (Conv2D)                 鈹� (None, 26, 26, 32)        鈹�        320 鈹�
鈹溾攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹尖攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹尖攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�
鈹� conv2d_1 (Conv2D)               鈹� (None, 24, 24, 32)        鈹�      9,248 鈹�
鈹溾攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹尖攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹尖攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�
鈹� conv2d_2 (Conv2D)               鈹� (None, 22, 22, 32)        鈹�      9,248 鈹�
鈹溾攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹尖攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹尖攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�
鈹� global_average_pooling2d        鈹� (None, 32)                鈹�          0 鈹�
鈹� (GlobalAveragePooling2D)        鈹�                           鈹�            鈹�
鈹溾攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹尖攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹尖攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�
鈹� dense (Dense)                   鈹� (None, 10)                鈹�        330 鈹�
鈹斺攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹粹攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹粹攢鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�鈹�
 Total params: 19,146 (74.79 KB)
 Trainable params: 19,146 (74.79 KB)
 Non-trainable params: 0 (0.00 B)

Train Model (Default Callback)鈥�/a>

We will fit the model on the dataset, using MLflow's mlflow.keras.MlflowCallback to log metrics during training.

BATCH_SIZE = 64  # adjust this based on the memory of your machine
EPOCHS = 3

Log Per Epoch鈥�/a>

An epoch defined as one pass through the entire training dataset.

model = initialize_model()

model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(),
metrics=["accuracy"],
)

run = mlflow.start_run()
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_split=0.1,
callbacks=[mlflow.keras.MlflowCallback(run)],
)
mlflow.end_run()
Epoch 1/3
844/844 鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣 30s 34ms/step - accuracy: 0.5922 - loss: 1.2862 - val_accuracy: 0.9427 - val_loss: 0.2075
Epoch 2/3
844/844 鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣 28s 33ms/step - accuracy: 0.9330 - loss: 0.2286 - val_accuracy: 0.9348 - val_loss: 0.2020
Epoch 3/3
844/844 鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣 28s 33ms/step - accuracy: 0.9499 - loss: 0.1671 - val_accuracy: 0.9558 - val_loss: 0.1491

Log Results鈥�/a>

The callback for the run would log parameters, metrics and artifacts to MLflow dashboard.

run page

Log Per Batch鈥�/a>

Within each epoch, the training dataset is broken down to batches based on the defined BATCH_SIZE. If we set the callback to not log based on epochs with log_every_epoch=False, and to log every 5 batches with log_every_n_steps=5, we can adjust the logging to be based on the batches.

model = initialize_model()

model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(),
metrics=["accuracy"],
)

with mlflow.start_run() as run:
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_split=0.1,
callbacks=[mlflow.keras.MlflowCallback(run, log_every_epoch=False, log_every_n_steps=5)],
)
Epoch 1/3
844/844 鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣 30s 34ms/step - accuracy: 0.6151 - loss: 1.2100 - val_accuracy: 0.9373 - val_loss: 0.2144
Epoch 2/3
844/844 鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣 29s 34ms/step - accuracy: 0.9274 - loss: 0.2459 - val_accuracy: 0.9608 - val_loss: 0.1338
Epoch 3/3
844/844 鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣 28s 34ms/step - accuracy: 0.9477 - loss: 0.1738 - val_accuracy: 0.9577 - val_loss: 0.1454

Log Results鈥�/a>

If we log per epoch, we will only have three datapoints, since there are only 3 epochs:

log per epoch

By logging per batch, we can get more datapoints, but they can be noisier:

log per batch

class MlflowCallbackLogPerBatch(mlflow.keras.MlflowCallback):
def on_batch_end(self, batch, logs=None):
if self.log_every_n_steps is None or logs is None:
return
if (batch + 1) % self.log_every_n_steps == 0:
self.metrics_logger.record_metrics(logs, self._log_step)
self._log_step += self.log_every_n_steps
model = initialize_model()

model.compile(
loss=keras.losses.SparseCategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(),
metrics=["accuracy"],
)

with mlflow.start_run() as run:
model.fit(
x_train,
y_train,
batch_size=BATCH_SIZE,
epochs=EPOCHS,
validation_split=0.1,
callbacks=[MlflowCallbackLogPerBatch(run, log_every_epoch=False, log_every_n_steps=5)],
)
Epoch 1/3
844/844 鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣 29s 34ms/step - accuracy: 0.5645 - loss: 1.4105 - val_accuracy: 0.9187 - val_loss: 0.2826
Epoch 2/3
844/844 鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹�鈹佲攣鈹佲攣鈹佲攣鈹� 29s 34ms/step - accuracy: 0.9257 - loss: 0.2615 - val_accuracy: 0.9602 - val_loss: 0.1368
Epoch 3/3
844/844 鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣 29s 34ms/step - accuracy: 0.9456 - loss: 0.1800 - val_accuracy: 0.9678 - val_loss: 0.1037

Evaluation鈥�/a>

Similar to training, you can use the callback to log the evaluation result.

with mlflow.start_run() as run:
model.evaluate(x_test, y_test, callbacks=[mlflow.keras_core.MlflowCallback(run)])
313/313 鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣鈹佲攣 1s 4ms/step - accuracy: 0.9541 - loss: 0.1487