在深度学习框架和数值计算领域,张量(Tensor)是最基础的数据结构。PyTorch和NumPy等流行库都围绕张量构建了丰富的功能集。今天我将分享如何在Rust中从零开始构建一个张量库,重点解析核心结构设计和索引操作的实现细节。这个实现会借鉴生产级框架(如Candle)的设计理念,但保持足够简单以便理解底层原理。
张量本质上是多维数组的数学抽象,但在实现时需要仔细考虑以下问题:
生产级框架通常将张量分解为两个部分:
这种分离带来三个关键优势:
以下是基础结构定义:
rust复制#[derive(Debug, Clone, PartialEq)]
struct TensorShape {
shape: Vec<usize>, // 各维度大小
}
impl TensorShape {
fn size(&self) -> usize {
self.shape.iter().product() // 计算元素总数
}
}
#[derive(Debug, Clone, PartialEq)]
struct TensorStorage<T> {
data: Vec<T>, // 连续内存存储
}
#[derive(Debug, Clone, PartialEq)]
struct Tensor<T> {
shape: TensorShape,
storage: TensorStorage<T>,
}
选择Vec<T>作为底层存储的原因:
内存布局决定多维索引如何映射到线性内存。以形状[2,2,2]的张量为例:
行优先(C风格)
code复制0 -> [0,0,0]
1 -> [0,0,1]
2 -> [0,1,0]
...
15 -> [1,1,1]
列优先(Fortran风格)
code复制0 -> [0,0,0]
1 -> [1,0,0]
2 -> [0,1,0]
...
15 -> [1,1,1]
现代深度学习框架普遍采用行优先布局,因为:
我们需要为张量实现zeros构造函数。这里利用num-traits crate提供的Zero trait:
rust复制use num_traits::Zero;
impl<T: Zero + Clone> Tensor<T> {
fn zeros(shape: Vec<usize>) -> Self {
let shape = TensorShape { shape };
let storage = TensorStorage::<T>::zeros(shape.size());
Tensor { shape, storage }
}
}
impl<T: Zero + Clone> TensorStorage<T> {
fn zeros(size: usize) -> Self {
TensorStorage {
data: vec![T::zero(); size],
}
}
}
注意:这里要求泛型参数T实现Zero trait,保证了类型安全的零值初始化。对于自定义类型,只需实现Zero trait即可获得相同能力。
Candle是Hugging Face开发的Rust张量库,其核心结构如下:
rust复制pub struct Tensor_ {
storage: Arc<RwLock<Storage>>, // 线程安全存储
layout: Layout, // 包含形状和步幅
// 其他字段省略
}
pub enum Storage {
Cpu(CpuStorage),
Cuda(CudaStorage), // GPU支持
Metal(MetalStorage),
}
pub enum CpuStorage {
U8(Vec<u8>),
F32(Vec<f32>), // 各种数据类型
// ...
}
关键设计差异:
Arc<RwLock<>>包装存储将多维索引转换为线性索引的过程称为"展平"。计算公式为:
code复制linear_index = index[-1] + index[-2]*shape[-1] + index[-3]*shape[-1]*shape[-2] + ...
等效于索引向量与步幅(stride)向量的点积:
code复制strides = [..., shape[-2]*shape[-1], shape[-1], 1]
实现代码:
rust复制impl TensorShape {
fn ravel_index(&self, indices: &[usize]) -> usize {
assert_eq!(indices.len(), self.shape.len());
indices.iter()
.zip(self.shape.iter())
.rev()
.scan(1, |stride, (&idx, &dim_size)| {
let result = idx * *stride;
*stride *= dim_size;
Some(result)
})
.sum()
}
}
算法复杂度:O(n)其中n是维度数。实际应用中,步幅通常预计算并缓存。
将线性索引转换回多维索引的过程更为复杂。以二维矩阵为例:
code复制linear_index = i * width + j
=> j = linear_index % width
i = linear_index / width
通用实现:
rust复制impl TensorShape {
fn unravel_index(&self, index: usize) -> Vec<usize> {
let mut indices = vec![0; self.shape.len()];
let mut remaining = index;
for (i, &dim_size) in self.shape.iter().enumerate().rev() {
indices[i] = remaining % dim_size;
remaining /= dim_size;
}
indices
}
}
为标准索引语法提供支持:
rust复制use std::ops::{Index, IndexMut};
impl<T> Index<usize> for TensorStorage<T> {
type Output = T;
fn index(&self, index: usize) -> &Self::Output {
&self.data[index]
}
}
impl<T> Index<&[usize]> for Tensor<T> {
type Output = T;
fn index(&self, indices: &[usize]) -> &Self::Output {
&self.storage[self.shape.ravel_index(indices)]
}
}
现在可以这样使用:
rust复制let t = Tensor::zeros(vec![2, 3]);
println!("{}", t[&[1, 2]]); // 访问(1,2)位置元素
Candle等框架会预计算并缓存步幅信息:
rust复制pub struct Layout {
shape: Shape,
stride: Vec<usize>, // 预计算步幅
start_offset: usize, // 内存起始偏移
}
在卷积运算中的典型应用:
rust复制let (b_sz, c, h, w) = layout.shape().dims4()?;
let mut src_index = layout.start_offset();
for b_idx in 0..b_sz {
src_index += b_idx * stride[0]; // 批量维度步幅
for c_idx in 0..c {
src_index += c_idx * stride[1]; // 通道维度步幅
// ...
}
}
这种模式的优势:
生产级实现还会考虑:
例如,Candle的CPU后端会根据数据类型选择最优的内存对齐方式:
rust复制pub enum CpuStorage {
F32(Vec<f32>), // 默认32字节对齐
F64(AlignedVec<f64>), // 特殊对齐处理
// ...
}
完善的测试应覆盖:
示例测试用例:
rust复制#[test]
fn test_ravel_unravel() {
let shape = TensorShape { shape: vec![2, 3, 4] };
let indices = vec![1, 2, 3];
let linear = shape.ravel_index(&indices);
assert_eq!(shape.unravel_index(linear), indices);
}
#[test]
#[should_panic]
fn test_invalid_index() {
let t = Tensor::zeros(vec![2, 2]);
let _ = t[&[3, 0]]; // 应panic
}
使用Rust的criterion库进行性能分析:
rust复制fn indexing_bench(c: &mut Criterion) {
let t = Tensor::zeros(vec![100, 100, 100]);
c.bench_function("ravel_index", |b| {
b.iter(|| t.shape.ravel_index(&[99, 99, 99]))
});
}
重点关注:
当索引行为不符合预期时:
rust复制println!("Strides: {:?}", self.compute_strides());
使用指针运算验证内存布局:
rust复制let ptr = t.storage.data.as_ptr();
let offset = |i, j| { /* 计算偏移量 */ };
assert_eq!(unsafe { *ptr.add(offset(i,j)) }, t[&[i,j]]);
注意:unsafe代码仅用于调试,正式实现应避免
热循环优化:将形状检查移出循环
rust复制// 错误做法:循环内检查
for idx in indices {
assert!(idx < dim);
}
// 正确做法:预先验证
assert!(indices.iter().all(|&idx| idx < dim));
缓存友好访问:优化遍历顺序
rust复制// 行优先存储应按最后维度连续访问
for i in 0..rows {
for j in 0..cols { // 最内层循环遍历连续内存
// ...
}
}
批量操作:减少边界检查
rust复制// 使用get_unchecked在性能关键路径
unsafe {
data.get_unchecked(ravel_index(...))
}
现在我们已经实现了张量核心结构和基本索引操作。接下来的开发方向:
在实现这些高级特性时,当前设计的优势将显现:
建议尝试扩展当前代码:
ones构造函数reshape方法t[1..3, 2]理解这些基础原理后,你将能更好地使用甚至贡献于生产级张量库。Rust的类型系统和所有权模型为构建安全高效的数值计算基础库提供了独特优势。