2.4K Star 8.2K Fork 4.4K

GVPMindSpore / mindspore

 / 详情

[ST][MS][MF][llama2_70b]网络评测失败,TypeError: For Operator[ReshapeAndCache], slot_mapping's type 'None' does not match expected type 'Tensor'

DONE
Bug-Report
创建于  
2024-05-07 15:04
name about labels
Bug Report Use this template for reporting a bug kind/bug

Describe the current behavior / 问题描述 (Mandatory / 必填)

llama2_70b分布式评测,如果把配置文件中的use_past改为True,评测失败,会报ReshapeAndCache算子错误,改成False,可以正常评测,但是评测速度很慢,2个多小时才跑一半都不到
模型仓地址:https://gitee.com/mindspore/mindformers/blob/dev/docs/model_cards/llama2.md#%E8%AF%84%E6%B5%8B

Environment / 环境信息 (Mandatory / 必填)

  • Hardware Environment(Ascend/GPU/CPU) / 硬件环境:

Please delete the backend not involved / 请删除不涉及的后端:
/device ascend/

  • Software Environment / 软件环境 (Mandatory / 必填):
    -- MindSpore version (e.g., 1.7.0.Bxxx) :
    -- Python version (e.g., Python 3.7.5) :
    -- OS platform and distribution (e.g., Linux Ubuntu 16.04):
    -- GCC/Compiler version (if compiled from source):

CANN:Milan_C17/20240414
MS:master_20240506061517_d8802c69db29
MF:dev_20240506121520_6ffde9b33612a3

  • Excute Mode / 执行模式 (Mandatory / 必填)(PyNative/Graph):

Please delete the mode not involved / 请删除不涉及的模式:
/mode pynative
/mode graph

Related testcase / 关联用例 (Mandatory / 必填)

用例仓地址:MindFormers_Test/cases/llama2/70b/train/
用例:
test_mf_llama2_70b_eval_squad_8p_0001

Steps to reproduce the issue / 重现步骤 (Mandatory / 必填)

  1. get code from mindformers
  2. cd mindformers
  3. 修改predict_llama2_70b.yaml中的权重、tokenizer路径,seq_length改为2048,type: PerplexityMetric改为type: EmF1Metric
  4. bash scripts/msrun_launcher.sh "run_mindformer.py --config {1} '
    '--run_mode eval " 8
  5. 验证网络是否评测成功

Describe the expected behavior / 预期结果 (Mandatory / 必填)

网络评测成功

Related log / screenshot / 日志 / 截图 (Mandatory / 必填)

2024-05-07 14:56:55,871 - mindformers[mindformers/trainer/utils.py:345] - INFO - .........Building model.........
[CRITICAL] ANALYZER(935442,ffffaac83020,python):2024-05-07-14:57:09.610.720 [mindspore/ccsrc/pipeline/jit/ps/static_analysis/prim.cc:1263] CheckArgsSizeAndType] For Operator[ReshapeAndCache], slot_mapping's type 'None' does not match expected type 'Tensor'.
The reason may be: lack of definition of type cast, or incorrect type when creating the node.
This exception is caused by framework's unexpected error. Please create an issue at https://gitee.com/mindspore/mindspore/issues to get help.
2024-05-07 14:57:20,216 - mindformers[mindformers/tools/cloud_adapter/cloud_monitor.py:43] - ERROR - Traceback (most recent call last):
 File "/home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/tools/cloud_adapter/cloud_monitor.py", line 34, in wrapper
   result = run_func(*args, **kwargs)
 File "run_mindformer.py", line 41, in main
   trainer.evaluate(eval_checkpoint=config.load_checkpoint)
 File "/home/miniconda3/envs/ci/lib/python3.7/site-packages/mindspore/_checkparam.py", line 1372, in wrapper
   return func(*args, **kwargs)
 File "/home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/trainer/trainer.py", line 603, in evaluate
   compute_metrics=self.compute_metrics, is_full_config=True, **kwargs)
 File "/home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/trainer/causal_language_modeling/causal_language_modeling.py", line 161, in evaluate
   **kwargs)
 File "/home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/trainer/causal_language_modeling/causal_language_modeling.py", line 229, in generate_evaluate
   transform_and_load_checkpoint(config, model, network, dataset, do_eval=True)
 File "/home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/trainer/utils.py", line 346, in transform_and_load_checkpoint
   build_model(config, model, dataset, do_eval=do_eval, do_predict=do_predict)
 File "/home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/trainer/utils.py", line 470, in build_model
   model.infer_predict_layout(*next(dataset.create_tuple_iterator()))
 File "/home/miniconda3/envs/ci/lib/python3.7/site-packages/mindspore/train/model.py", line 1896, in infer_predict_layout
   predict_net.compile(*predict_data)
 File "/home/miniconda3/envs/ci/lib/python3.7/site-packages/mindspore/nn/cell.py", line 997, in compile
   jit_config_dict=self._jit_config_dict, **kwargs)
 File "/home/miniconda3/envs/ci/lib/python3.7/site-packages/mindspore/common/api.py", line 1642, in compile
   result = self._graph_executor.compile(obj, args, kwargs, phase, self._use_vm_mode())
