基于 TensorFlow.js 的概念
2026-01-12
基于 TensorFlow.js 的概念,编写一个可视化的神经网络训练演示器。用户可以自定义网络层数(Hidden Layers)和每层的神经元数量。在画布上实时展示训练数据点(分类问题)和决策边界(Decision Boundary)的动态变化。右侧实时绘制 Loss 曲线。交互上,允许用户调节学习率(Learning Rate)、激活函数等超参数,并直观看到对训练收敛速度的影响。UI 需体现“黑盒透明化”的教育意义。
思考过程
**Envisioning the Demo**
I've been focusing on the core user need: a web-based, visual neural network demonstrator built with TensorFlow.js. Right now, I'm concentrating on the essential features. Custom network architecture and real-time visualization are at the top of my list. I'm visualizing the UI, thinking about how users will interact with the architecture controls and how the network's training will be displayed.
**Planning the Architecture**
I'm now deep in the design phase. I've broken down the project into key components: data generation for datasets like XOR and Spirals, a model management system using `tf.sequential`, and a training loop leveraging `model.fit()`. I'm leaning toward using HTML5 Canvas for the decision boundaries and a custom line plotter for the loss curve to keep things lean.
**Structuring the Components**
I'm now outlining how the components will fit together. I'm focusing on the canvas elements: one for the decision boundary visualization and another for the loss curve. I'll need to generate datasets (XOR, circles, spirals) and create the model using `tf.sequential`. I plan to control the training loop for UI responsiveness.
**Mapping Out the Code**
I'm presently building the code outlines. First, I'm constructing a simple HTML layout with controls and visualization areas. The core logic involves dataset creation, model compilation using TensorFlow.js layers, and a training loop, likely using `model.fit` with custom callbacks. I'll need to visualize the decision boundaries and loss, considering performance implications for the former. I'm focusing on the main steps, like model setup, the training process, and UI interactions.
**Planning the Implementation**
Now I'm moving toward drafting the code. I'm focusing on the canvas elements, along with the decision boundary visualization and the loss curve. I'll need to generate datasets, like spirals, XOR, etc., and then create the model with `tf.sequential`. I plan to control the training loop for UI responsiveness, using `model.fit()` with a callback for visualizing updates. I'm taking notes on potential performance optimizations as I go.
**Outlining Refinements**
I'm now iterating on the implementation, focusing on performance. I've pinpointed the decision boundary drawing as a potential bottleneck. I'm exploring optimization techniques like drawing on an offscreen canvas and scaling up, and also, I'm going to leverage `tf.tidy()` to manage memory efficiently. I am also considering how to structure the training loop, opting for an async `model.fit` approach with callbacks.
**Defining the Architecture**
I'm now outlining how the model architecture will be defined within the UI, considering input fields or sliders for hidden layers and neurons, leaning toward a text input format like "8, 4, 2" for layer specifications. This allows for dynamic and flexible model structures with a minimal interface. I'm focusing on the training loop to ensure a smooth experience.
这是一个单文件解决方案,包含了 HTML、CSS 和 JavaScript。它使用 TensorFlow.js 构建了一个完全在浏览器中运行的神经网络可视化工坊。
### 主要功能特点:
1. **架构自定义**:通过输入框定义隐藏层结构(例如输入 `8, 4` 代表第一层8个神经元,第二层4个)。
2. **实时渲染**:
* **决策边界(热力图)**:背景颜色代表网络对不同区域的预测概率(蓝色为类0,橙色为类1,白色为决策边界)。
* **训练数据**:散点图覆盖在热力图上。
* **Loss 曲线**:右侧实时绘制损失下降趋势。
3. **交互控制**:支持动态调整学习率、激活函数、数据分布类型,并支持暂停/重置。
4. **性能优化**:使用 `tf.tidy` 管理内存,采用低分辨率推理+Canvas自动插值来实现流畅的实时 60FPS 渲染。
### 代码实现
你可以将以下代码保存为 `.html` 文件,然后直接在浏览器中打开。
```html
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>TF.js 神经网络可视化演示器</title>
<!-- 引入 TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@4.10.0/dist/tf.min.js"></script>
<style>
:root {
--bg-color: #1a1a2e;
--panel-color: #16213e;
--text-color: #e0e0e0;
--accent-color: #0f3460;
--highlight: #e94560;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background-color: var(--bg-color);
color: var(--text-color);
margin: 0;
display: flex;
height: 100vh;
overflow: hidden;
}
/* 侧边栏控制区 */
.sidebar {
width: 300px;
background-color: var(--panel-color);
padding: 20px;
display: flex;
flex-direction: column;
gap: 15px;
box-shadow: 2px 0 10px rgba(0,0,0,0.5);
z-index: 10;
}
h2 { margin: 0 0 10px 0; font-size: 1.2rem; color: var(--highlight); }
h3 { margin: 5px 0; font-size: 0.9rem; opacity: 0.8; }
.control-group {
background: rgba(0,0,0,0.2);
padding: 10px;
border-radius: 8px;
}
label { display: block; font-size: 0.85rem; margin-bottom: 5px; }
input[type="text"], select, input[type="number"] {
width: 100%;
padding: 8px;
background: #0f3460;
border: 1px solid #333;
color: white;
border-radius: 4px;
box-sizing: border-box;
}
.slider-container { display: flex; align-items: center; gap: 10px; }
input[type="range"] { flex-grow: 1; }
.btn-group { display: flex; gap: 10px; margin-top: 10px; }
button {
flex: 1;
padding: 10px;
border: none;
border-radius: 4px;
cursor: pointer;
font-weight: bold;
transition: 0.2s;
}
.btn-primary { background-color: var(--highlight); color: white; }
.btn-primary:hover { background-color: #ff2e4c; }
.btn-secondary { background-color: #444; color: white; }
.btn-secondary:hover { background-color: #666; }
/* 主画布区域 */
.main-stage {
flex-grow: 1;
display: flex;
flex-direction: column; /* 垂直排列 */
align-items: center;
justify-content: center;
position: relative;
padding: 20px;
gap: 20px;
}
.canvas-container {
position: relative;
box-shadow: 0 0 20px rgba(0,0,0,0.5);
border-radius: 4px;
overflow: hidden;
}
/* 决策边界画布 */
#vizCanvas {
background: #000;
/* 关键:使用CSS拉伸低分辨率Canvas以获得平滑插值效果 */
width: 500px;
height: 500px;
image-rendering: pixelated; /* 或者 auto */
}
/* 状态显示 */
.stats {
position: absolute;
top: 10px;
left: 10px;
background: rgba(0,0,0,0.6);
padding: 5px 10px;
border-radius: 4px;
font-family: monospace;
pointer-events: none;
}
/* Loss 曲线区域 (放在右侧或下方) */
.loss-panel {
width: 500px;
height: 150px;
background: var(--panel-color);
border-radius: 4px;
position: relative;
padding: 10px;
box-sizing: border-box;
}
#lossCanvas { width: 100%; height: 100%; }
</style>
</head>
<body>
<!-- 控制面板 -->
<div class="sidebar">
<h2>神经网络工坊</h2>
<div class="control-group">
<h3>数据 (Data)</h3>
<select id="dataSelect">
<option value="circle">同心圆 (Circles)</option>
<option value="xor">异或 (XOR)</option>
<option value="gauss">高斯分布 (Gaussian)</option>
<option value="spiral">螺旋 (Spiral - Hard)</option>
</select>
<label style="margin-top:10px">噪声 (Noise): <span id="noiseVal">0</span></label>
<div class="slider-container">
<input type="range" id="noise" min="0" max="50" value="0">
</div>
</div>
<div class="control-group">
<h3>模型架构 (Architecture)</h3>
<label>隐藏层结构 (逗号分隔):</label>
<input type="text" id="layersInput" value="8, 4" placeholder="例如: 8, 4, 2">
<small style="color: #888">输入每层的神经元数量</small>
</div>
<div class="control-group">
<h3>超参数 (Hyperparams)</h3>
<label>学习率 (Learning Rate):</label>
<select id="lrSelect">
<option value="0.001">0.001</option>
<option value="0.01">0.01</option>
<option value="0.03" selected>0.03</option>
<option value="0.1">0.1</option>
<option value="0.3">0.3</option>
</select>
<label style="margin-top:10px">激活函数 (Activation):</label>
<select id="actSelect">
<option value="tanh" selected>Tanh</option>
<option value="relu">ReLU</option>
<option value="sigmoid">Sigmoid</option>
</select>
</div>
<div class="btn-group">
<button id="trainBtn" class="btn-primary">开始训练</button>
<button id="resetBtn" class="btn-secondary">重置网络</button>
</div>
<div style="margin-top: auto; font-size: 0.8rem; color: #888;">
基于 TensorFlow.js<br>
可视化演示器
</div>
</div>
<!-- 主展示区 -->
<div class="main-stage">
<!-- 决策边界与数据点 -->
<div class="canvas-container">
<!-- 实际渲染分辨率较低,通过CSS拉伸,提升性能 -->
<canvas id="vizCanvas" width="100" height="100"></canvas>
<div class="stats" id="epochStats">Epoch: 0 | Loss: -</div>
</div>
<!-- Loss 曲线 -->
<div class="loss-panel">
<h3 style="position:absolute; top:5px; left:10px; margin:0;">Training Loss</h3>
<canvas id="lossCanvas"></canvas>
</div>
</div>
<script>
/**
* 全局状态与配置
*/
let model;
let isTraining = false;
let animationId;
let dataPoints = []; // [{x, y, label}, ...]
let lossHistory = [];
let epoch = 0;
// 画布上下文
const vizCanvas = document.getElementById('vizCanvas');
const vizCtx = vizCanvas.getContext('2d');
const lossCanvas = document.getElementById('lossCanvas');
const lossCtx = lossCanvas.getContext('2d');
// 分辨率设置
const VIZ_RES = 100; // 内部推理分辨率 100x100
const CANVAS_SIZE = 500; // CSS显示大小
/**
* 1. 数据生成模块
*/
function generateData(type, count = 400) {
const data = [];
const noise = parseInt(document.getElementById('noise').value) / 1000;
for (let i = 0; i < count; i++) {
let x = 0, y = 0, label = 0;
if (type === 'gauss') {
label = Math.random() > 0.5 ? 1 : 0;
const centerX = label === 0 ? -0.5 : 0.5;
const centerY = label === 0 ? -0.5 : 0.5;
x = centerX + (Math.random() - 0.5) * 0.8 + (Math.random()-0.5) * noise * 5;
y = centerY + (Math.random() - 0.5) * 0.8 + (Math.random()-0.5) * noise * 5;
}
else if (type === 'circle') {
const r = Math.random() * 2; // radius 0 to 2
const angle = Math.random() * 2 * Math.PI;
// Circle 1: radius < 1, Circle 2: radius > 1.2
const dist = Math.sqrt(Math.random()) * 2.5;
x = Math.cos(angle) * dist;
y = Math.sin(angle) * dist;
// Label based on distance from center
label = (dist < 1.3) ? 0 : 1;
// Add noise
x += (Math.random() - 0.5) * noise;
y += (Math.random() - 0.5) * noise;
// Normalize to approximately -1 to 1
x /= 2.5; y /= 2.5;
}
else if (type === 'xor') {
x = (Math.random() - 0.5) * 2;
y = (Math.random() - 0.5) * 2;
const padding = 0.05;
x += (x > 0 ? padding : -padding); // separate slightly
label = (x > 0 && y > 0) || (x < 0 && y < 0) ? 0 : 1;
x += (Math.random() - 0.5) * noise * 2;
y += (Math.random() - 0.5) * noise * 2;
}
else if (type === 'spiral') {
// 经典的双螺旋
const n = count / 2;
const j = i % 2; // class 0 or 1
const r = (i / n) * 1 + 0.2; // radius
const t = 1.75 * i / n * 2 * Math.PI + (j * Math.PI); // theta
x = r * Math.sin(t) + (Math.random() - 0.5) * noise;
y = r * Math.cos(t) + (Math.random() - 0.5) * noise;
label = j;
// Scale down to fit -1, 1
x /= 1.5; y /= 1.5;
}
data.push({ x, y, label });
}
return data;
}
/**
* 2. 模型构建模块
*/
function createModel() {
const layersStr = document.getElementById('layersInput').value;
const units = layersStr.split(',').map(s => parseInt(s.trim())).filter(n => !isNaN(n));
const lr = parseFloat(document.getElementById('lrSelect').value);
const act = document.getElementById('actSelect').value;
// 清理旧模型
if (model) model.dispose();
model = tf.sequential();
// 动态构建隐藏层
units.forEach((u, i) => {
const config = {
units: u,
activation: act,
kernelInitializer: 'glorotNormal'
};
// 第一层需要 inputShape
if (i === 0) config.inputShape = [2];
model.add(tf.layers.dense(config));
});
// 输出层:二分类使用 Sigmoid
model.add(tf.layers.dense({ units: 1, activation: 'sigmoid' }));
model.compile({
optimizer: tf.train.adam(lr),
loss: 'binaryCrossentropy',
metrics: ['accuracy']
});
epoch = 0;
lossHistory = [];
document.getElementById('epochStats').innerText = `Epoch: 0 | Loss: -`;
// 初始随机状态绘制
draw();
}
/**
* 3. 可视化模块:决策边界与Loss
*/
// 预计算网格坐标张量 (只做一次以优化性能)
let gridTensor;
function getGridTensor() {
if (gridTensor) return gridTensor;
const coords = [];
// 生成 -1 到 1 的网格
for (let y = 0; y < VIZ_RES; y++) {
for (let x = 0; x < VIZ_RES; x++) {
// 坐标映射:0..99 -> -1..1 (注意 Y 轴反转,Canvas 上方是0)
const u = (x / (VIZ_RES - 1)) * 2 - 1;
const v = -((y / (VIZ_RES - 1)) * 2 - 1);
coords.push([u, v]);
}
}
gridTensor = tf.tensor2d(coords);
return gridTensor;
}
async function drawBoundary() {
if (!model) return;
// 使用 tf.tidy 自动清除中间张量,防止内存泄漏
const predictions = tf.tidy(() => {
const grid = getGridTensor();
return model.predict(grid).dataSync(); // 获取预测结果数组
});
const imgData = vizCtx.createImageData(VIZ_RES, VIZ_RES);
const data = imgData.data;
// 填充像素
for (let i = 0; i < predictions.length; i++) {
const p = predictions[i]; // 0 到 1 的概率
// 配色方案:
// p < 0.5 (类 0): 蓝色背景
// p > 0.5 (类 1): 橙色背景
// p ~ 0.5 (边界): 白色/浅色
let r, g, b;
// 简单的热力图插值
if (p < 0.5) {
// 趋向蓝色 (0, 100, 255)
const intensity = (0.5 - p) * 2;
r = 255 * (1 - intensity) + 0 * intensity;
g = 255 * (1 - intensity) + 100 * intensity;
b = 255 * (1 - intensity) + 255 * intensity;
} else {
// 趋向橙色 (255, 150, 0)
const intensity = (p - 0.5) * 2;
r = 255 * (1 - intensity) + 255 * intensity;
g = 255 * (1 - intensity) + 150 * intensity;
b = 255 * (1 - intensity) + 0 * intensity;
}
data[i * 4 + 0] = r;
data[i * 4 + 1] = g;
data[i * 4 + 2] = b;
data[i * 4 + 3] = 255; // Alpha
}
vizCtx.putImageData(imgData, 0, 0);
}
function drawDataPoints() {
// 将坐标 -1..1 映射到 Canvas 像素坐标 0..VIZ_RES
const toPx = (v, isY) => {
const norm = (v + 1) / 2; // 0..1
return isY ? (1 - norm) * VIZ_RES : norm * VIZ_RES;
};
dataPoints.forEach(pt => {
vizCtx.beginPath();
const cx = toPx(pt.x, false);
const cy = toPx(pt.y, true);
vizCtx.arc(cx, cy, 2, 0, 2 * Math.PI);
// 类0: 蓝色实心,白边
// 类1: 橙色实心,白边
vizCtx.fillStyle = pt.label === 0 ? '#4d96ff' : '#ff6b6b';
vizCtx.strokeStyle = 'white';
vizCtx.lineWidth = 0.5;
vizCtx.fill();
vizCtx.stroke();
});
}
function drawLoss() {
// 简单绘制 Loss 曲线
const w = lossCanvas.width = lossCanvas.clientWidth;
const h = lossCanvas.height = lossCanvas.clientHeight;
lossCtx.clearRect(0, 0, w, h);
if (lossHistory.length < 2) return;
// 寻找最大值以便缩放
const maxLoss = Math.max(...lossHistory) || 1;
lossCtx.beginPath();
lossCtx.strokeStyle = '#00fff0'; // 青色线条
lossCtx.lineWidth = 2;
for (let i = 0; i < lossHistory.length; i++) {
const x = (i / (lossHistory.length - 1)) * w;
// loss越高 y越小(上方),loss=0 在最下方
const y = h - (lossHistory[i] / maxLoss) * (h * 0.8) - 10;
if (i === 0) lossCtx.moveTo(x, y);
else lossCtx.lineTo(x, y);
}
lossCtx.stroke();
}
async function draw() {
// 1. 画决策边界
await drawBoundary();
// 2. 画原始数据点
drawDataPoints();
// 3. 画 Loss
drawLoss();
}
/**
* 4. 训练循环
*/
async function trainStep() {
if (!isTraining) return;
// 将数据转换为 Tensor (为了效率,实际应用中应在外部做一次,这里为了演示简单每次做)
// 更好的做法:在 startTraining 前 convert,loop 中使用 dataset
const xs = tf.tensor2d(dataPoints.map(p => [p.x, p.y]));
const ys = tf.tensor2d(dataPoints.map(p => [p.label]));
const history = await model.fit(xs, ys, {
batchSize: 32,
epochs: 1,
shuffle: true,
});
const loss = history.history.loss[0];
lossHistory.push(loss);
epoch++;
document.getElementById('epochStats').innerText = `Epoch: ${epoch} | Loss: ${loss.toFixed(4)}`;
// 内存清理
xs.dispose();
ys.dispose();
// 绘制当前状态
await draw();
// 继续下一帧
if (isTraining) {
requestAnimationFrame(trainStep);
}
}
/**
* 5. UI 事件绑定
*/
// 数据重置
function resetData() {
const type = document.getElementById('dataSelect').value;
dataPoints = generateData(type);
draw();
}
document.getElementById('noise').addEventListener('input', (e) => {
document.getElementById('noiseVal').innerText = e.target.value;
resetData();
});
document.getElementById('dataSelect').addEventListener('change', resetData);
// 训练控制
const trainBtn = document.getElementById('trainBtn');
trainBtn.addEventListener('click', () => {
if (isTraining) {
// 暂停
isTraining = false;
trainBtn.innerText = "继续训练";
trainBtn.classList.replace('btn-secondary', 'btn-primary');
} else {
// 开始
if (!model) createModel();
isTraining = true;
trainBtn.innerText = "暂停训练";
trainBtn.classList.replace('btn-primary', 'btn-secondary');
trainStep();
}
});
document.getElementById('resetBtn').addEventListener('click', () => {
isTraining = false;
trainBtn.innerText = "开始训练";
trainBtn.classList.replace('btn-secondary', 'btn-primary');
createModel(); // 重新编译模型,重置权重
});
// 监听架构变化,即时重建模型
document.getElementById('layersInput').addEventListener('change', () => {
// 停止训练并重置
isTraining = false;
trainBtn.innerText = "开始训练";
createModel();
});
document.getElementById('lrSelect').addEventListener('change', createModel);
document.getElementById('actSelect').addEventListener('change', createModel);
// 初始化
resetData();
createModel();
</script>
</body>
</html>
```
### 设计思路解析(UI 体现教育意义):
1. **黑盒透明化**:
* 通过 `drawBoundary` 函数,我们将神经网络这个“黑盒”在整个二维平面上的输出(Probabilities)直接映射为颜色。
* 学生可以直观地看到:当神经网络层数很少时(例如只有一层),决策边界是一条直线(线性);当加入隐藏层和非线性激活函数(Tanh/ReLU)后,决策边界开始弯曲、包裹数据。
2. **动态反馈**:
* **学习率(LR)**:如果用户将 LR 设得很大(0.3),Loss 曲线会震荡;设得很小(0.001),收敛极慢。
* **激活函数**:切换到 `ReLU` 可能会导致决策边界呈现折线状,而 `Tanh` 则更平滑。
3. **技术细节**:
* **坐标归一化**:为了简化计算,所有数据(X, Y)都生成在 -1 到 1 的范围内。
* **异步训练**:使用了 `requestAnimationFrame` 配合 `model.fit({epochs: 1})`。虽然 `fit` 本身是异步的,但为了在每一轮(Epoch)结束后都能刷新 UI 且不卡死浏览器,这种“手动步进”的方式最可控。
* **Grid Tensor 缓存**:`getGridTensor()` 只在初始化时运行一次,避免在每一帧渲染时重复创建巨大的坐标数组,极大提高了 FPS。