TensorFlow.js and linear regression
TensorFlow.js is a JavaScript machine learning library, part of the larger TensorFlow ecosystem used to build ML-powered applications. It packages a set of functionalities for training and deploying machine learning and deep learning models and runs both in a web browser and in the Node.js environment.
Simple linear regression is one of the most important types of data analysis which models the relationship between two variables. A special case of supervised learning, it is one of the simplest tasks that can be performed using TensorFlow.js.
A regular ML-based solution typically includes the following steps:
- Load and prepare the input data.
- Define the model architecture.
- Compile and training the model.
- Use the model to make predictions.
In a TensorFlow.js-based application, those steps correspond to the following sections of the code:
const { inputs, labels } = await getData()
const model = createModel()
model.compile(...)
await model.fit(inputs, labels, ...)
model.predict(...)In the case of simple linear regression, the model architecture can be defined as follows:
const model = tf.sequential()
model.add(tf.layers.dense({ inputShape: [1], units: 1, useBias: true }))
model.add(tf.layers.dense({ units: 1, useBias: true }))