diff --git a/mindspore/python/mindspore/nn/layer/flash_attention.py b/mindspore/python/mindspore/nn/layer/flash_attention.py index fa200f2ff1d038f76aa9fcdfcf39eb8380e76be6..b8ef1e145599df98a5dba544c54a4386ccc22805 100644 --- a/mindspore/python/mindspore/nn/layer/flash_attention.py +++ b/mindspore/python/mindspore/nn/layer/flash_attention.py @@ -110,7 +110,7 @@ class FlashAttention(Cell): high_precision=high_precision ) self.flash_attention.add_prim_attr("primitive_target", "Ascend") - scaling_constant = math.sqrt(head_dim) + scaling_constant = math.sqrt(math.sqrt(head_dim)) if scaling_constant != 0: self.scale_factor = Tensor([1. / scaling_constant], dtype=mstype.float16) else: @@ -186,6 +186,7 @@ class FlashAttention(Cell): :return: output [bsz, head_num, seq_len, head_dim] """ query = self.scale_mul(query, self.scale_factor) + key = self.scale_mul(key, self.scale_factor) bsz, head_num, seq_len, head_dim = query.shape _, k_head_num, k_seq_len, _ = key.shape _, v_head_num, v_seq_len, _ = value.shape