菜单

Administrator
发布于 2026-05-19 / 2 阅读
0
0

PyTorch 的 C++ Extension 写法

pytorch 的 C++ extension 写法

发布于 2019-12-31 · Monstarrrr

2019年的最后一天了,终于填了一个早就想了解的坑。就是关于 pytorch 如何自定义一个扩展,这里主要是说 C++ 扩展。

为什么需要扩展?

python 调用 C++ 的库也是可行的啊。刚开始我也在思考这个问题,觉得没有必要。但是后来深入了解了以后发现还是有必要的。调用始终是使用的是别人的东西,但是扩展则是通过他人的帮助来完成一个属于自己的东西。

pytorch 的 C++ extension 和 python 的 c/c++ extension 其实原理差不多,本质上都是为了扩展各自的功能,当然也为了使程序运行更加有效率,差别在于 pytorch 的 C++ extension 实施步骤较 python 的 c/c++ extension 的要简化一些。

这里以实现神经网络自定义的 layer 为例:

基本流程

  1. 利用 C++ 写好自定义层的功能,主要包括前向传播和反向传播,以及 pybind11 的内容。
  2. 写好 setup.py 脚本,并利用 python 提供的 setuptools 来编译并加载 C++ 代码。
  3. 编译安装,在 python 中调用 C++ 扩展接口。

pybind11 是 python 的一个库,主要负责 python 与 C++11 之间的通信。

示例:z = 2x + y

第一步:编写头文件 test.h

/*test.h*/
#include <torch/extension.h>
#include <vector>

// forward propagation
torch::Tensor Test_forward_cpu(const torch::Tensor& inputA, const torch::Tensor& inputB);

// backward propagation
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput);

这里包含一个重要的头文件 <torch/extension.h>,这个头文件里面包含很多重要的模块。如用于 python 和 C++11 交互的 pybind11,以及包含 Tensor 的一系列定义操作。

第二步:源文件 test.cpp

/*test.cpp*/
#include "test.h"

// part1: forward propagation
torch::Tensor Test_forward_cpu(const torch::Tensor& x, const torch::Tensor& y) {
    AT_ASSERTM(x.sizes() == y.sizes());
    torch::Tensor z = torch::zeros(x.sizes());
    z = 2 * x + y;
    return z;
}

// part2: backward propagation
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput) {
    torch::Tensor gradOutputX = 2 * gradOutput * torch::ones(gradOutput.sizes());
    torch::Tensor gradOutputY = gradOutput * torch::ones(gradOutput.sizes());
    return {gradOutputX, gradOutputY};
}

// part3: pybind11(将 python 与 C++11 进行绑定)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("forward", &Test_forward_cpu, "Test forward");
    m.def("backward", &Test_backward_cpu, "Test backward");
}

源文件里面包含了三个部分:

  • forward 函数
  • backward 函数
  • pytorch 和 C++ 交互的部分(pybind11)

第三步:编写 setup.py

文件目录排布:

├── setup.py
├── src/
│   ├── test.h
│   └── test.cpp
from setuptools import setup
import os
import glob
from torch.utils.cpp_extension import BuildExtension, CppExtension

# 头文件目录
include_dirs = os.path.dirname(os.path.abspath(__file__))

# 源代码目录
source_file = glob.glob(os.path.join(working_dirs, 'src', '*.cpp'))

setup(
    name='test_cpp',  # 模块名称
    ext_modules=[
        CppExtension('test_cpp',
            sources=source_file,
            include_dirs=[include_dirs])
    ],
    cmdclass={'build_ext': BuildExtension}
)

第四步:编译安装

python setup.py install

建议将扩展安装在个人虚拟环境中。这一步包含了 build + install,执行的是先编译链接动态链接库,然后将构建好的文件以 package 的形式安装存放在当前开发环境的 package 的集中存放处。

完整内容请参考原始知乎文章:https://www.zhihu.com


评论