Implementation of a Modern Machine Learning Library

If you want to understand how a modern machine learning library works, there is no better alternative than Flux

Image for post
Image for post

What does a Machine Learning Library do?

Before getting into the details let’s take a birds eye perspective and talk about the different tasks a ML library has to carry out.

  1. A way to define a loss/error/cost function. A loss function calculates the difference between desired output from model and actual prediction made by model.
  2. Calculate the gradient of the loss function with respect to its parameters. The parameters are the ones used to to define the model.
  3. Way of defining different optimizers or training strategies.
  4. A training function which tries to minimize the loss function by adjusting the model parameters using calculated gradient and optimizer.

What is a Mathematical Model?

A mathematical model is a model defined in mathematics or in code in our case. The purpose is much the same as for a physical model.

What is a Model in Flux?

In the Flux machine learning library, the model is simply a function which takes one input. The input is a matrix needs to be organized in a particular way. Every row represents a different property. In machine learning we refer to these properties as features. Another way to look at it is that you can simulate a function with multiple inputs by treating each row as a separate input argument.

W = rand(2, 5)
b = rand(2)

model(x) = W*x .+ b
model = Chain(
Dense(784, 64, relu),
Dense(64, 64, relu),
Dense(32, 10)

Model Parameters

The parameters of the model are variable you can adjust to change the behavior of your model. In our first example, the parameters where W and b. W is referred to as the weights. It will typically be a matrix as well. b is referred to as the bias. The bias is not affected by the input.

Flux.train!(loss, params, data, optimizer)
params = Flux.params(W, b)
Flux.train!(loss, params, data, optimizer)
params = Flux.params(model)

Flux Implementation of Model Params

Params contains an array order holding a list of parameters added to it. the params member is primarily used to check that a parameter (typically an array object) has not already been added to the Params object.

struct Params
order::Buffer{Any, Vector{Any}}
Params() = new(Buffer([], false), IdSet())
function Base.push!(ps::Params, x)
if !(x in ps.params)
push!(ps.order, x)
push!(ps.params, x)
return ps
Params(xs) = push!(Params(), xs...)
function params(m...)
ps = Params()
params!(ps, m)
return ps

Updating Weights and Optimizers

The parameters of our model is usually referred to as weights in machine learning. The learning strategy is often referred to as an optimizer.

function update!(opt, x, x̄)
x .-= apply!(opt, x, x̄)
mutable struct Descent

function apply!(o::Descent, x, Δ)
Δ .*= o.eta
function update!(opt, x, x̄)
x .-= (x̄ .*= opt.eta)
function update!(opt, x, x̄)
x̄ .*= opt.eta
x .-= x̄
function update!(opt, x, x̄)
x .-= (x̄ .* opt.eta)
function update!(opt, xs::Params, gs)
for x in xs
if gs[x] == nothing
update!(opt, x, gs[x])

Training Function

Normally the user does not call the update! function directly. Instead update! get called by the training function train!.

function train!(loss, ps, data, opt; cb = () -> ())
ps = Params(ps)
cb = runall(cb)
for d in data
gs = gradient(ps) do
update!(opt, ps, gs)

Final Remarks

This story is a bit of a mess, because I realized upon writing it that the scope of this was too large. But rather than not publishing my writing I decided to put it out there because people who know what they are looking for may find sections of this story useful.

Geek dad, living in Oslo, Norway with passion for UX, Julia programming, science, teaching, reading and writing.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store