From 6ce0aa557d45a6555e45bc48979ad61e60a24608 Mon Sep 17 00:00:00 2001 From: "7347157+joylvliang@user.noreply.gitee.com" Date: Thu, 24 Mar 2022 20:31:24 +0800 Subject: [PATCH] chenge_singel_to_tuple_for_backward_hook --- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 15 ++++++++++++++- mindspore/ccsrc/pybind_api/ir/primitive_py.h | 1 + .../pynative/hook/test_pynative_backward_hook.py | 2 +- .../pynative/hook/test_pynative_forward_hook.py | 4 ++-- tests/st/pynative/test_pynative_hook_grad.py | 2 +- 5 files changed, 19 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 121f53da4177..248f91c7cb5a 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -252,6 +252,19 @@ void PrimitivePy::RemoveBackwardHookFn(const int &key) { } } +py::object PrimitivePy::UnpackRetValueOfCellHook(const py::object &grad_out) const { + if (!py::isinstance(grad_out)) { + hook_grad_.clear(); + MS_EXCEPTION(TypeError) << "The return gradient of cell backward hook function should be a tuple!"; + } + auto out_tuple = py::cast(grad_out); + if (out_tuple.size() == 1) { + // The input number of current cell is 1. + return out_tuple[0]; + } + return grad_out; +} + void PrimitivePy::CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out, const py::object &code_obj, const py::object &co_name) const { if (py::isinstance(expected_grad_out)) { @@ -345,7 +358,7 @@ BaseRef PrimitivePy::RunCellHookFunction(const py::tuple &py_args) const { py::tuple hook_fn_args = ConstructCellHookFnArgs(cell_id, iter->second, grad_output); py::object ret = elem.second(*hook_fn_args); if (!py::isinstance(ret)) { - grad_output = ret; + grad_output = UnpackRetValueOfCellHook(ret); } CheckHookConsistency(grad_output, py_args[args_size - 1], code_obj, co_name); } diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.h b/mindspore/ccsrc/pybind_api/ir/primitive_py.h index 09856c0b3671..733bc1959c28 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.h +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.h @@ -78,6 +78,7 @@ class PrimitivePy : public Primitive { private: py::function GetComputeFunction() const; + py::object UnpackRetValueOfCellHook(const py::object &grad_out) const; void CheckHookConsistency(const py::object &grad_out, const py::object &expected_grad_out, const py::object &code_obj, const py::object &co_name) const; py::object python_obj_; diff --git a/tests/st/pynative/hook/test_pynative_backward_hook.py b/tests/st/pynative/hook/test_pynative_backward_hook.py index 108226ed062f..16fe113a35fa 100644 --- a/tests/st/pynative/hook/test_pynative_backward_hook.py +++ b/tests/st/pynative/hook/test_pynative_backward_hook.py @@ -54,7 +54,7 @@ def backward_hook_fn3(cell_id, grad_inp, grad_outp): def backward_hook_fn4(cell_id, grad_inp, grad_outp): - return Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 10) + return (Tensor(np.ones([2, 2, 2, 2]).astype(np.float32) * 10),) class Net(nn.Cell): diff --git a/tests/st/pynative/hook/test_pynative_forward_hook.py b/tests/st/pynative/hook/test_pynative_forward_hook.py index f8d95ce1708d..9253adf3dbd2 100644 --- a/tests/st/pynative/hook/test_pynative_forward_hook.py +++ b/tests/st/pynative/hook/test_pynative_forward_hook.py @@ -70,12 +70,12 @@ def forward_hook_fn_with_ms_func(cell_id, inp, outp): def backward_hook_fn(cell_id, grad_inp, grad_outp): print("Enter backward hook function.") - return grad_outp[0] + return grad_outp def backward_hook_fn_inner(cell_id, grad_inp, grad_outp): print("Enter backward hook function inner.") - return grad_outp[0] + return grad_outp class SingleNet(nn.Cell): diff --git a/tests/st/pynative/test_pynative_hook_grad.py b/tests/st/pynative/test_pynative_hook_grad.py index d240ccf12268..768a531b6832 100644 --- a/tests/st/pynative/test_pynative_hook_grad.py +++ b/tests/st/pynative/test_pynative_hook_grad.py @@ -48,7 +48,7 @@ class HookBase(MetaFactory): mul = P.Mul() grad = grad_output[0] output = mul(grad, y) - return output + return (output,) class FinalNet(nn.Cell, HookBase): def __init__(self): -- Gitee