浏览器端的机器学习 tensorflowjs(6) 训练模型

cover_002.png

现在模型已经定义好了,数据也下载并进行了处理,一切准备就绪准备开始训练。

async function trainModel(model, inputs, labels) {
  // 准备要训练的模型
  model.compile({
    optimizer: tf.train.adam(),
    loss: tf.losses.meanSquaredError,
    metrics: ['mse'],
  });

  const batchSize = 32;
  const epochs = 50;

  return await model.fit(inputs, labels, {
    batchSize,
    epochs,
    shuffle: true,
    callbacks: tfvis.show.fitCallbacks(
      { name: 'Training Performance' },
      ['loss', 'mse'],
      { height: 200, callbacks: ['onEpochEnd'] }
    )
  });
}

训练前的一些准备

model.compile({
  optimizer: tf.train.adam(),
  loss: tf.losses.meanSquaredError,
  metrics: ['mse'],
});

在训练模型之前,需要 "编译 "该模型,那么具体应该如何做呢? 我们需要一个优化和一个损失函数,损失函数也可以理解目标函数,主要是指定训练,让我们训练一个目标,优化器这是给出一个策略如何在训练过程更新参数。

  • 优化器。这是一种算法,是更新参数的算法。在 TensorFlow.js 中有许多优化器可用。这里选择了 adam 优化器,也可以尝试用其他优化器
  • 损失函数:其实就是一个函数,告诉模型在学习过程中,在每个批次(数据子集)时的表现如何。这里选择 meanSquaredError 来比较模型的预测和真实值
const batchSize = 32;
const epochs = 50;

设置超参数 batchSize 和一个 epochs 的数量。

  • batchSize 指的是模型在每次迭代训练中看到的数据子集的大小。常见的批次大小往往在 32-512 之间取值。批次大小对于训练速度是有所影响的

  • epochs 完成整个数据集进行训练的次数

开始训练

return await model.fit(inputs, labels, {
  batchSize,
  epochs,
  callbacks: tfvis.show.fitCallbacks(
    { name: 'Training Performance' },
    ['loss', 'mse'],
    { height: 200, callbacks: ['onEpochEnd'] }
  )
});

model.fit 是来启动训练的函数。这是一个异步函数,所以返回会是一个 promise。

为了监控训练进度,回调传函数作为 model.fit 来获取训练过程中信息。然后回调函数使用 tfvis.show.fitCallbacks 来定义,然后可以绘制损失值对于迭代的图标

const tensorData = convertToTensor(data);
const {inputs, labels} = tensorData;

// Train the model
await trainModel(model, inputs, labels);
console.log('Done Training');

这的注意的这部分代码要写在 run 函数中,具体如下

async function run() {
    // 加载数据
    const data = await getData();
    // 处理原始数据,将数据 horsepower 映射为 x 而 mpg 则映射为 y
    const values = data.map(d => ({
      x: d.horsepower,
      y: d.mpg,
    }));
    // 将数据以散点图形式显示在开发者调试工具
    
  
    tfvis.render.scatterplot(
      {name: 'Horsepower v MPG'},
      {values},
      {
        xLabel: 'Horsepower',
        yLabel: 'MPG',
        height: 300
      }
    );

    const model = createModel();
    const tensorData = convertToTensor(data);
    const {inputs, labels} = tensorData;

    // Train the model
    await trainModel(model, inputs, labels);
    console.log('Done Training');
}

本文章由javascript技术分享原创和收集

发表评论 (审核通过后显示评论):