From c3861aedde7e924d8c2939cd2f77500b7ca47e21 Mon Sep 17 00:00:00 2001 From: panzhihui Date: Mon, 4 Dec 2023 06:05:47 +0800 Subject: [PATCH] Fix randomchoicewithmask error --- .../aicpu_ops/random_choice_with_mask_kernels.cc | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/random_choice_with_mask_kernels.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/random_choice_with_mask_kernels.cc index 2bf720f105a2..7b7cfd20ab91 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/random_choice_with_mask_kernels.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_ops/random_choice_with_mask_kernels.cc @@ -36,16 +36,11 @@ const size_t kIndex3 = 3; const size_t kIndex4 = 4; } // namespace -std::vector GetSamples(const bool *input, int64_t input_size, int64_t count_target) { +std::vector GetAllSamples(const bool *input, int64_t input_size) { std::vector sample_ids{}; - int64_t count{0}; for (int64_t i = 0; i < input_size; ++i) { if (input[i]) { sample_ids.push_back(i); - count++; - } - if (count >= count_target) { - break; } } return sample_ids; @@ -64,12 +59,12 @@ uint32_t RandomChoiceWithMaskKernel::DoCompute() { auto *mask = reinterpret_cast(io_addrs_[kIndex4]); int64_t input_size = std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies()); - std::vector sample_ids = GetSamples(input, input_size, count_target_); - size_t count = sample_ids.size(); + std::vector sample_ids = GetAllSamples(input, input_size); std::random_device rd; std::mt19937 g(rd()); std::shuffle(sample_ids.begin(), sample_ids.end(), g); + size_t count = std::min(sample_ids.size(), static_cast(count_target_)); // Calculate coordinates auto *output_offset = output_coordinate; -- Gitee