Cart

TensorFlow.js and linear regression

TensorFlow.js is a JavaScript machine learning library, part of the larger TensorFlow ecosystem used to build ML-powered applications. It packages a set of functionalities for training and deploying machine learning and deep learning models and runs both in a web browser and in the Node.js environment.

Simple linear regression is one of the most important types of data analysis which models the relationship between two variables. A special case of supervised learning, it is one of the simplest tasks that can be performed using TensorFlow.js.

A regular ML-based solution typically includes the following steps:

  1. Load and prepare the input data.
  2. Define the model architecture.
  3. Compile and training the model.
  4. Use the model to make predictions.

In a TensorFlow.js-based application, those steps correspond to the following sections of the code:

const { inputs, labels } = await getData()
const model = createModel()
model.compile(...)
await model.fit(inputs, labels, ...)
model.predict(...)

In the case of simple linear regression, the model architecture can be defined as follows:

const model = tf.sequential()
model.add(tf.layers.dense({ inputShape: [1], units: 1, useBias: true }))
model.add(tf.layers.dense({ units: 1, useBias: true }))