TypeError: For Operator[ReshapeAndCache], slot_mapping's type 'None' does not match expected type 'Tensor'.
The reason may be: lack of definition of type cast, or incorrect type when creating the node.

----------------------------------------------------
- Framework Unexpected Exception Raised:
----------------------------------------------------
This exception is caused by framework's unexpected error. Please create an issue at https://gitee.com/mindspore/mindspore/issues to get help.

----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/ccsrc/pipeline/jit/ps/static_analysis/prim.cc:1263 CheckArgsSizeAndType

----------------------------------------------------
- The Traceback of Net Construct Code:
----------------------------------------------------
# 0 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:356
       if self.use_past:
# 1 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:357
           if not isinstance(batch_valid_length, Tensor):
# 2 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:359
       if self.training:
# 3 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:362
           tokens = input_ids
           ^
# 4 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:365
       if not self.is_first_iteration:
# 5 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:369
       if pre_gather:
# 6 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:367
       output = self.model(tokens, batch_valid_length, batch_index, zactivate_len, block_tables, slot_mapping)
                ^
# 7 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:193
       if self.use_past:
# 8 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:194
           if self.is_first_iteration:
# 9 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:195
               freqs_cis = self.freqs_mgr.prefill(bs, seq_len)
               ^
# 10 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:194
           if self.is_first_iteration:
# 11 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:367
       output = self.model(tokens, batch_valid_length, batch_index, zactivate_len, block_tables, slot_mapping)
                ^
# 12 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:206
       for i in range(self.num_layers):
# 13 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:207
           h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, block_tables=block_tables,
               ^
# 14 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama_transformer.py:500
       if not self.use_past:
