部分修复:https ://github.com/pytorch/pytorch/issues/394
实施细节:
Codegen 被修改为生成如下所示的代码:
static PyObject * THPVariable_svd(PyObject* self_, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"svd(Tensor input, bool some=True, bool compute_uv=True, *, TensorList[3] out=None)",
}, /*traceable=*/true);
ParsedArgs<6> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
static PyStructSequence_Field fields0[] = {
{"U", ""}, {"S", ""}, {"V", ""}, {nullptr}
};
static PyStructSequence_Desc desc0 = {
"torch.return_types.svd_out", nullptr,
fields0, 3
};
static PyTypeObject type0;
static bool namedtuple_type_initialized0 = false;
if (!namedtuple_type_initialized0) {
PyStructSequence_InitType(&type0, &desc0);
namedtuple_type_initialized0 = true;
}
static PyStructSequence_Field fields1[] = {
{"U", ""}, {"S", ""}, {"V", ""}, {nullptr}
};
static PyStructSequence_Desc desc1 = {
"torch.return_types.svd", nullptr,
fields1, 3
};
static PyTypeObject type1;
static bool namedtuple_type_initialized1 = false;
if (!namedtuple_type_initialized1) {
PyStructSequence_InitType(&type1, &desc1);
namedtuple_type_initialized1 = true;
}
if (r.idx == 0) {
if (r.isNone(3)) {
return wrap(&type1, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2)));
} else {
auto results = r.tensorlist_n<3>(3);
return wrap(&type0, dispatch_svd(r.tensor(0), r.toBool(1), r.toBool(2), results[0], results[1], results[2]));
}
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
类型被定义为THPVariable_${op_name}
函数的静态成员,并在第一次调用函数时初始化。
当解析 中的函数原型时native_functions.yaml
,解析器将设置指定的名称,就像field_name
看到-> (Tensor t1, ...)
. 这些字段名称将是namedtuple的字段名称。命名元组的类将被命名为torch.return_types.${op_name}
。
在某些Python 2中,PyStructSequence
它不是元组的子类型,因此我们必须创建一些函数来检查对象是否是元组或namedtuple,以解决兼容性问题。
中的运算符native_functions.yaml
已更改,仅max
和svd
生成为命名元组。为这两个运算符添加了测试,以查看返回值是否按预期工作。这两个操作的文档也已更新,明确提到返回值是一个命名元组。更多操作将在后续 PR 中添加。
Windows 版本的链接器存在一些无法解决的问题PyStructSequence_UnnamedField
,并且添加了一些解决方法来处理这种情况。