7월 25, 2018 —
Posted by Nick Kreeger
In this post we’ll be using TensorFlow.js, D3.js, and the power of the web to visualize the process of training a model to predict balls (blue areas) and strikes (orange areas) from baseball data. As we go, we’ll visualize the strike zone the model understands throughout training. You can run this model entirely in the browser by visiting this Observable notebook.
If you’re…
The GIF above visualizes the neural network learning to call balls (blue areas) and strikes (orange areas). After each training step, the heatmap updates with the predictions from the model. |
const model = tf.sequential();
// Two fully connected layers with dropout between each:
model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));
model.add(tf.layers.dropout({rate: 0.01}));
model.add(tf.layers.dense({units: 16, activation: 'relu'}));
model.add(tf.layers.dropout({rate: 0.01}));
// Only two classes: "strike" and "ball":
model.add(tf.layers.dense({units: 2, activation: 'softmax'}));
model.compile({
optimizer: tf.train.adam(0.01),
loss: 'categoricalCrossentropy',
metrics: ['accuracy']
});
const data = [];
csvData.forEach((values) => {
// 'logit' data uses the 5 fields:
const x = [];
x.push(parseFloat(values.px));
x.push(parseFloat(values.pz));
x.push(parseFloat(values.sz_top));
x.push(parseFloat(values.sz_bot));
x.push(parseFloat(values.left_handed_batter));
// The label is simply 'is strike' or 'is ball':
const y = parseInt(values.is_strike, 10);
data.push({x: x, y: y});
});
// Shuffle the contents to ensure the model does not always train on the same
// sequence of pitch data:
tf.util.shuffle(data);
After parsing the CSV data, the JS types will need to be converted into Tensor batches for training and evaluation. See the code lab for more details on this process. The TensorFlow.js team is working on a new Data API to make this ingestion much easier in the future.
// Trains and reports loss+accuracy for one batch of training data:
async function trainBatch(index) {
const history = await model.fit(batches[index].x, batches[index].y, {
epochs: 1,
shuffle: false,
validationData: [batches[index].x, batches[index].y],
batchSize: CONSTANTS.BATCH_SIZE
});
// Don't block the UI frame by using tf.nextFrame()
await tf.nextFrame();
updateHeatmap();
await tf.nextFrame();
}
This visual shows where the strike zone and prediction matrix relate to home plate and the field of play. |
function predictZone() {
const predictions = model.predictOnBatch(predictionMatrix.data);
const values = predictions.dataSync();
// Sort each value so the higher prediction is the first element in the array:
const results = [];
let index = 0;
for (let i = 0; i < values.length; i++) {
let list = [];
list.push({value: values[index++], strike: 0});
list.push({value: values[index++], strike: 1});
list = list.sort((a, b) => b.value - a.value);
results.push(list);
}
return results;
}
function updateHeatmap() {
rects.data(generateHeatmapData());
rects
.attr('x', (coord) => { return scaleX(coord.x) * CONSTANTS.HEATMAP_SIZE; })
.attr('y', (coord) => { return scaleY(coord.y) * CONSTANTS.HEATMAP_SIZE; })
.attr('width', CONSTANTS.HEATMAP_SIZE)
.attr('height', CONSTANTS.HEATMAP_SIZE)
.style('fill', (coord) => {
if (coord.strike) {
return strikeColorScale(coord.value);
} else {
return ballColorScale(coord.value);
}
});
}
Please see this section for the complete details for drawing a heatmap with D3.
7월 25, 2018
—
Posted by Nick Kreeger
In this post we’ll be using TensorFlow.js, D3.js, and the power of the web to visualize the process of training a model to predict balls (blue areas) and strikes (orange areas) from baseball data. As we go, we’ll visualize the strike zone the model understands throughout training. You can run this model entirely in the browser by visiting this Observable notebook.
If you’re…