pytorch 的 C++ extension 写法
发布于 2019-12-31 · Monstarrrr
2019年的最后一天了,终于填了一个早就想了解的坑。就是关于 pytorch 如何自定义一个扩展,这里主要是说 C++ 扩展。
为什么需要扩展?
python 调用 C++ 的库也是可行的啊。刚开始我也在思考这个问题,觉得没有必要。但是后来深入了解了以后发现还是有必要的。调用始终是使用的是别人的东西,但是扩展则是通过他人的帮助来完成一个属于自己的东西。
pytorch 的 C++ extension 和 python 的 c/c++ extension 其实原理差不多,本质上都是为了扩展各自的功能,当然也为了使程序运行更加有效率,差别在于 pytorch 的 C++ extension 实施步骤较 python 的 c/c++ extension 的要简化一些。
这里以实现神经网络自定义的 layer 为例:
基本流程
- 利用 C++ 写好自定义层的功能,主要包括前向传播和反向传播,以及 pybind11 的内容。
- 写好 setup.py 脚本,并利用 python 提供的 setuptools 来编译并加载 C++ 代码。
- 编译安装,在 python 中调用 C++ 扩展接口。
pybind11 是 python 的一个库,主要负责 python 与 C++11 之间的通信。
示例:z = 2x + y
第一步:编写头文件 test.h
/*test.h*/
#include <torch/extension.h>
#include <vector>
// forward propagation
torch::Tensor Test_forward_cpu(const torch::Tensor& inputA, const torch::Tensor& inputB);
// backward propagation
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput);
这里包含一个重要的头文件 <torch/extension.h>,这个头文件里面包含很多重要的模块。如用于 python 和 C++11 交互的 pybind11,以及包含 Tensor 的一系列定义操作。
第二步:源文件 test.cpp
/*test.cpp*/
#include "test.h"
// part1: forward propagation
torch::Tensor Test_forward_cpu(const torch::Tensor& x, const torch::Tensor& y) {
AT_ASSERTM(x.sizes() == y.sizes());
torch::Tensor z = torch::zeros(x.sizes());
z = 2 * x + y;
return z;
}
// part2: backward propagation
std::vector<torch::Tensor> Test_backward_cpu(const torch::Tensor& gradOutput) {
torch::Tensor gradOutputX = 2 * gradOutput * torch::ones(gradOutput.sizes());
torch::Tensor gradOutputY = gradOutput * torch::ones(gradOutput.sizes());
return {gradOutputX, gradOutputY};
}
// part3: pybind11(将 python 与 C++11 进行绑定)
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &Test_forward_cpu, "Test forward");
m.def("backward", &Test_backward_cpu, "Test backward");
}
源文件里面包含了三个部分:
- forward 函数
- backward 函数
- pytorch 和 C++ 交互的部分(pybind11)
第三步:编写 setup.py
文件目录排布:
├── setup.py
├── src/
│ ├── test.h
│ └── test.cpp
from setuptools import setup
import os
import glob
from torch.utils.cpp_extension import BuildExtension, CppExtension
# 头文件目录
include_dirs = os.path.dirname(os.path.abspath(__file__))
# 源代码目录
source_file = glob.glob(os.path.join(working_dirs, 'src', '*.cpp'))
setup(
name='test_cpp', # 模块名称
ext_modules=[
CppExtension('test_cpp',
sources=source_file,
include_dirs=[include_dirs])
],
cmdclass={'build_ext': BuildExtension}
)
第四步:编译安装
python setup.py install
建议将扩展安装在个人虚拟环境中。这一步包含了 build + install,执行的是先编译链接动态链接库,然后将构建好的文件以 package 的形式安装存放在当前开发环境的 package 的集中存放处。
完整内容请参考原始知乎文章:https://www.zhihu.com