# 15 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama.py:207
           h = self.layers[i](h, freqs_cis, mask, batch_valid_length=batch_valid_length, block_tables=block_tables,
               ^
# 16 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama_transformer.py:507
       h = self.attention(input_x, freqs_cis, mask, batch_valid_length, block_tables, slot_mapping)
           ^
# 17 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama_transformer.py:248
       if self.qkv_concat:
# 18 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama_transformer.py:258
           query = self.cast(self.wq(x), self.dtype)  # dp, 1 -> dp, mp
           ^
# 19 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama_transformer.py:263
       if self.use_past:
# 20 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama_transformer.py:264
           freqs_cos, freqs_sin, _ = freqs_cis
# 21 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/models/llama/llama_transformer.py:265
           context_layer = self.infer_attention(query, key, value, batch_valid_length, block_tables, slot_mapping,
                           ^
# 22 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/modules/infer_attention.py:281
       if self.use_rope_rotary_emb:
# 23 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/modules/infer_attention.py:283
           freqs_cos = self.cast(freqs_cos, mstype.float16)
# 24 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/modules/infer_attention.py:289
       if self.is_first_iteration:
# 25 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/modules/infer_attention.py:290
           if self.input_layout == "BSH":
           ^
# 26 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/modules/infer_attention.py:291
               context_layer = self.flash_attention(query, key, value, attn_mask, alibi_mask)
               ^
# 27 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/modules/infer_attention.py:290
           if self.input_layout == "BSH":
           ^
# 28 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/modules/infer_attention.py:286
       key_out = self.paged_attention_mgr(key, value, slot_mapping)
                 ^
# 29 In file /home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/mindformers/modules/paged_attention_mgr.py:61
       return self.reshape_and_cache(key, value, self.key_cache, self.value_cache, slot_mapping)
              ^
(See file '/home/jenkins0/MindFormers_Test/cases/llama2/70b/train/test_mf_llama2_70b_eval_squad_8p_0001/rank_0/om/analyze_fail.ir' for more details. Get instructions about `analyze_fail.ir` at https://www.mindspore.cn/search?inputValue=analyze_fail.ir)

Special notes for this issue/备注 (Optional / 选填)

走给谭纬城

评论 (8)

zhangjie18 创建了Bug-Report
zhangjie18 添加了
 
kind/bug
标签
zhangjie18 添加了
 
v2.3.0.rc2
标签
zhangjie18 添加了
 
attr/function
标签
zhangjie18 添加了
 
stage/func-debug
标签
zhangjie18 添加了
 
sig/mindformers
标签
zhangjie18 添加了
 
device/ascend
标签
zhangjie18 添加协作者xiangminshan
展开全部操作日志

Please assign maintainer to check this issue.
请为此issue分配处理人。
@zhangjie18

感谢您的提问,您可以评论//mindspore-assistant更快获取帮助:

  1. 如果您刚刚接触MindSpore,或许您可以在教程找到答案
  2. 如果您是资深Pytorch用户,您或许需要:
  1. 如果您遇到动态图问题,可以设置set_context(pynative_synchronize=True)查看报错栈协助定位
  2. 模型精度调优问题可参考官网调优指南
  3. 如果您反馈的是框架BUG,请确认您在ISSUE中提供了MindSpore版本、使用的后端类型(CPU、GPU、Ascend)、环境、训练的代码官方链接以及可以复现报错的代码的启动方式等必要的定位信息
  4. 如果您已经定位出问题根因,欢迎提交PR参与MindSpore开源社区,我们会尽快review
zhangjie18 修改了标题
wangxingyan 添加协作者wangxingyan
wangxingyan 负责人wangxingyan 修改为tan-wei-cheng

跟测试对齐:

ppl只能走use_past=false,不支持增量推理,非ppl那些生成式走use_past=true,走增量推理。

llama2-70B的分布式评测,需要加上predict_infer_layout的逻辑。

转给冯浩验证跟踪

tan-wei-cheng 负责人tan-wei-cheng 修改为冯浩
冯浩 添加协作者冯浩
冯浩 负责人冯浩 修改为renyujin
tan-wei-cheng 负责人renyujin 修改为tan-wei-cheng
tan-wei-cheng 添加协作者renyujin
tan-wei-cheng 任务状态TODO 修改为VALIDATION
tan-wei-cheng 负责人tan-wei-cheng 修改为zhangjie18
i-robot 添加了
 
gitee
标签
tan-wei-cheng 移除了
 
gitee
标签
tan-wei-cheng 添加了
 
rct/oldrelease
标签
tan-wei-cheng 添加了
 
rca/others
标签
tan-wei-cheng 添加了
 
ctl/solutiontest
标签

验证结果:

2024-05-09 09:46:06,469 - mindformers[mindformers/generation/text_generator.py:886] - INFO - total time: 52.50571131706238 s; generated tokens: 1812 tokens; generate speed: 34.51053141739208 tokens/s
2024-05-09 09:46:06,470 - mindformers[mindformers/modules/block_tables.py:129] - INFO - Clear block table cache engines.
2024-05-09 09:46:06,471 - mindformers[mindformers/trainer/causal_language_modeling/causal_language_modeling.py:283] - INFO - Step[1/2067], cost time 52.5124s, every example cost time is 52.5124, generate speed: 34.5062 tokens/s, avg speed: 0.0000 tokens/s, remaining time: 0:00:00
Building prefix dict from the default dictionary ...
DEBUG:jieba:Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
DEBUG:jieba:Loading model from cache /tmp/jieba.cache

i-robot 添加了
 
gitee
标签

回归版本:Ms:master_20240509061515_b6f9201324ff
MF:dev_20240509021521_8e97c2b1676e3b
回归步骤:参考issue步骤
基本问题:已解决
输入图片说明
测试结论:回归通过
回归时间:2024.5.9

i-robot 添加了
 
foruda
标签
zhangjie18 任务状态VALIDATION 修改为DONE
fangwenyi 移除了
 
v2.3.0.rc2
标签
fangwenyi 添加了
 
master
标签

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
预计工期 (小时)
参与者(7)
11016979 xiangmd 1654824581
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore

搜索帮助