在Rust中从头开始构建张量库是一个极具挑战性但也非常有价值的项目。作为深度学习的基础数据结构,张量(Tensor)的高效实现直接影响到机器学习框架的性能。这个系列教程的1.3部分聚焦于数据操作,这是张量库最核心的功能之一。
我在实现自己的深度学习框架时发现,数据操作看似简单,实则暗藏玄机。一个设计良好的张量操作接口不仅能提高开发效率,还能显著提升运行时性能。Rust的所有权系统和零成本抽象特性,让我们能够在保证安全性的同时,实现接近C++的性能。
张量操作大致可以分为以下几类:
在Rust中实现这些操作需要考虑:
rust复制pub struct Tensor<T> {
data: Vec<T>, // 连续存储的数据
shape: Vec<usize>, // 各维度大小
strides: Vec<usize>, // 各维度步长
offset: usize, // 数据起始偏移
}
这种设计借鉴了NumPy的核心思想,但加入了Rust特有的安全保证。strides和offset的引入使得视图(view)操作可以零拷贝实现。
rust复制pub trait TensorOps<T> {
fn reshape(&self, new_shape: &[usize]) -> Result<Tensor<T>, TensorError>;
fn transpose(&self, dims: &[usize]) -> Result<Tensor<T>, TensorError>;
fn matmul(&self, other: &Tensor<T>) -> Result<Tensor<T>, TensorError>;
// 其他操作...
}
使用trait而不是直接实现方法,可以更好地组织代码并支持扩展。
视图操作(如reshape、transpose)不实际移动数据,而是通过调整shape和strides来实现:
rust复制impl<T> Tensor<T> {
pub fn view(&self, new_shape: &[usize]) -> Result<Self, TensorError> {
// 检查元素总数是否匹配
if new_shape.iter().product::<usize>() != self.numel() {
return Err(TensorError::ShapeMismatch);
}
// 计算新strides (行主序)
let mut new_strides = vec![1; new_shape.len()];
for i in (0..new_shape.len()-1).rev() {
new_strides[i] = new_strides[i+1] * new_shape[i+1];
}
Ok(Tensor {
data: self.data.clone(), // 浅拷贝Arc
shape: new_shape.to_vec(),
strides: new_strides,
offset: self.offset,
})
}
}
矩阵乘法是深度学习中最频繁的操作,需要特别优化:
rust复制impl<T: Num + Copy> Tensor<T> {
pub fn matmul(&self, other: &Tensor<T>) -> Result<Tensor<T>, TensorError> {
// 检查形状兼容性
if self.shape().len() != 2 || other.shape().len() != 2 {
return Err(TensorError::ShapeMismatch);
}
if self.shape()[1] != other.shape()[0] {
return Err(TensorError::ShapeMismatch);
}
let m = self.shape()[0];
let n = other.shape()[1];
let k = self.shape()[1];
let mut result_data = vec![T::zero(); m * n];
// 使用分块技术提高缓存命中率
const BLOCK_SIZE: usize = 64;
for i in (0..m).step_by(BLOCK_SIZE) {
for j in (0..n).step_by(BLOCK_SIZE) {
for kk in (0..k).step_by(BLOCK_SIZE) {
// 处理当前块
for ii in i..(i+BLOCK_SIZE).min(m) {
for jj in j..(j+BLOCK_SIZE).min(n) {
let mut sum = T::zero();
for kkk in kk..(kk+BLOCK_SIZE).min(k) {
sum = sum + self.get(&[ii, kkk]) * other.get(&[kkk, jj]);
}
result_data[ii * n + jj] = result_data[ii * n + jj] + sum;
}
}
}
}
}
Tensor::new(result_data, &[m, n])
}
}
广播(broadcasting)是NumPy风格的自动维度扩展,在Rust中实现需要考虑类型安全:
rust复制impl<T: Num + Clone> Tensor<T> {
fn broadcast_to(&self, shape: &[usize]) -> Result<Tensor<T>, TensorError> {
// 检查广播可行性
if shape.len() < self.shape.len() {
return Err(TensorError::BroadcastError);
}
// 对齐维度
let mut new_shape = vec![1; shape.len() - self.shape.len()];
new_shape.extend(self.shape.iter());
let mut new_strides = vec![0; shape.len() - self.shape.len()];
new_strides.extend(self.strides.iter());
// 调整步长
for i in 0..shape.len() {
if new_shape[i] != shape[i] {
if new_shape[i] != 1 {
return Err(TensorError::BroadcastError);
}
new_strides[i] = 0; // 广播维度步长为0
}
}
Ok(Tensor {
data: self.data.clone(),
shape: shape.to_vec(),
strides: new_strides,
offset: self.offset,
})
}
}
为支持神经网络训练,需要实现基本的自动微分:
rust复制pub struct Variable<T> {
data: Tensor<T>,
grad: Option<Tensor<T>>,
requires_grad: bool,
creator: Option<Rc<dyn Function<T>>>,
}
trait Function<T> {
fn forward(&self, inputs: &[&Tensor<T>]) -> Tensor<T>;
fn backward(&self, grad: &Tensor<T>) -> Vec<Tensor<T>>;
}
impl<T: Num + Clone> Variable<T> {
pub fn backward(&mut self) {
if let Some(ref mut grad) = self.grad {
if let Some(ref creator) = self.creator {
let grads = creator.backward(grad);
// 处理梯度传播...
}
}
}
}
rust复制impl<T> Tensor<T> {
pub fn ensure_contiguous(&self) -> Tensor<T> {
if self.is_contiguous() {
return self.clone();
}
// 创建新的连续存储
let mut new_data = Vec::with_capacity(self.numel());
for idx in 0..self.numel() {
let pos = self.flat_to_index(idx);
new_data.push(unsafe { self.get_unchecked(&pos) });
}
Tensor::new(new_data, self.shape()).unwrap()
}
}
利用Rayon库实现数据并行:
rust复制use rayon::prelude::*;
impl<T: Num + Send + Sync> Tensor<T> {
pub fn par_map<F>(&self, f: F) -> Tensor<T>
where
F: Fn(T) -> T + Send + Sync,
{
let new_data = self.data.par_iter().map(|&x| f(x)).collect();
Tensor {
data: new_data,
shape: self.shape.clone(),
strides: self.strides.clone(),
offset: 0,
}
}
}
rust复制#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matmul() {
let a = Tensor::new(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
let b = Tensor::new(vec![5.0, 6.0, 7.0, 8.0], &[2, 2]).unwrap();
let c = a.matmul(&b).unwrap();
assert_eq!(c.shape(), &[2, 2]);
assert_eq!(c.get(&[0, 0]), 19.0);
assert_eq!(c.get(&[0, 1]), 22.0);
assert_eq!(c.get(&[1, 0]), 43.0);
assert_eq!(c.get(&[1, 1]), 50.0);
}
#[test]
fn test_broadcast() {
let a = Tensor::new(vec![1.0, 2.0, 3.0], &[3]).unwrap();
let b = a.broadcast_to(&[2, 3]).unwrap();
assert_eq!(b.shape(), &[2, 3]);
assert_eq!(b.get(&[0, 0]), 1.0);
assert_eq!(b.get(&[1, 2]), 3.0);
}
}
使用criterion.rs进行性能测试:
rust复制use criterion::{criterion_group, criterion_main, Criterion};
fn matmul_benchmark(c: &mut Criterion) {
let a = Tensor::rand(&[256, 256]);
let b = Tensor::rand(&[256, 256]);
c.bench_function("matmul 256x256", |bench| {
bench.iter(|| a.matmul(&b).unwrap())
});
}
criterion_group!(benches, matmul_benchmark);
criterion_main!(benches);
问题:操作链中频繁克隆张量导致性能下降
解决方案:
Arc共享数据所有权Cow(Copy-on-Write)语义rust复制impl<T> Tensor<T> {
pub fn into_shared(self) -> TensorShared<T> {
TensorShared {
data: Arc::new(self.data),
shape: self.shape,
strides: self.strides,
offset: self.offset,
}
}
}
问题:操作前需要频繁检查形状兼容性
解决方案:
debug_assert版本rust复制pub trait Shape {
fn shape(&self) -> &[usize];
fn same_shape(&self, other: &dyn Shape) -> bool {
self.shape() == other.shape()
}
}
问题:不同数值类型需要不同实现
解决方案:
num-traits定义通用数值traitrust复制macro_rules! impl_tensor_ops {
($($t:ty),*) => {
$(
impl TensorOps for Tensor<$t> {
// 通用实现...
}
)*
}
}
impl_tensor_ops!(f32, f64, i32, i64);
rust-gpu或wgpu实现GPU后端cranelift或llvm-sys实现操作融合在实现Rust张量库的过程中,我发现最困难的部分不是算法本身,而是在保证安全性的同时不牺牲性能。Rust的所有权系统虽然增加了学习曲线,但一旦掌握,就能写出既安全又高效的代码。特别是在实现广播和视图操作时,Rust的生命周期检查帮助我避免了许多潜在的内存错误。