Introduction

Online machine learning is a branch of machine learning where the training samples arrive sequentially (compared to batch processing). The model parameters are dynamically updated with the arrival new data. More details can be found here. Several tools like Vowpal Wabbit, Scikit Learn provide support for online learning. In this article we look at an alternate approach using JAGS.

JAGS (Just Another Gibbs Sampler) is a DSL for model specification and an MCMC sampler. It is used to build hierarchical Bayesian models.

Simple example

In this article we will look at a simple example of inferring the probability of heads from a sequence of coin tosses. For the purposes of this example, the data will arrive in 2 batches - data1 and data2. The model will be sequentially updated. First, the model specification in JAGS. Each coin toss is Bernoulli distributed and the prior on p is uniform distribution.

model <- "model {
            for (i in 1:N) {
              toss[i] ~ dbern(p)
            }
            p ~ dbeta(a,b)
          }"

Set the initial data, compile the model and draw samples from posterior distribution for the first batch of data.

a1 <- 1
b1 <- 1
data1 <- list(toss = c(0,0,1,1), N = 4, a = a1, b = b1)
# compile the model
model1 <- jags.model(textConnection(model), data1)
## Compiling model graph
##    Resolving undeclared variables
##    Allocating nodes
## Graph information:
##    Observed stochastic nodes: 4
##    Unobserved stochastic nodes: 1
##    Total graph size: 8
## 
## Initializing model
# sample from the posterior p|toss[1:N]
out1 <- jags.samples(model1, "p", n.iter = 1000)

The posterior is the conjugate prior of Bernoulli distribution and hence beta distributed. But choosing a parametric family for the posterior is not always appropriate.

library(MASS)
fit <- fitdistr(out1$p, "beta", start = list(shape1=a1, shape2=b1))
a2 <- fit$estimate[[1]]
b2 <- fit$estimate[[2]]

Update the model with new data and repeat the process

data2 <- list(toss=c(0,1,1,0), N=4, a=a2, b=b2)
model2 <- jags.model(textConnection(model), data2)
## Compiling model graph
##    Resolving undeclared variables
##    Allocating nodes
## Graph information:
##    Observed stochastic nodes: 4
##    Unobserved stochastic nodes: 1
##    Total graph size: 8
## 
## Initializing model
out2 <- jags.samples(model2, "p", n.iter=1000)