浏览器端的机器学习 tensorflowjs(6) 训练模型
现在模型已经定义好了,数据也下载并进行了处理,一切准备就绪准备开始训练。
async function trainModel(model, inputs, labels) {
// 准备要训练的模型
model.compile({
optimizer: tf.train.adam(),
loss: tf.losses.meanSquaredError,
metrics: ['mse'],
});
const batchSize = 32;
const epochs = 50;
return await model.fit(inputs, labels, {
batchSize,
epochs,
shuffle: true,
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training Performance' },
['loss', 'mse'],
{ height: 200, callbacks: ['onEpochEnd'] }
)
});
}
训练前的一些准备
model.compile({
optimizer: tf.train.adam(),
loss: tf.losses.meanSquaredError,
metrics: ['mse'],
});
在训练模型之前,需要 "编译 "该模型,那么具体应该如何做呢? 我们需要一个优化和一个损失函数,损失函数也可以理解目标函数,主要是指定训练,让我们训练一个目标,优化器这是给出一个策略如何在训练过程更新参数。
- 优化器。这是一种算法,是更新参数的算法。在 TensorFlow.js 中有许多优化器可用。这里选择了 adam 优化器,也可以尝试用其他优化器
- 损失函数:其实就是一个函数,告诉模型在学习过程中,在每个批次(数据子集)时的表现如何。这里选择 meanSquaredError 来比较模型的预测和真实值
const batchSize = 32;
const epochs = 50;
设置超参数 batchSize 和一个 epochs 的数量。
batchSize 指的是模型在每次迭代训练中看到的数据子集的大小。常见的批次大小往往在 32-512 之间取值。批次大小对于训练速度是有所影响的
epochs 完成整个数据集进行训练的次数
开始训练
return await model.fit(inputs, labels, {
batchSize,
epochs,
callbacks: tfvis.show.fitCallbacks(
{ name: 'Training Performance' },
['loss', 'mse'],
{ height: 200, callbacks: ['onEpochEnd'] }
)
});
model.fit 是来启动训练的函数。这是一个异步函数,所以返回会是一个 promise。
为了监控训练进度,回调传函数作为 model.fit 来获取训练过程中信息。然后回调函数使用 tfvis.show.fitCallbacks 来定义,然后可以绘制损失值对于迭代的图标
const tensorData = convertToTensor(data);
const {inputs, labels} = tensorData;
// Train the model
await trainModel(model, inputs, labels);
console.log('Done Training');
这的注意的这部分代码要写在 run 函数中,具体如下
async function run() {
// 加载数据
const data = await getData();
// 处理原始数据,将数据 horsepower 映射为 x 而 mpg 则映射为 y
const values = data.map(d => ({
x: d.horsepower,
y: d.mpg,
}));
// 将数据以散点图形式显示在开发者调试工具
tfvis.render.scatterplot(
{name: 'Horsepower v MPG'},
{values},
{
xLabel: 'Horsepower',
yLabel: 'MPG',
height: 300
}
);
const model = createModel();
const tensorData = convertToTensor(data);
const {inputs, labels} = tensorData;
// Train the model
await trainModel(model, inputs, labels);
console.log('Done Training');
}
发表评论 (审核通过后显示评论):