From 4f7c9cb27f9a34f52eb753354ccc96af83f54fa1 Mon Sep 17 00:00:00 2001 From: huangbingjian Date: Thu, 10 Mar 2022 20:09:32 +0800 Subject: [PATCH] Support user-defined classes through ms_class decorators. --- .../optimizer/irpass/symbol_resolver.cc | 90 ++-- .../optimizer/irpass/symbol_resolver.h | 1 + .../pipeline/jit/parse/data_converter.cc | 10 + .../ccsrc/pipeline/jit/parse/parse_base.h | 2 + mindspore/ccsrc/pipeline/jit/parse/resolve.cc | 36 +- mindspore/ccsrc/pipeline/jit/parse/resolve.h | 19 +- mindspore/ccsrc/pybind_api/export_flags.cc | 1 + mindspore/ccsrc/pybind_api/export_flags.h | 1 + mindspore/ccsrc/utils/convert_utils_py.cc | 6 + .../mindspore/_extends/parse/__init__.py | 6 +- .../python/mindspore/_extends/parse/parser.py | 28 +- mindspore/python/mindspore/common/__init__.py | 4 +- mindspore/python/mindspore/common/api.py | 55 ++- .../ut/python/fallback/test_graph_fallback.py | 111 ----- .../fallback/test_graph_fallback_class.py | 398 ++++++++++++++++++ 15 files changed, 600 insertions(+), 168 deletions(-) create mode 100644 tests/ut/python/fallback/test_graph_fallback_class.py diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.cc b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.cc index 36a6d2751309..14e88df1695b 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.cc @@ -23,58 +23,63 @@ namespace mindspore { namespace opt { namespace irpass { -// {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} -// {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr} -// {prim::kPrimGetAttr, namespace, attr} -// {prim::kPrimGetAttr, bool, attr} +// {prim::kPrimGetAttr, object, attr} // {prim::kPrimResolve, namespace, symbol} AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) { - PatternNode getattr_operand, ns_node, sym_node, attr_node, bool_node; - auto GetAttrResolveLambda = [&node, &getattr_operand, &attr_node, &optimizer]() -> AnfNodePtr { - auto getattr_operand_node = getattr_operand.GetNode(node); - auto attr = attr_node.GetNode(node); + PatternNode object, attr, ns_node, sym_node; + auto GetAttrLambda = [&node, &object, &attr, &optimizer]() -> AnfNodePtr { + auto object_node = object.GetNode(node); + auto attr_node = attr.GetNode(node); // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} - if (IsPrimitiveCNode(getattr_operand_node, prim::kPrimResolve)) { - auto [name_space, symbol] = parse::GetNamespaceAndSymbol(getattr_operand_node); + if (IsPrimitiveCNode(object_node, prim::kPrimResolve)) { + auto [name_space, symbol] = parse::GetNamespaceAndSymbol(object_node); auto module_name = name_space->module(); constexpr std::string_view parse_super_name = "namespace"; if (module_name.find(parse::RESOLVE_NAMESPACE_NAME_CLASS_MEMBER) != std::string::npos && symbol->symbol() != parse_super_name) { - auto obj = parse::GetSymbolObject(name_space, symbol, node); - return parse::ResolveCellWithAttr(optimizer->manager(), obj, getattr_operand_node, attr); + auto symbol_obj = parse::GetSymbolObject(name_space, symbol, node); + return parse::ResolveCellWithAttr(optimizer->manager(), symbol_obj, object_node, attr_node); } } // {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr} - auto operand_cnode = getattr_operand_node->cast(); - constexpr size_t getitem_inputs_size = 3; - if (operand_cnode != nullptr && operand_cnode->size() == getitem_inputs_size) { - constexpr auto prim_index = 0; + if (parse::IsGetItemCNode(object_node)) { + auto getitem_cnode = object_node->cast(); constexpr auto resolve_index = 1; constexpr auto index_index = 2; - auto prim_node = operand_cnode->input(prim_index); - auto resolve_node = operand_cnode->input(resolve_index); - auto index_node = operand_cnode->input(index_index); - if (!parse::IsResolveNodeWithGetItem(prim_node) || !IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) { - return nullptr; + auto resolve_node = getitem_cnode->input(resolve_index); + auto index_node = getitem_cnode->input(index_index); + if (IsPrimitiveCNode(resolve_node, prim::kPrimResolve)) { + auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node); + auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node); + if (py::isinstance(obj) || py::isinstance(obj)) { + return parse::ResolveSequenceWithAttr(optimizer->manager(), obj, resolve_node, attr_node, getitem_cnode); + } + return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr_node); } - auto [name_space, symbol] = parse::GetNamespaceAndSymbol(resolve_node); - auto obj = parse::GetObjectFromSequence(name_space, symbol, resolve_node, index_node); - if (py::isinstance(obj) || py::isinstance(obj)) { - return parse::ResolveSequenceWithAttr(optimizer->manager(), obj, resolve_node, attr, operand_cnode); - } - return parse::ResolveCellWithAttr(optimizer->manager(), obj, resolve_node, attr); } - return nullptr; - }; - auto GetAttrLambda = [&node, &ns_node, &attr_node, &optimizer]() -> AnfNodePtr { - auto name_space = GetValueNode(ns_node.GetNode(node)); - auto str = GetValue(GetValueNode(attr_node.GetNode(node))); - parse::SymbolPtr symbol = std::make_shared(str); - auto manager = optimizer->manager(); - return parse::ResolveSymbol(manager, name_space, symbol, node); + // {prim::kPrimGetAttr, namespace, attr} + if (IsValueNode(object_node)) { + auto name_space = GetValueNode(object_node); + auto attr_str = GetValue(GetValueNode(attr_node)); + parse::SymbolPtr symbol = std::make_shared(attr_str); + return parse::ResolveSymbol(optimizer->manager(), name_space, symbol, node); + } + + // {prim::kPrimGetAttr, MsClassObject, attr} + if (IsValueNode(object_node)) { + auto ms_class = GetValueNode(object_node); + auto attr_str = GetValue(GetValueNode(attr_node)); + return parse::ResolveMsClassWithAttr(optimizer->manager(), ms_class, attr_str, node); + } + + // {prim::kPrimGetAttr, bool, attr} + if (IsValueNode(object_node)) { + return object_node; + } + return nullptr; }; auto ResolveLambda = [&node, &ns_node, &sym_node, &optimizer]() -> AnfNodePtr { @@ -84,18 +89,9 @@ AnfNodePtr Resolver::operator()(const OptimizerPtr &optimizer, const AnfNodePtr return parse::ResolveSymbol(manager, name_space, symbol, node); }; - // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} - // {prim::kPrimGetAttr, {getitem, {prim::kPrimResolve, namespace, symbol}, index}, attr} - MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, getattr_operand, attr_node), GetAttrResolveLambda, - attr_node.CheckFunc(IsValueNode, node)); - // {prim::kPrimGetAttr, namespace, attr} - MATCH_REPLACE_LAMBDA_IF( - node, PPrimitive(prim::kPrimGetAttr, ns_node, attr_node), GetAttrLambda, - ns_node.CheckFunc(IsValueNode, node) && attr_node.CheckFunc(IsValueNode, node)); - // {prim::kPrimGetAttr, bool, attr} - MATCH_REPLACE_IF( - node, PPrimitive(prim::kPrimGetAttr, bool_node, attr_node), bool_node, - bool_node.CheckFunc(IsValueNode, node) && attr_node.CheckFunc(IsValueNode, node)); + // {prim::kPrimGetAttr, object, attr} + MATCH_REPLACE_LAMBDA_IF(node, PPrimitive(prim::kPrimGetAttr, object, attr), GetAttrLambda, + attr.CheckFunc(IsValueNode, node)); // {prim::kPrimResolve, namespace, symbol} MATCH_REPLACE_LAMBDA_IF( node, PPrimitive(prim::kPrimResolve, ns_node, sym_node), ResolveLambda, diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h index 3919effc7579..2500cab0ae85 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/symbol_resolver.h @@ -40,6 +40,7 @@ namespace irpass { // {prim::kPrimGetAttr, {prim::kPrimTupleGetItem, {prim::kPrimResolve, namespace, symbol}, index}, attr} // {prim::kPrimGetAttr, {prim::kPrimResolve, namespace, symbol}, attr} // {prim::kPrimGetAttr, namespace, attr} +// {prim::kPrimGetAttr, MsClassObject, attr} // {prim::kPrimGetAttr, bool, attr} // {prim::kPrimResolve, namespace, symbol} class Resolver : public OptimizerCaller { diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index 8e40763df906..12b88c2cae43 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -253,6 +253,15 @@ ValuePtr ConvertDataClass(const py::object &obj) { return converted; } +ValuePtr ConvertMsClass(const py::object &obj) { + MS_LOG(DEBUG) << "Converting ms class"; + // Convert class instance decorated with ms_class. + py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); + py::object name = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MS_CLASS_NAME, obj); + auto cls_name = py::cast(name); + return std::make_shared(obj, cls_name); +} + ValuePtr ConvertPrimitive(const py::object &obj, bool use_signature = false) { MS_LOG(DEBUG) << "Converting primitive object" << use_signature; @@ -502,6 +511,7 @@ static const std::vector &GetDataConverters() { std::make_shared>(kEllipsis), std::make_shared>(ConvertModuleNameSpace), std::make_shared(PYTHON_DATACLASS_FIELDS, ConvertDataClass), + std::make_shared(PYTHON_MS_CLASS, ConvertMsClass), std::make_shared>(ObjCast), std::make_shared>(ObjCast), std::make_shared>(ObjCast), diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h index bd818d188269..28d231eeb770 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/jit/parse/parse_base.h @@ -67,6 +67,8 @@ const char PYTHON_MOD_CREATE_INSTANCE[] = "create_instance"; const char PYTHON_MOD_IS_SUPPORTED_CREATE_INSTANCE_TYPE[] = "is_supported_create_instance_type"; const char PYTHON_MOD_GET_DATACLASS_ATTRS[] = "get_dataclass_attributes"; const char PYTHON_MOD_GET_DATACLASS_METHODS[] = "get_dataclass_methods"; +const char PYTHON_MOD_GET_MS_CLASS_NAME[] = "get_ms_class_name"; +const char PYTHON_MOD_GET_MS_CLASS_ATTR[] = "get_ms_class_attr"; const char PYTHON_MOD_GET_MODULE_NAMESPACE[] = "get_module_namespace"; const char PYTHON_MOD_GET_ATTR_NAMESPACE_SYMBOL[] = "get_class_attr_namespace_symbol"; const char PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL[] = "get_class_member_namespace_symbol"; diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index d4e29df4a79c..a57670c9924f 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -307,7 +307,7 @@ AnfNodePtr ResolveObjectAndAddToManager(const FuncGraphManagerPtr &manager, cons AnfNodePtr resolved_node = nullptr; bool success = ResolveObjectToNode(node->func_graph(), obj, &resolved_node); if (!success) { - MS_LOG(EXCEPTION) << "Parse Resolve covert failed NodeInfo."; + MS_LOG(EXCEPTION) << "Parse Resolve covert failed."; } if (IsValueNode(resolved_node)) { auto new_fg = GetValueNode(resolved_node); @@ -465,6 +465,40 @@ bool IsResolveNodeWithGetItem(const AnfNodePtr &node) { return false; } +bool IsGetItemCNode(const AnfNodePtr &node) { + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + constexpr size_t getitem_inputs_size = 3; + if (cnode->size() != getitem_inputs_size) { + return false; + } + constexpr auto prim_index = 0; + return IsResolveNodeWithGetItem(cnode->input(prim_index)); +} + +AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsClassObjectPtr &ms_class, + const std::string &attr, const AnfNodePtr &node) { + // Get attribute or method from ms_class obj. + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(node); + MS_LOG(DEBUG) << "Resolve ms_class obj (" << ms_class->name() << ") with attr " << attr << "."; + TraceGuard trace_guard(std::make_shared(node->debug_info())); + + py::object cls_obj = ms_class->obj(); + if (!py::hasattr(cls_obj, attr.c_str())) { + MS_LOG(EXCEPTION) << ms_class->name() << " has not attribute: " << attr << "."; + } + + const std::string fn = PYTHON_MOD_GET_MS_CLASS_ATTR; + const std::string module = "mindspore._extends.parse.parser"; + py::object attr_obj = python_adapter::GetPyFn(module, fn)(cls_obj, attr); + AnfNodePtr res_node = ResolveObjectAndAddToManager(manager, attr_obj, node); + TraceManager::ClearParseOrResolveDebugInfo(); + return res_node; +} + namespace { opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { // For resolve and getattr primitive. diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.h b/mindspore/ccsrc/pipeline/jit/parse/resolve.h index f6f8cd371bec..122c6724cc3f 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.h @@ -131,6 +131,18 @@ class InterpretedObject final : public PyObjectWrapper { }; using InterpretedObjectPtr = std::shared_ptr; +class MsClassObject final : public PyObjectWrapper { + public: + explicit MsClassObject(const py::object &obj, const std::string &name = "ms class") + : PyObjectWrapper(obj, "MsClassObject: \'" + name + "\'") {} + ~MsClassObject() override = default; + MS_DECLARE_PARENT(MsClassObject, PyObjectWrapper); + abstract::AbstractBasePtr ToAbstract() override { + return std::make_shared(shared_from_base(), std::make_shared()); + } +}; +using MsClassObjectPtr = std::shared_ptr; + // ClassObject class wrappers dataclass class ClassObject final : public PyObjectWrapper { public: @@ -168,8 +180,11 @@ AnfNodePtr ResolveCellWithAttr(const FuncGraphManagerPtr &manager, const py::obj AnfNodePtr ResolveSequenceWithAttr(const FuncGraphManagerPtr &manager, const py::object &obj, const AnfNodePtr &resolve_node, const AnfNodePtr &attr, const CNodePtr &operand_cnode); -// Check if node is resolve node with getitem. -bool IsResolveNodeWithGetItem(const AnfNodePtr &node); +AnfNodePtr ResolveMsClassWithAttr(const FuncGraphManagerPtr &manager, const MsClassObjectPtr &ms_class, + const std::string &attr, const AnfNodePtr &node); + +// Check if node is cnode with getitem. +bool IsGetItemCNode(const AnfNodePtr &node); // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); diff --git a/mindspore/ccsrc/pybind_api/export_flags.cc b/mindspore/ccsrc/pybind_api/export_flags.cc index 909951d47004..350f9df29aa4 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.cc +++ b/mindspore/ccsrc/pybind_api/export_flags.cc @@ -19,5 +19,6 @@ namespace mindspore { const char PYTHON_PRIMITIVE_FLAG[] = "__primitive_flag__"; const char PYTHON_CELL_AS_LIST[] = "__cell_as_list__"; const char PYTHON_DATACLASS_FIELDS[] = "__dataclass_fields__"; +const char PYTHON_MS_CLASS[] = "__ms_class__"; const char PYTHON_CLASS_MEMBER_NAMESPACE[] = "__class_member_namespace__"; } // namespace mindspore diff --git a/mindspore/ccsrc/pybind_api/export_flags.h b/mindspore/ccsrc/pybind_api/export_flags.h index 56e0a87ead66..117e22df1a8d 100644 --- a/mindspore/ccsrc/pybind_api/export_flags.h +++ b/mindspore/ccsrc/pybind_api/export_flags.h @@ -22,6 +22,7 @@ namespace mindspore { extern const char PYTHON_PRIMITIVE_FLAG[]; extern const char PYTHON_CELL_AS_LIST[]; extern const char PYTHON_DATACLASS_FIELDS[]; +extern const char PYTHON_MS_CLASS[]; extern const char PYTHON_CLASS_MEMBER_NAMESPACE[]; } // namespace mindspore diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index 868c3554c168..c9951cdbdbd7 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -220,6 +220,12 @@ static ValueNameToConverterVector value_name_to_converter = { auto class_type = value->cast(); return class_type->obj(); }}, + // parse::MsClassObject + {parse::MsClassObject::kTypeId, + [](const ValuePtr &value) -> py::object { + auto ms_class_object = value->cast(); + return ms_class_object->obj(); + }}, // parse::InterpretedObject {parse::InterpretedObject::kTypeId, [](const ValuePtr &value) -> py::object { diff --git a/mindspore/python/mindspore/_extends/parse/__init__.py b/mindspore/python/mindspore/_extends/parse/__init__.py index e5494b7bdf80..99b0a40e7fe6 100644 --- a/mindspore/python/mindspore/_extends/parse/__init__.py +++ b/mindspore/python/mindspore/_extends/parse/__init__.py @@ -23,7 +23,8 @@ from .parser import (Parser, create_instance, is_supported_create_instance_type, get_args, get_args_default_values, get_ast_namespace_symbol, get_operation_symbol, get_operation_namespace_symbol, get_parse_method_of_class, get_scope_name, eval_script, expand_expr_statement, is_class_member, parse_cb, resolve_symbol, convert_to_ms_tensor, - get_object_description, get_class_attr_namespace_symbol) + get_object_description, get_class_attr_namespace_symbol, get_ms_class_name, + get_ms_class_attr) __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', 'get_object_key', 'get_class_instance_type', 'is_class_member', 'get_ast_type', 'get_node_type', @@ -32,4 +33,5 @@ __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'get_module_namespace', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', 'get_dataclass_methods', 'get_dataclass_methods', 'get_scope_name', 'eval_script', 'create_slice_obj', 'convert_to_ms_tensor', 'get_object_description', 'expand_expr_statement', - 'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol'] + 'generate_scope', 'get_operation_symbol', 'get_class_attr_namespace_symbol', 'get_ms_class_name', + 'get_ms_class_attr'] diff --git a/mindspore/python/mindspore/_extends/parse/parser.py b/mindspore/python/mindspore/_extends/parse/parser.py index 9dcaa766ead8..62272a85ff76 100644 --- a/mindspore/python/mindspore/_extends/parse/parser.py +++ b/mindspore/python/mindspore/_extends/parse/parser.py @@ -410,6 +410,30 @@ def get_dataclass_methods(cls): return methods +def get_ms_class_name(cls): + """Get the name of the class instance decorated by ms_class.""" + # Check if cls is nn.Cell. + if isinstance(cls, nn.Cell): + raise TypeError(f"ms_class is used for user-defined classes and cannot be used for nn.Cell: {cls}.") + if isinstance(cls, type): + name = cls.__name__ + else: + name = cls.__class__.__name__ + # Get the name of cls. + cls_name = cls.__module__ + '.' + name + return cls_name + + +def get_ms_class_attr(cls, name: str): + """Get attribute or method of ms_class obj.""" + # Don't take into account python magic methods and private variables. + if name.startswith('_'): + raise AttributeError(f"{name} is a private variable or magic method, which is not supported.") + if not hasattr(cls, name): + raise AttributeError(f"{cls} has no attribute: {name}.") + return getattr(cls, name) + + def convert_to_ms_tensor(data): """Convert C++ tensor to mindspore tensor.""" return Tensor(data) @@ -562,8 +586,8 @@ def eval_script(exp_str, params): local_params = _convert_data(local_params) obj = eval(exp_str, global_params, local_params) except Exception as e: - error_info = f"When eval '{exp_str}' by using Fallback feature, an error occurred: " + str(e) + \ - ". You can try to turn off the Fallback feature by 'export MS_DEV_ENABLE_FALLBACK=0'." + error_info = f"When eval '{exp_str}' by using JIT Fallback feature, an error occurred: " + str(e) + \ + ". You can try to turn off JIT Fallback feature by 'export MS_DEV_ENABLE_FALLBACK=0'." logger.error(error_info) raise e diff --git a/mindspore/python/mindspore/common/__init__.py b/mindspore/python/mindspore/common/__init__.py index 8cee874cf16f..50b5554d665f 100644 --- a/mindspore/python/mindspore/common/__init__.py +++ b/mindspore/python/mindspore/common/__init__.py @@ -14,7 +14,7 @@ # ============================================================================ """Top-level reference to dtype of common module.""" from . import dtype -from .api import ms_function, ms_memory_recycle, _convert_data +from .api import ms_function, ms_memory_recycle, ms_class, _convert_data from .dtype import Type, int8, byte, int16, short, int32, intc, int64, intp, \ uint8, ubyte, uint16, ushort, uint32, uintc, uint64, uintp, float16, half, \ float32, single, float64, double, bool_, float_, list_, tuple_, int_, \ @@ -54,7 +54,7 @@ __all__ = [ __all__.extend([ "Tensor", "RowTensor", "SparseTensor", "COOTensor", "CSRTensor", # tensor - 'ms_function', # api + 'ms_function', 'ms_class', # api 'Parameter', 'ParameterTuple', # parameter "dtype", "_convert_data", "set_seed", "get_seed", # random seed diff --git a/mindspore/python/mindspore/common/api.py b/mindspore/python/mindspore/common/api.py index 656149e13afd..2d0295243a7d 100644 --- a/mindspore/python/mindspore/common/api.py +++ b/mindspore/python/mindspore/common/api.py @@ -20,6 +20,7 @@ import sys import os import time import ast +import inspect import importlib from collections import OrderedDict from functools import wraps @@ -439,12 +440,64 @@ def ms_function(fn=None, obj=None, input_signature=None): return wrap_mindspore(fn) return wrap_mindspore + +def ms_class(cls): + """ + Class decorator for user-defined classes. + + This allows MindSpore to identify user-defined classes and thus obtain their attributes and methods. + + Args: + cls (Class): User-defined class. + + Returns: + Class with __ms_class__ attribute. + + Supported Platforms: + ``Ascend`` ``GPU`` ``CPU`` + + Examples: + >>> import mindspore.nn as nn + >>> from mindspore import ms_class + ... + >>> @ms_class + >>> class UserDefinedNet: + ... def __init__(self): + ... self.value = 10 + ... + ... def func(self, x): + ... return 2 * x + ... + >>> class Net(nn.Cell): + ... def __init__(self): + ... super(Net, self).__init__() + ... self.net = UserDefinedNet() + ... + ... def construct(self, x): + ... out = self.net.value + self.net.func(x) + ... return out + ... + >>> net = Net() + >>> out = net(5) + >>> print(out) + 20 + """ + + # Check if cls is of type class. + if not inspect.isclass(cls): + raise TypeError(f'Decorator ms_class can only be used for class type, but got {cls}.') + logger.info(f'Found ms_class: {cls}.') + setattr(cls, '__ms_class__', True) + return cls + + def is_pynative_parallel(): run_mode = context.get_context('mode') parallel_mode = context.get_auto_parallel_context('parallel_mode') return run_mode == context.PYNATIVE_MODE and parallel_mode in ( context.ParallelMode.SEMI_AUTO_PARALLEL, context.ParallelMode.AUTO_PARALLEL) + def _get_auto_split_param_names(parameter_layout_dict): auto_split_param_names = [] for key, value in parameter_layout_dict.items(): @@ -899,4 +952,4 @@ def ms_memory_recycle(): _cell_graph_executor = _CellGraphExecutor() _pynative_executor = _PynativeExecutor() -__all__ = ['ms_function', 'ms_memory_recycle'] +__all__ = ['ms_function', 'ms_memory_recycle', 'ms_class'] diff --git a/tests/ut/python/fallback/test_graph_fallback.py b/tests/ut/python/fallback/test_graph_fallback.py index 4007d53492f8..2183f033bb72 100644 --- a/tests/ut/python/fallback/test_graph_fallback.py +++ b/tests/ut/python/fallback/test_graph_fallback.py @@ -243,117 +243,6 @@ def test_scipy_module(): print(out) -def test_self_attr(): - """ - Feature: JIT Fallback - Description: Test self.attr in graph. - Expectation: No exception. - """ - class Network(nn.Cell): - def __init__(self): - super(Network, self).__init__() - self.dim = 1 - - def construct(self, x): - batch = x.shape[0] - one = Tensor(np.ones([batch, self.dim]), mstype.float16) - return one * x - - net = Network() - x = Tensor([1, 2], mstype.float32) - out = net(x) - print(out) - - -def test_self_attr_2(): - """ - Feature: JIT Fallback - Description: Test self.attr in graph. - Expectation: No exception. - """ - class Network(nn.Cell): - def __init__(self, fn): - super(Network, self).__init__() - self.fn = fn - - def construct(self): - x = np.array([1, 2, 3]) - y = np.array([3, 4, 5]) - out = Tensor(self.fn(x, y)) - return out - - def fn(x, y): - return x + y - - net = Network(fn) - out = net() - print(out) - - -def test_self_attr_3(): - """ - Feature: JIT Fallback - Description: Test self.attr in graph. - Expectation: No exception. - """ - class Network(nn.Cell): - def __init__(self): - super(Network, self).__init__() - self.value = [2, 2, 3] - - def construct(self): - x = np.array(self.value.count(2)) - return Tensor(x) - - net = Network() - out = net() - print(out) - - -def test_self_method(): - """ - Feature: JIT Fallback - Description: Test self.method in graph. - Expectation: No exception. - """ - class Network(nn.Cell): - def construct(self): - x = np.array([1, 2, 3]) - y = np.array([3, 4, 5]) - out = Tensor(self.fn(x, y)) - return out - - def fn(self, x, y): - return x + y - - net = Network() - out = net() - print(out) - - -@pytest.mark.skip(reason='Not support in graph jit fallback feature yet') -def test_self_method_2(): - """ - Feature: JIT Fallback - Description: Test self.method in graph. - Expectation: No exception. - """ - class Network(nn.Cell): - def construct(self): - x = np.array([1, 2, 3]) - y = np.array([3, 4, 5]) - z = self.fn(x, y) - out = Tensor(z) - return out - - def fn(self, x, y): - return x + y - - net = Network() - out = net() - print(out) - - def test_probability_cauchy(): """ Feature: JIT Fallback diff --git a/tests/ut/python/fallback/test_graph_fallback_class.py b/tests/ut/python/fallback/test_graph_fallback_class.py new file mode 100644 index 000000000000..0bdf9d1cadd7 --- /dev/null +++ b/tests/ut/python/fallback/test_graph_fallback_class.py @@ -0,0 +1,398 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test graph fallback """ +import pytest +import numpy as np + +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor, context, ms_class + +context.set_context(mode=context.GRAPH_MODE) + + +def test_fallback_self_attr(): + """ + Feature: JIT Fallback + Description: Test self.attr in graph. + Expectation: No exception. + """ + class Network(nn.Cell): + def __init__(self): + super(Network, self).__init__() + self.dim = 1 + + def construct(self, x): + batch = x.shape[0] + one = Tensor(np.ones([batch, self.dim]), mstype.float32) + return one * x + + net = Network() + x = Tensor([1, 2], mstype.float32) + out = net(x) + expect = np.array([[1., 2.], [1., 2.]]) + assert np.allclose(out.asnumpy(), expect, 1.e-2, 1.e-2) + + +def test_fallback_self_attr_fn(): + """ + Feature: JIT Fallback + Description: Test self.attr in graph. + Expectation: No exception. + """ + class Network(nn.Cell): + def __init__(self, fn): + super(Network, self).__init__() + self.fn = fn + + def construct(self): + x = np.array([1, 2, 3]) + y = np.array([3, 4, 5]) + out = Tensor(self.fn(x, y)) + return out + + def fn(x, y): + return x + y + + net = Network(fn) + out = net() + expect = np.array([4, 6, 8]) + assert np.all(out.asnumpy() == expect) + + +def test_fallback_self_attr_attr(): + """ + Feature: JIT Fallback + Description: Test self.attr in graph. + Expectation: No exception. + """ + class Network(nn.Cell): + def __init__(self): + super(Network, self).__init__() + self.value = [2, 2, 3] + + def construct(self): + x = np.array(self.value.count(2)) + return Tensor(x) + + net = Network() + out = net() + assert out == 2 + + +def test_fallback_self_method(): + """ + Feature: JIT Fallback + Description: Test self.method in graph. + Expectation: No exception. + """ + class Network(nn.Cell): + def construct(self): + x = np.array([1, 2, 3]) + y = np.array([3, 4, 5]) + out = Tensor(self.fn(x, y)) + return out + + def fn(self, x, y): + return x + y + + net = Network() + out = net() + expect = np.array([4, 6, 8]) + assert np.all(out.asnumpy() == expect) + + +@pytest.mark.skip(reason='Not support in graph jit fallback feature yet') +def test_fallback_self_method_tensor(): + """ + Feature: JIT Fallback + Description: Test self.method in graph. + Expectation: No exception. + """ + class Network(nn.Cell): + def construct(self): + x = np.array([1, 2, 3]) + y = np.array([3, 4, 5]) + z = self.fn(x, y) + out = Tensor(z) + return out + + def fn(self, x, y): + return x + y + + net = Network() + out = net() + print(out) + + +def test_fallback_class_attr(): + """ + Feature: JIT Fallback + Description: Test user-defined class attributes in graph. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self): + self.number = 1 + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet() + + def construct(self): + out = self.inner_net.number + return out + + net = Net() + out = net() + assert out == 1 + + +def test_fallback_class_method(): + """ + Feature: JIT Fallback + Description: Test user-defined class methods in graph. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self): + self.val = 2 + + def act(self, x, y): + return self.val * (x + y) + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet() + + def construct(self): + out = self.inner_net.act(1, 2) + return out + + net = Net() + out = net() + assert out == 6 + + +def test_fallback_class_input_attr(): + """ + Feature: JIT Fallback + Description: Test user-defined class attributes in graph. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self): + self.number = Tensor(np.array([1, 2, 3])) + + class Net(nn.Cell): + def __init__(self, net): + super(Net, self).__init__() + self.inner_net = net() + + def construct(self): + out = self.inner_net.number + return out + + net = Net(InnerNet) + out = net() + expect_res = np.array([1, 2, 3]) + assert np.all(out.asnumpy() == expect_res) + + +def test_fallback_class_input_method(): + """ + Feature: JIT Fallback + Description: Test user-defined class methods in graph. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self): + self.val = 2 + + def act(self, x, y): + return self.val * (x + y) + + class Net(nn.Cell): + def __init__(self, net): + super(Net, self).__init__() + self.inner_net = net() + + def construct(self): + out = self.inner_net.act(1, 2) + return out + + net = Net(InnerNet) + out = net() + assert out == 6 + + +def test_fallback_class_class_nested(): + """ + Feature: JIT Fallback + Description: Test nested ms_class in graph. + Expectation: No exception. + """ + @ms_class + class Inner: + def __init__(self): + self.number = 1 + + @ms_class + class InnerNet: + def __init__(self): + self.inner = Inner() + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet() + + def construct(self): + out = self.inner_net.inner.number + return out + + net = Net() + out = net() + assert out == 1 + + +def test_fallback_class_cell_nested(): + """ + Feature: JIT Fallback + Description: Test nested ms_class and cell in graph. + Expectation: No exception. + """ + class Net(nn.Cell): + def __init__(self, val): + super().__init__() + self.val = val + + def construct(self, x): + return x + self.val + + @ms_class + class TrainNet(): + class Loss(nn.Cell): + def __init__(self, net): + super().__init__() + self.net = net + + def construct(self, x): + out = self.net(x) + return out * 2 + + def __init__(self, net): + self.net = net + loss_net = self.Loss(self.net) + self.number = loss_net(10) + + global_net = Net(1) + class LearnNet(nn.Cell): + def __init__(self): + super().__init__() + self.value = TrainNet(global_net).number + + def construct(self, x): + return x + self.value + + leanrn_net = LearnNet() + out = leanrn_net(3) + print(out) + assert out == 25 + + +@pytest.mark.skip(reason='Not support in graph yet') +def test_fallback_class_isinstance(): + """ + Feature: JIT Fallback + Description: Test ms_class in graph. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self): + self.number = 1 + + class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.inner_net = InnerNet() + + def construct(self, x): + if isinstance(self.inner_net, InnerNet): + return x + 10 + return x + + net = Net() + out = net(5) + assert out == 15 + + +def test_fallback_raise_error_not_class_type(): + """ + Feature: JIT Fallback + Description: Test ms_class in graph. + Expectation: No exception. + """ + with pytest.raises(TypeError): + @ms_class + def func(x, y): + return x + y + + func(1, 2) + + +def test_fallback_raise_error_not_class_instance(): + """ + Feature: JIT Fallback + Description: Test ms_class in graph. + Expectation: No exception. + """ + @ms_class + class InnerNet: + def __init__(self): + self.number = 1 + + class Net(nn.Cell): + def construct(self): + out = InnerNet().number + return out + + with pytest.raises(ValueError): + net = Net() + net() + + +def test_fallback_raise_error_decorate_cell(): + """ + Feature: JIT Fallback + Description: Test ms_class in graph. + Expectation: No exception. + """ + @ms_class + class Net(nn.Cell): + def construct(self, x): + return x + + with pytest.raises(TypeError): + x = Tensor(1) + net = Net() + net(x) -- Gitee