[ PROMPT_NODE_22686 ]
Mlops Tensorboard 集成
[ SKILL_DOCUMENTATION ]
# 框架集成指南
将 TensorBoard 与主流机器学习框架集成的完整指南。
## 目录
- PyTorch
- TensorFlow/Keras
- PyTorch Lightning
- HuggingFace Transformers
- Fast.ai
- JAX
- scikit-learn
## PyTorch
### 基础集成
python
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
# 创建 writer
writer = SummaryWriter('runs/pytorch_experiment')
# 模型和优化器
model = ResNet50()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# 记录模型图
dummy_input = torch.randn(1, 3, 224, 224)
writer.add_graph(model, dummy_input)
# 训练循环
for epoch in range(100):
model.train()
train_loss = 0.0
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
train_loss += loss.item()
# 记录批次指标
if batch_idx % 100 == 0:
global_step = epoch * len(train_loader) + batch_idx
writer.add_scalar('Loss/train_batch', loss.item(), global_step)
# Epoch 指标
train_loss /= len(train_loader)
writer.add_scalar('Loss/train_epoch', train_loss, epoch)
# 记录直方图
for name, param in model.named_parameters():
writer.add_histogram(name, param, epoch)
writer.close()
### torchvision 集成
python
from torchvision.utils import make_grid
# 记录图像批次
for batch_idx, (images, labels) in enumerate(train_loader):
if batch_idx == 0: # 第一批次
img_grid = make_grid(images[:64], nrow=8)
writer.add_image('Training_batch', img_grid, epoch)
break
### 分布式训练
python
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# 设置
dist.init_process_group(backend='nccl')
rank = dist.get_rank()
# 仅从 rank 0 记录
if rank == 0:
writer = SummaryWriter('runs/distributed_experiment')
model = DDP(model, device_ids=[rank])
for epoch in range(100):
train_loss = train_epoch()
# 仅从 rank 0 记录
if rank == 0:
writer.add_scalar('Loss/train', train_loss, epoch)
## TensorFlow/Keras
### Keras 回调
python
import tensorflow as tf
# TensorBoard 回调
tensorboard_callback = tf.keras.callbacks.TensorBoard(
log_dir='logs/keras_experiment',
histogram_freq=1,