diff --git a/mindspore/ccsrc/kernel/framework_utils.cc b/mindspore/ccsrc/kernel/framework_utils.cc index aac72d28d34fa7ff8cd202e18a59d29e89aadfbb..74d10334f8e5cf4c683c1e123d7ad3c866fb1aab 100644 --- a/mindspore/ccsrc/kernel/framework_utils.cc +++ b/mindspore/ccsrc/kernel/framework_utils.cc @@ -756,10 +756,22 @@ std::vector GetReduceAttrAxis(const CNodePtr &cnode) { return {}; } std::vector axis_list; - if (axis_attr->isa()) { - (void)axis_list.emplace_back(GetValue(axis_attr)); + if (axis_attr->isa() || axis_attr->isa()) { + (void)axis_list.emplace_back(AnfUtils::GetIntValue(axis_attr)); + } else if (axis_attr->isa()) { + auto axis_vec = axis_attr->cast()->value(); + if (axis_vec.empty() || axis_vec[0]->isa() || axis_vec[0]->isa()) { + for (auto ax : axis_vec) { + auto ax_int64 = AnfUtils::GetIntValue(ax); + axis_list.push_back(ax_int64); + } + } else { + MS_LOG(EXCEPTION) << "Axis of reduce node[" << cnode->fullname_with_scope() << "] should be int32 or int64, but " + << "got " << axis_attr->ToString(); + } } else { - axis_list = GetValue>(axis_attr); + MS_LOG(EXCEPTION) << "Axis of reduce node[" << cnode->fullname_with_scope() << "] should be int32 or int64, but " + << "got " << axis_attr->ToString(); } return axis_list; }