From 9de104efa9b4ad063b278bb9b88652abb449c275 Mon Sep 17 00:00:00 2001 From: yide12 Date: Mon, 18 Mar 2024 10:02:15 +0800 Subject: [PATCH] ckpt_type_convert_add_bf16 --- mindspore/python/mindspore/train/serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/python/mindspore/train/serialization.py b/mindspore/python/mindspore/train/serialization.py index 15d9fb5d3940..46304e74a18f 100644 --- a/mindspore/python/mindspore/train/serialization.py +++ b/mindspore/python/mindspore/train/serialization.py @@ -176,7 +176,7 @@ def _update_param(param, new_param, strict_load): def _type_convert(param, new_param, strict_load): """Whether to convert parameter's type during load checkpoint into network.""" - float_type = (mstype.float16, mstype.float32, mstype.float64) + float_type = (mstype.float16, mstype.float32, mstype.float64, mstype.bfloat16) int_type = (mstype.int8, mstype.int16, mstype.int32, mstype.int64) if not strict_load and ({param.data.dtype, new_param.data.dtype}.issubset(float_type) or {param.data.dtype, new_param.data.dtype}.issubset(int_type)): -- Gitee