From eeb6c5fc2a211e84f95eefeee150f2945053f64a Mon Sep 17 00:00:00 2001 From: yinglailin Date: Sun, 28 Apr 2024 16:53:09 +0800 Subject: [PATCH] =?UTF-8?q?"=E6=96=B0=E5=A2=9E=20MatMul+Bias+AllReduce=20?= =?UTF-8?q?=E4=B8=8Bbias=E4=B8=8E=20allreduce=20swap=20=E7=9A=84=20pass"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../parallel/pass/bias_add_comm_swap.cc | 198 ++++++++++++++++++ .../parallel/pass/bias_add_comm_swap.h | 28 +++ mindspore/ccsrc/pipeline/jit/ps/pass.cc | 8 + .../ccsrc/pybind_api/utils/ms_context_py.cc | 1 + mindspore/core/utils/ms_context.cc | 1 + mindspore/core/utils/ms_context.h | 1 + mindspore/python/mindspore/context.py | 1 + .../parallel/test_bias_add_comm_swap.py | 117 +++++++++++ 8 files changed, 355 insertions(+) create mode 100644 mindspore/ccsrc/frontend/parallel/pass/bias_add_comm_swap.cc create mode 100644 mindspore/ccsrc/frontend/parallel/pass/bias_add_comm_swap.h create mode 100644 tests/ut/python/parallel/test_bias_add_comm_swap.py diff --git a/mindspore/ccsrc/frontend/parallel/pass/bias_add_comm_swap.cc b/mindspore/ccsrc/frontend/parallel/pass/bias_add_comm_swap.cc new file mode 100644 index 000000000000..a1902859301c --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/pass/bias_add_comm_swap.cc @@ -0,0 +1,198 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/parallel/pass/bias_add_comm_swap.h" +#include +#include +#include +#include +#include +#include "include/common/utils/utils.h" +#include "frontend/parallel/step_parallel.h" +#include "frontend/parallel/step_parallel_utils.h" +#include "mindspore/core/ops/math_ops.h" +#include "mindspore/core/ops/other_ops.h" + +namespace mindspore { +namespace parallel { +namespace { +constexpr const char BIAS_ADD_COMM_SWAP[] = "bias_add_comm_swap"; + +bool IsSubRankList(const RankList &child_list, const RankList &parent_list) { + for (auto &child : child_list) { + if (std::find(parent_list.begin(), parent_list.end(), child) == parent_list.end()) { + return false; + } + } + return true; +} +bool IsAddNodeValid(const CNodePtr &add_node, const AnfNodePtr &comm_node) { + OperatorInfoPtr add_distribute_operator = add_node->user_data(); + if (add_distribute_operator == nullptr) { + return false; + } + TensorInfo node_add_tensor_in = add_distribute_operator->inputs_tensor_info()[LongToSize(1)]; + TensorLayout node_add_tensor_layout = node_add_tensor_in.tensor_layout(); + auto node_add_rank_list = node_add_tensor_layout.InferRepeatedGroup(); + + auto comm_prim = GetCNodePrimitive(comm_node); + if (!comm_prim->HasAttr(GROUP)) { + return false; + } + auto comm_group = GetValue(comm_prim->GetAttr(GROUP)); + MS_EXCEPTION_IF_NULL(g_device_manager); + auto comm_rank_list = g_device_manager->FindRankListByHashName(comm_group); + return IsSubRankList(comm_rank_list, node_add_rank_list); +} + +// find matmul node +AnfNodePtr FindMatMulNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto matmul_node = GetInputNodeWithFilter(node, [&](const CNodePtr &cnode) { + bool filter = !IsPrimitiveCNode(cnode, prim::kPrimMatMul); + return std::make_pair(filter, 1); + }); + return matmul_node; +} + +// find allreduce/reduce_scatter node +AnfNodePtr FindValidCommNode(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto comm_node = GetInputNodeWithFilter(node, [&](const AnfNodePtr &anode) { + bool filter = !IsPrimitiveCNode(anode, prim::kPrimAllReduce) && !IsPrimitiveCNode(anode, prim::kPrimReduceScatter); + return std::make_pair(filter, 1); + }); + if (comm_node == nullptr || + (!IsPrimitiveCNode(comm_node, prim::kPrimAllReduce) && !IsPrimitiveCNode(comm_node, prim::kPrimReduceScatter))) { + return nullptr; + } + auto matmul_node = FindMatMulNode(comm_node); + if (matmul_node == nullptr || !IsPrimitiveCNode(matmul_node, prim::kPrimMatMul)) { + return nullptr; + } + return comm_node; +} + +void FindAllValidAddNode(const FuncGraphPtr &graph, HashMap *add_node_map) { + std::list graph_orders = graph->GetOrderedCnodes(); + std::vector origin_nodes_topological(graph_orders.cbegin(), graph_orders.cend()); + for (const auto &node : origin_nodes_topological) { + if (!IsPrimitiveCNode(node, prim::kPrimAdd)) { + MS_LOG(INFO) << "For cur node, it must be node add and its strategy must be all ones, but got " + << node->DebugString(); + continue; + } + + auto comm_node = FindValidCommNode(node->cast()); + if (comm_node == nullptr) { + MS_LOG(INFO) << "For cur node, cannot find valid comm node, cur node is " << node->DebugString(); + continue; + } + if (!IsAddNodeValid(node, comm_node)) { + MS_LOG(INFO) << "For cur node, its strategy not equal to comm node, cur node is " << node->DebugString() + << " comm node is " << comm_node->DebugString(); + continue; + } + (*add_node_map)[node] = comm_node; + } +} + +void HandleNodePullUp(const AnfNodePtr &comm_node, const CNodePtr &add_node) { + auto graph = comm_node->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + // handle matmul node, connect it to next node of reduce_scatter/allreduce + auto comm_node_input = comm_node->cast()->input(1); + (void)manager->Replace(comm_node, comm_node_input); +} + +void HandleNodeBiasAdd(const AnfNodePtr &comm_node, const CNodePtr &add_node) { + auto comm_prim = GetCNodePrimitive(comm_node); + if (!comm_prim->HasAttr(kAttrRankSize)) { + MS_LOG(ERROR) << "cur prim has not attr " << kAttrRankSize << ", cur node is " << comm_node->DebugString(); + return; + } + auto rank_size = GetValue(comm_prim->GetAttr(kAttrRankSize)); + auto bias_node = add_node->input(2); + const auto bias_dtype = bias_node->abstract()->cast(); + MS_EXCEPTION_IF_NULL(bias_dtype); + mindspore::tensor::TensorPtr tensor_ptr = + std::make_shared(rank_size, bias_dtype->element()->GetType()); + auto const_node = NewValueNode(MakeValue(tensor_ptr)); + const_node->set_abstract(bias_node->abstract()); + AnfNodePtrList div_node_inputs = {NewValueNode(prim::kPrimRealDiv), bias_node, const_node}; + + auto fg = comm_node->func_graph(); + auto div_node = fg->NewCNode(div_node_inputs); + div_node->set_abstract(bias_node->abstract()->Clone()); + auto graph = comm_node->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + (void)manager->Replace(bias_node, div_node); +} + +void HandleNodePullDown(const AnfNodePtr &comm_node, const CNodePtr &add_node) { + auto graph = comm_node->func_graph(); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + AnfNodePtrList new_comm_node_inputs = {comm_node->cast()->input(0), add_node}; + auto new_comm_node = graph->NewCNode(new_comm_node_inputs); + new_comm_node->set_abstract(comm_node->abstract()); + auto prim = GetCNodePrimitive(new_comm_node); + (void)prim->AddAttr(BIAS_ADD_COMM_SWAP, MakeValue(true)); + (void)manager->Replace(add_node, new_comm_node); +} + +void HandleAddNode(HashMap *add_node_map) { + for (auto node_pair : (*add_node_map)) { + auto add_node = node_pair.first; + auto comm_node = node_pair.second; + HandleNodePullUp(comm_node, add_node); + HandleNodeBiasAdd(comm_node, add_node); + // pull down comm node, change add node user's input to allreduce + HandleNodePullDown(comm_node, add_node); + } +} +} // namespace + +void BiasAddCommSwap(const FuncGraphPtr &graph) { + if (parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kSemiAutoParallel && + parallel::ParallelContext::GetInstance()->parallel_mode() != parallel::kAutoParallel) { + MS_LOG(INFO) << "BiasAddCommSwap is only support under [semi_]auto_parallel, skip it."; + return; + } + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (!ms_context->get_param(MS_CTX_BIAS_ADD_COMM_SWAP)) { + return; + } + + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + HashMap add_node_map; + for (auto &each_graph : manager->func_graphs()) { + FindAllValidAddNode(each_graph, &add_node_map); + } + // pull up add node, pull down allreduce/reduce_scatter node + HandleAddNode(&add_node_map); +} +} // namespace parallel +} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/pass/bias_add_comm_swap.h b/mindspore/ccsrc/frontend/parallel/pass/bias_add_comm_swap.h new file mode 100644 index 000000000000..a31c88d72e5e --- /dev/null +++ b/mindspore/ccsrc/frontend/parallel/pass/bias_add_comm_swap.h @@ -0,0 +1,28 @@ +/** + * Copyright 2024 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_BIAS_ADD_COMM_SWAP_H +#define MINDSPORE_BIAS_ADD_COMM_SWAP_H +#include "ir/anf.h" + +namespace mindspore { +namespace parallel { +// Pull down allreduce +void BiasAddCommSwap(const FuncGraphPtr &graph); +} // namespace parallel +} // namespace mindspore + +#endif // MINDSPORE_BIAS_ADD_COMM_SWAP_H diff --git a/mindspore/ccsrc/pipeline/jit/ps/pass.cc b/mindspore/ccsrc/pipeline/jit/ps/pass.cc index 70def375b944..89b3f2e7119f 100644 --- a/mindspore/ccsrc/pipeline/jit/ps/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/ps/pass.cc @@ -59,6 +59,7 @@ #include "frontend/parallel/pass/float32_redistribution.h" #include "frontend/parallel/pass/merge_cast_opt.h" #include "frontend/parallel/pass/remove_cast_before_assign_add.h" +#include "frontend/parallel/pass/bias_add_comm_swap.h" #include "frontend/parallel/pass/comp_comm_scheduling.h" #include "frontend/parallel/pass/overlap_opt_shard_in_pipeline.h" #include "frontend/parallel/pass/slice_activation_in_cell_share_recompute.h" @@ -808,6 +809,12 @@ bool RemoveCastBeforeAssignAdd(const ResourcePtr &resource) { return true; } +bool BiasAddCommSwap(const ResourcePtr &resource) { + MS_EXCEPTION_IF_NULL(resource); + parallel::BiasAddCommSwap(resource->func_graph()); + return true; +} + bool ReorderSendRecvBetweenFpBpPass(const ResourcePtr &resource) { MS_EXCEPTION_IF_NULL(resource); parallel::ReorderSendRecvBetweenFpBp(resource->func_graph()); @@ -1195,6 +1202,7 @@ std::vector kVmPasses = {{"py_interpret_to_execute", PyInterpretToExec {"add_recomputation", AddRecomputationPass}, {"cse_after_recomputation", OptAfterRecomputeGroup}, {"environ_conv", EnvironConversionPass}, + {"bias_add_comm_swap", BiasAddCommSwap}, {"label_micro_interleaved_index", LabelMicroInterleavedIndexPass}, {"label_fine_grained_interleaved_index", LabelFineGrainedInterleavedIndexPass}, {"merge_cast_opt", MergeCastOpt}, diff --git a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc index 6c0de5a6173a..11f544776b7c 100644 --- a/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc +++ b/mindspore/ccsrc/pybind_api/utils/ms_context_py.cc @@ -133,6 +133,7 @@ void RegMsContext(const py::module *m) { .value("debug_level", MsCtxParam::MS_CTX_DEBUG_LEVEL) .value("interleaved_matmul_comm", MsCtxParam::MS_CTX_INTERLEAVED_MATMUL_COMM) .value("interleaved_layernorm_comm", MsCtxParam::MS_CTX_INTERLEAVED_LAYERNORM_COMM) + .value("bias_add_comm_swap", MsCtxParam::MS_CTX_BIAS_ADD_COMM_SWAP) .value("enable_begin_end_inline_opt", MsCtxParam::MS_CTX_ENABLE_BEGIN_END_INLINE_OPT) .value("enable_concat_eliminate_opt", MsCtxParam::MS_CTX_ENABLE_CONCAT_ELIMINATE_OPT) .value("host_scheduling_max_threshold", MsCtxParam::MS_CTX_HOST_SCHEDULING_MAX_THRESHOLD) diff --git a/mindspore/core/utils/ms_context.cc b/mindspore/core/utils/ms_context.cc index bdf470e3410b..3d4b85cfd1d7 100644 --- a/mindspore/core/utils/ms_context.cc +++ b/mindspore/core/utils/ms_context.cc @@ -121,6 +121,7 @@ MsContext::MsContext(const std::string &policy, const std::string &target) { set_param(MS_CTX_ENABLE_GRAD_COMM_OPT, false); set_param(MS_CTX_INTERLEAVED_MATMUL_COMM, false); set_param(MS_CTX_INTERLEAVED_LAYERNORM_COMM, false); + set_param(MS_CTX_BIAS_ADD_COMM_SWAP, false); set_param(MS_CTX_ENABLE_BEGIN_END_INLINE_OPT, false); set_param(MS_CTX_ENABLE_CONCAT_ELIMINATE_OPT, false); set_param(MS_CTX_OP_TIMEOUT, kOpTimeout); diff --git a/mindspore/core/utils/ms_context.h b/mindspore/core/utils/ms_context.h index 6beb5160a7ab..30e7c2e1a4e1 100644 --- a/mindspore/core/utils/ms_context.h +++ b/mindspore/core/utils/ms_context.h @@ -123,6 +123,7 @@ enum MsCtxParam : unsigned { MS_CTX_ENABLE_OPT_SHARD_COMM_OPT, MS_CTX_INTERLEAVED_MATMUL_COMM, MS_CTX_INTERLEAVED_LAYERNORM_COMM, + MS_CTX_BIAS_ADD_COMM_SWAP, MS_CTX_ENABLE_COMPILE_CACHE, MS_CTX_CONV_ALLOW_TF32, MS_CTX_MATMUL_ALLOW_TF32, diff --git a/mindspore/python/mindspore/context.py b/mindspore/python/mindspore/context.py index 40e7934234c0..8667af39f00f 100644 --- a/mindspore/python/mindspore/context.py +++ b/mindspore/python/mindspore/context.py @@ -701,6 +701,7 @@ class _Context: "enable_task_opt": (ms_ctx_param.enable_task_opt, bool), "enable_grad_comm_opt": (ms_ctx_param.enable_grad_comm_opt, bool), "interleaved_matmul_comm": (ms_ctx_param.interleaved_matmul_comm, bool), + "bias_add_comm_swap": (ms_ctx_param.bias_add_comm_swap, bool), "enable_opt_shard_comm_opt": (ms_ctx_param.enable_opt_shard_comm_opt, bool), "enable_begin_end_inline_opt": (ms_ctx_param.enable_begin_end_inline_opt, bool), "enable_concat_eliminate_opt": (ms_ctx_param.enable_concat_eliminate_opt, bool), diff --git a/tests/ut/python/parallel/test_bias_add_comm_swap.py b/tests/ut/python/parallel/test_bias_add_comm_swap.py new file mode 100644 index 000000000000..d7a560a553c1 --- /dev/null +++ b/tests/ut/python/parallel/test_bias_add_comm_swap.py @@ -0,0 +1,117 @@ +# Copyright 2020-2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import json +import os +import subprocess +import shutil +import numpy as np +import mindspore as ms +import mindspore.nn as nn +from mindspore import context +from mindspore import Tensor +from mindspore.common.api import _cell_graph_executor +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from tests.ut.python.ops.test_math_ops import VirtualLoss + + +def setup_function(): + context.set_auto_parallel_context(dataset_strategy="full_batch") + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.loss = VirtualLoss() + self.network = network + + def construct(self, x, y, b): + predict = self.network(x, y, b) + return self.loss(predict) + + +def compile_net(net, x, y, b): + net.set_train() + _cell_graph_executor.compile(net, x, y, b) + + +grad_all = C.GradOperation(get_all=True) + + +class GradWrap(nn.Cell): + def __init__(self, network): + super(GradWrap, self).__init__() + self.network = network + + def construct(self, x, y, b): + return grad_all(self.network)(x, y, b) + + +def test_bias_add_comm_swap(): + """ + Feature: test bias_add comm swap + Description: change structure like matmul+allreduce/reduce_scatter+bias to matmul+bias+allreduce/reduce_scatter + Expectation: compile success + """ + + class Net(nn.Cell): + def __init__(self, matmul_in_strategy, matmul_out_strategy, add_strategy): + super().__init__() + self.matmul = P.MatMul().shard(matmul_in_strategy, matmul_out_strategy) + self.add = P.Add().shard(add_strategy) + + def construct(self, x, w, b): + out = self.matmul(x, w) + out = self.add(out, b) + return out + + context.set_auto_parallel_context( + device_num=8, global_rank=0) + context.set_context(save_graphs=True) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + if os.path.exists("./speed_up.json"): + os.remove("./speed_up.json") + a = {"bias_add_comm_swap": True} + f = open("speed_up.json", "w") + f.write(json.dumps(a)) + f.close() + context.set_context(ascend_config={"parallel_speed_up_json_path": "speed_up.json"}) + + matmul_in_strategy = ((2, 2), (2, 1)) + matmul_out_strategy = ((4, 1),) + add_strategy = ((4, 1), (1,)) + net = GradWrap(NetWithLoss(Net(matmul_in_strategy, matmul_out_strategy, add_strategy))) + + x = Tensor(np.ones([128, 32]), dtype=ms.float16) + w = Tensor(np.ones([32, 64]), dtype=ms.float16) + b = Tensor(np.ones([64]), dtype=ms.float16) + if os.path.exists("./rank_0"): + shutil.rmtree("./rank_0") + # compile + compile_net(net, x, w, b) + + file = "./rank_0/*validate*.ir" + prim_name = "ReduceScatter" + para = "bias_add_comm_swap" + output = subprocess.check_output( + ["grep -r '%s' %s |grep '%s' | wc -l" % (prim_name, file, para)], + shell=True) + out = str(output, 'utf-8').strip() + assert out == "1" + if os.path.exists("./rank_0"): + shutil.rmtree("./rank_0") + if os.path.exists("./speed_up.json"): + os.remove("./speed_up.json") + context.set_context(save_graphs=False) -- Gitee