在深度学习推理引擎的实现中,Session.Run是整个执行流程的核心入口。作为ONNX Runtime的关键执行路径,它负责将模型定义转化为实际的硬件计算。本文将深入剖析ONNX Runtime中Session.Run的完整调用链路,特别关注计算函数的执行机制和模型的动态重编译过程。
提示:本文分析的代码基于ONNX Runtime 1.14版本,主要针对ROCm后端实现,但核心逻辑同样适用于CUDA等其他计算后端。
ONNX Runtime提供了多语言接口,但最终都会汇聚到C++核心实现。从外部调用到内部实现的转换过程如下:
SessionImpl<T>::Run作为最外层的模板接口,处理基础的类型转换和错误检查。这一层主要确保接口的通用性,同时将用户传入的Value对象转换为内部使用的OrtValue。cpp复制template <typename T>
inline void SessionImpl<T>::Run(const RunOptions& run_options,
const char* const* input_names,
const Value* input_values,
size_t input_count,
const char* const* output_names,
Value* output_values,
size_t output_count) {
// 类型安全转换
static_assert(sizeof(Value) == sizeof(OrtValue*),
"Value本质上是内存中的OrtValue*数组");
auto ort_input_values = reinterpret_cast<const OrtValue* const*>(input_values);
auto ort_output_values = reinterpret_cast<OrtValue**>(output_values);
// 调用底层API实现
ThrowOnError(GetApi().Run(this->p_, run_options, input_names,
ort_input_values, input_count,
output_names, output_count, ort_output_values));
}
OrtApis::Run函数作为C语言接口,负责将C风格参数转换为C++对象,并处理运行选项的默认值情况。这一层实现了ONNX Runtime的跨语言支持基础。cpp复制ORT_API_STATUS_IMPL(OrtApis::Run,
_Inout_ OrtSession* sess,
_In_opt_ const OrtRunOptions* run_options,
_In_reads_(input_len) const char* const* input_names,
_In_reads_(input_len) const OrtValue* const* input,
size_t input_len,
_In_reads_(output_names_len) const char* const* output_names,
size_t output_names_len,
_Inout_updates_all_(output_names_len) OrtValue** output) {
// 参数转换和范围检查
auto session = reinterpret_cast<::onnxruntime::InferenceSession*>(sess);
gsl::span<const char* const> input_names_span(input_names, input_len);
gsl::span<const OrtValue* const> input_span(input, input_len);
gsl::span<const char* const> output_name_span(output_names, output_names_len);
gsl::span<OrtValue*> output_span(output, output_names_len);
// 处理运行选项
Status status;
if (run_options) {
status = session->Run(*run_options, input_names_span, input_span,
output_name_span, output_span);
} else {
const RunOptions default_run_options;
status = session->Run(default_run_options, input_names_span, input_span,
output_name_span, output_span);
}
return ToOrtStatus(status);
}
真正的核心逻辑位于InferenceSession::Run方法中,该方法完成了以下关键工作:
FeedsFetchesManager管理输入输出张量的映射关系utils::ExecuteGraph执行计算图cpp复制Status InferenceSession::Run(const RunOptions& run_options,
gsl::span<const std::string> feed_names,
gsl::span<const OrtValue> feeds,
gsl::span<const std::string> output_names,
std::vector<OrtValue>* p_fetches,
const std::vector<OrtDevice>* p_fetches_device_info) {
// 性能监控和日志记录初始化
TimePoint tp;
if (session_profiler_.IsEnabled()) {
tp = session_profiler_.Start();
}
// 图形注解ID处理(用于CUDA/ROCm图捕获)
int graph_annotation_id = 0;
const std::string& graph_annotation_str =
run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigCudaGraphAnnotation, "");
if (!graph_annotation_str.empty()) {
if (!TryParseStringWithClassicLocale<int>(graph_annotation_str, graph_annotation_id)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Failed to parse the cuda graph annotation id: ",
graph_annotation_str);
}
}
// 线程池和并发控制
const bool control_spinning = use_per_session_threads_ &&
force_spinning_stop_between_runs_ &&
!cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id);
auto* intra_tp = (control_spinning) ? thread_pool_.get() : nullptr;
auto* inter_tp = (control_spinning) ? inter_op_thread_pool_.get() : nullptr;
ThreadPoolSpinningSwitch runs_refcounter_and_tp_spin_control(intra_tp, inter_tp, current_num_runs_);
// 检查是否是图形回放模式
if (cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id)) {
LOGS(*session_logger_, INFO) << "Replaying the captured "
<< cached_execution_provider_for_graph_replay_.Type()
<< " CUDA Graph for this model with tag: " << run_options.run_tag
<< " with graph annotation id: " << graph_annotation_id;
ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id));
} else {
// 正常执行路径
// ...(详细实现见下文)
}
return Status::OK();
}
在执行实际计算前,ONNX Runtime会进行一系列准备工作:
cpp复制// 检查会话是否已初始化
if (!is_inited_) {
LOGS(*session_logger_, ERROR) << "Session was not initialized";
return Status(common::ONNXRUNTIME, common::FAIL, "Session not initialized.");
}
// 验证输入输出
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateInputs(feed_names, feeds));
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateOutputs(output_names, p_fetches));
// 内存池收缩配置
const std::string& shrink_memory_arenas =
run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigEnableMemoryArenaShrinkage, "");
if (!shrink_memory_arenas.empty()) {
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateAndParseShrinkArenaString(shrink_memory_arenas, arenas_to_shrink));
}
// 创建输入输出管理器
FeedsFetchesInfo info(feed_names, output_names, session_state_->GetOrtValueNameIdxMap());
FeedsFetchesManager feeds_fetches_manager{std::move(info)};
// 执行提供者初始化
for (auto& xp : execution_providers_) {
auto start_func = [&xp, &exec_providers_to_stop, &run_options]() {
auto status = xp->OnRunStart(run_options);
if (status.IsOK())
exec_providers_to_stop.push_back(xp.get());
return status;
};
ORT_CHECK_AND_SET_RETVAL(start_func());
}
utils::ExecuteGraph是实际执行计算图的核心函数,它完成了以下工作:
compute_funccpp复制Status ExecuteGraph(const SessionState& session_state,
const FeedsFetchesManager& feeds_fetches_manager,
gsl::span<const OrtValue> feeds,
std::vector<OrtValue>& fetches,
ExecutionMode execution_mode,
const RunOptions& run_options,
#ifdef ORT_ENABLE_STREAM
const DeviceStreamCollection* device_stream_collection,
#endif
const logging::Logger& logger) {
// 获取计算信息
const auto& compute_info = session_state.GetComputeInfo();
// 调用计算函数
return compute_info.compute_func(feeds, fetches,
session_state.GetMutableInitializedTensors(),
session_state.GetDeviceStreamCollection(
#ifdef ORT_ENABLE_STREAM
device_stream_collection
#endif
),
session_state.GetMutablePatternPlanner(),
session_state.GetMutableKernelCreateInfoMap(),
execution_mode, run_options, logger);
}
在模型编译阶段,ONNX Runtime会为每个计算图注册一个compute_func。对于ROCm后端,这个函数通常是一个lambda表达式,封装了MIGraphX的执行逻辑。
cpp复制// 在编译阶段注册compute_func
compute_info_->compute_func = [this](const std::vector<OrtValue>& inputs,
std::vector<OrtValue>& outputs,
const std::unordered_map<std::string, OrtValue>& initializers,
const DeviceStreamCollection* device_streams,
const PatternPlanner* pattern_planner,
const KernelCreateInfoMap* kernel_create_info_map,
ExecutionMode execution_mode,
const RunOptions& run_options,
const logging::Logger& logger) {
// 实际计算逻辑
// ...
};
当输入张量的形状与编译时形状不匹配时,ONNX Runtime会触发模型重编译:
cpp复制// 检查输入形状是否匹配
bool input_shape_match = true;
if (prog_ == nullptr || !program_shape_.empty()) {
if (program_shape_.empty() || program_shape_ != inputs[0].Get<Tensor>().Shape()) {
input_shape_match = false;
}
}
// 需要重编译的情况
if (!input_shape_match) {
// 1. 加载预编译模型或原始ONNX模型
if (!compiled_model_path_.empty()) {
load_precompiled_model(compiled_model_path_);
} else {
parse_onnx_buffer(model_data_.data(), model_data_.size());
}
// 2. 应用量化优化(如INT8/FP16)
if (enable_int8_) {
apply_int8_quantization();
} else if (enable_fp16_) {
apply_fp16_quantization();
}
// 3. 编译模型
prog_.compile(t, co);
// 4. 保存编译结果(如配置了缓存)
if (!compiled_model_path_.empty()) {
save_compiled_model(compiled_model_path_);
}
}
对于ROCm后端,计算函数的执行会调用MIGraphX的运行时:
cpp复制// 将ORT输入转换为MIGraphX argument
std::vector<migraphx::argument> migraphx_inputs;
for (const auto& input : inputs) {
const Tensor& tensor = input.Get<Tensor>();
migraphx_inputs.push_back(create_migraphx_argument(tensor));
}
// 执行计算图
auto prog_outputs = prog_.run_async(migraphx_inputs,
static_cast<hipStream_t>(rocm_stream));
// 将输出拷贝回ORT张量
for (size_t i = 0; i < outputs.size(); ++i) {
Tensor* output_tensor = outputs[i].GetMutable<Tensor>();
copy_migraphx_output_to_ort_tensor(prog_outputs[i], output_tensor);
}
ONNX Runtime支持CUDA/ROCm图形捕获,可以显著减少内核启动开销:
cpp复制// 检查是否启用图形捕获
if (cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() &&
cached_execution_provider_for_graph_replay_.AllowGraphCaptureOnRun(graph_annotation_id) &&
!cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id)) {
LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture.";
ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info));
}
ONNX Runtime提供了精细的线程池控制,以优化多线程环境下的性能:
cpp复制// 线程池旋转控制
ThreadPoolSpinningSwitch runs_refcounter_and_tp_spin_control(intra_tp, inter_tp, current_num_runs_);
// 执行提供者同步控制
bool synchronize_execution_providers =
run_options.config_options.GetConfigOrDefault(
kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0";
当遇到输入输出形状不匹配时,可以:
ORT_LOG_LEVEL=VERBOSE获取详细执行信息RunOptions启用在实际使用中,我发现模型重编译的开销往往被低估。特别是在处理动态形状输入时,频繁的重编译会严重影响性能。一个实用的优化策略是为常见的输入形状预先编译多个版本,运行时根据实际输入形状选择最接近的预编译版本。