2.4K Star 8.2K Fork 4.4K

GVPMindSpore / mindspore

 / 详情

[ST][MS][memory]网络训练过程中内存泄漏

ACCEPTED
RFC
创建于  
2022-12-19 20:34
name about labels
Bug Report Use this template for reporting a bug kind/bug

Describe the current behavior / 问题描述 (Mandatory / 必填)

网络训练过程中host内存增长

Environment / 环境信息 (Mandatory / 必填)

  • Hardware Environment(Ascend/GPU/CPU) / 硬件环境:

Please delete the backend not involved / 请删除不涉及的后端:
/device ascend/GPU/CPU

  • Software Environment / 软件环境 (Mandatory / 必填):
    -- MindSpore version (e.g., 1.7.0.Bxxx) : commit_id__ = ''[sha1]:afc602d1,[branch]:(HEAD,origin/r1.10,r1.10)''
    -- Python version (e.g., Python 3.7.5) :
    -- OS platform and distribution (e.g., Linux Ubuntu 16.04):
    -- GCC/Compiler version (if compiled from source):

  • Excute Mode / 执行模式 (Mandatory / 必填)(PyNative/Graph):

Please delete the mode not involved / 请删除不涉及的模式:
/mode pynative
/mode graph

Related testcase / 关联用例 (Mandatory / 必填)

ms_lenet_mnist_train_infer_0003

Steps to reproduce the issue / 重现步骤 (Mandatory / 必填)

  1. 官网教程: https://www.mindspore.cn/tutorials/zh-CN/master/beginner/train.html
  2. 依据教程做了适配修改,代码见下面附录
  3. 执行训练

Describe the expected behavior / 预期结果 (Mandatory / 必填)

训练成功,无内存泄漏

Related log / screenshot / 日志 / 截图 (Mandatory / 必填)

1、 代码1host内存增长较快,大约每10个step增长75M左右,代码见附录,可直接执行

2、代码2 host内存增长较慢,大约每500个step增长5M左右,代码见附录,可直接执行

Special notes for this issue/备注 (Optional / 选填)

代码1

import os

import torch.nn as nn1
import mindspore.nn as nn2
import torch
import mindspore
import torchvision


class CNN_torch(nn1.Module):

    def __init__(self):
        # 调用父类的构造函数
        super(CNN_torch, self).__init__()
        # 第一层卷积池化, Sequential内的函数顺序执行
        # 原文中激活函数都是用的sigmoid,这里使用更好的ReLu
        self.conv_pool1 = nn1.Sequential(
            nn1.Conv2d(in_channels=1,  # input (1, 28, 28) padding to(1,32,32)
                       # 这里的input和output的值都是针对一个样本来说的,而训练时是一次输入一个batch
                       out_channels=6,
                       kernel_size=(5, 5),
                       padding=2
                       ),  # output(6, 28, 28)
            nn1.ReLU(),  # 激活函数
            nn1.MaxPool2d(2, stride=2)  # output(6, 14, 14)
        )

        self.conv_pool2 = nn1.Sequential(
            nn1.Conv2d(in_channels=6,
                       out_channels=16,
                       kernel_size=(5, 5)
                       ),  # output(16, 10, 10)
            nn1.ReLU(),
            nn1.MaxPool2d(2, stride=2)  # output(16, 5, 5)
        )

        # 全连接层
        self.fc1 = nn1.Sequential(  # 这里用全连接层代替原文的卷积层
            nn1.Linear(16 * 5 * 5, 120),
            nn1.ReLU()
        )

        # 全连接层
        self.fc2 = nn1.Sequential(
            nn1.Linear(120, 84),
            nn1.ReLU()
        )
        # 输出层
        self.out = nn1.Sequential(
            nn1.Linear(84, 10),

        )

    # 前向传播
    def forward(self, x):
        x = self.conv_pool1(x)
        x = self.conv_pool2(x)
        x = x.view(x.size(0), -1)  # resize to 2-dims(batch_size, 16*5*5) 展平成1维
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.out(x)
        return x


class CNN_ms(nn2.Cell):

    def __init__(self):
        # 调用父类的构造函数
        super(CNN_ms, self).__init__()
        # 第一层卷积池化, Sequential内的函数顺序执行
        # 原文中激活函数都是用的sigmoid,这里使用更好的ReLu
        self.conv_pool1 = nn2.SequentialCell(
            [nn2.Conv2d(in_channels=1,
                        pad_mode="pad",
                        padding=2,
                        out_channels=6,
                        kernel_size=(5, 5),
                        ),
             nn2.ReLU(),
             nn2.MaxPool2d(2, stride=2)]
        )
        self.reshape = mindspore.ops.Reshape()
        self.conv_pool2 = nn2.SequentialCell(
            [nn2.Conv2d(in_channels=6,
                        out_channels=16,
                        pad_mode="valid",
                        kernel_size=(5, 5)
                        ),
             nn2.ReLU(),
             nn2.MaxPool2d(2, stride=2)]
        )

        # 全连接层
        self.fc1 = nn2.SequentialCell(  # 这里用全连接层代替原文的卷积层
            [nn2.Dense(16 * 5 * 5, 120),
             nn2.ReLU()]
        )

        # 全连接层
        self.fc2 = nn2.SequentialCell([nn2.Dense(120, 84), nn2.ReLU()])
        # 输出层
        self.out = nn2.SequentialCell([nn2.Dense(84, 10), ])

    # 前向传播
    def construct(self, x):
        x = self.conv_pool1(x)
        x = self.conv_pool2(x)
        x = self.reshape(x, (-1, 400))  # resize to 2-dims(batch_size, 16*5*5) 展平成1维
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.out(x)
        return x


def data_loader(batch_size):
    # 将数据类型转换成tensor的函数
    transform = torchvision.transforms.ToTensor()

    train_set = torchvision.datasets.MNIST(root='minist', train=True, transform=transform, download=True)
    train_loaders = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=2)

    test_set = torchvision.datasets.MNIST(root='minist', train=False, transform=transform, download=True)
    test_loaders = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True, num_workers=2)

    return train_loaders, test_loaders


def train_loop(model, x, y, loss_fn, optimizer):
    # Define forward function
    def forward_fn(data, label):
        logits = model(data)
        loss = loss_fn(logits, label)
        return loss, logits

    # Get gradient function
    grad_fn = mindspore.ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

    # Define function of one-step training
    def train_step(data, label):
        (loss, _), grads = grad_fn(data, label)
        loss = mindspore.ops.depend(loss, optimizer(grads))
        return loss

    loss = train_step(x, y)
    return loss


if __name__ == '__main__':
    # model1 = CNN_torch()
    model2 = CNN_ms()
    epochs = 3

    batch_size = 32
    train_loader, test_loader = data_loader(batch_size)

    # 损失函数
    # loss_fun1 = nn1.CrossEntropyLoss()
    loss_fun2 = nn2.CrossEntropyLoss()

    # 优化器
    learning_rate = 1e-2
    # optimizer1 = torch.optim.SGD(model1.parameters(), lr=learning_rate)
    optimizer2 = nn2.SGD(params=model2.trainable_params(), learning_rate=learning_rate)

    for epoch in range(epochs):
        # model1.train()
        model2.set_train()
        print("epoch:" + str(epoch))
        batch = 0
        for data in train_loader:
            imgs, targets = data
            imgs_array, targets_array = imgs.numpy(), targets.numpy()
            imgs_ms, targets_ms = mindspore.Tensor(imgs_array, mindspore.float32), mindspore.Tensor(targets_array,
                                                                                                    mindspore.int32)

            # output_torch = model1(imgs)
            # loss_torch = loss_fun1(output_torch, targets)
            # 优化器优化模型
            # optimizer1.zero_grad()
            # loss_torch.backward()
            # optimizer1.step()
            loss_ms = train_loop(model2, imgs_ms, targets_ms, loss_fun2, optimizer2)
            if batch % 10 == 0:
                # print("batch:" + str(batch) + " torch_loss:" + str(loss_torch.item()) + " ms_loss:" + str(
                print("epoch:" + str(epoch) + "step:" + str(batch) + " ms_loss:" + str(
                    loss_ms.asnumpy()))
                print("free -m :", os.system("free -m"))

            batch += 1

        # 测试步骤开始
        # model1.eval()
        model2.set_train(False)
        test_data_size = 0
        total_test_loss = 0
        total_accuracy = 0
        test_loss_ms, correct_ms = 0, 0
        with torch.no_grad():
            for data in test_loader:
                imgs, targets = data
                imgs_array, targets_array = imgs.numpy(), targets.numpy()
                imgs_ms, targets_ms = mindspore.Tensor(imgs_array, mindspore.float32), mindspore.Tensor(targets_array,
                                                                                                        mindspore.int32)
                test_data_size += len(imgs_ms)

                # outputs_torch = model1(imgs)
                # loss_torch = loss_fun1(outputs_torch, targets)

                pred_ms = model2(imgs_ms)
                test_loss_ms += loss_fun2(pred_ms, targets_ms).asnumpy()
                correct_ms += (pred_ms.argmax(1) == targets_ms).asnumpy().sum()

                # total_test_loss = total_test_loss + loss_torch.item()
                # accuracy = (outputs_torch.argmax(1) == targets).asnumpy().sum()
                # total_accuracy = total_accuracy + accuracy

        test_loss_ms /= test_data_size
        correct_ms /= test_data_size

        print("Pytorch Test Accuracy: {}%".format(
            100 * total_accuracy / test_data_size))  # +" "+"Pytorch Test Loss: {}".format(total_test_loss/ test_data_size)
        print(f"Mindspore Test Accuracy: {(100 * correct_ms)}%\n")  # , Mindspore Test Loss: {test_loss_ms}

代码2

import os

import torch.nn as nn1
import mindspore.nn as nn2
import torch
import mindspore
import torchvision


class CNN_torch(nn1.Module):

    def __init__(self):
        # 调用父类的构造函数
        super(CNN_torch, self).__init__()
        # 第一层卷积池化, Sequential内的函数顺序执行
        # 原文中激活函数都是用的sigmoid,这里使用更好的ReLu
        self.conv_pool1 = nn1.Sequential(
            nn1.Conv2d(in_channels=1,  # input (1, 28, 28) padding to(1,32,32)
                       # 这里的input和output的值都是针对一个样本来说的,而训练时是一次输入一个batch
                       out_channels=6,
                       kernel_size=(5, 5),
                       padding=2
                       ),  # output(6, 28, 28)
            nn1.ReLU(),  # 激活函数
            nn1.MaxPool2d(2, stride=2)  # output(6, 14, 14)
        )

        self.conv_pool2 = nn1.Sequential(
            nn1.Conv2d(in_channels=6,
                       out_channels=16,
                       kernel_size=(5, 5)
                       ),  # output(16, 10, 10)
            nn1.ReLU(),
            nn1.MaxPool2d(2, stride=2)  # output(16, 5, 5)
        )

        # 全连接层
        self.fc1 = nn1.Sequential(  # 这里用全连接层代替原文的卷积层
            nn1.Linear(16 * 5 * 5, 120),
            nn1.ReLU()
        )

        # 全连接层
        self.fc2 = nn1.Sequential(
            nn1.Linear(120, 84),
            nn1.ReLU()
        )
        # 输出层
        self.out = nn1.Sequential(
            nn1.Linear(84, 10),

        )

    # 前向传播
    def forward(self, x):
        x = self.conv_pool1(x)
        x = self.conv_pool2(x)
        x = x.view(x.size(0), -1)  # resize to 2-dims(batch_size, 16*5*5) 展平成1维
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.out(x)
        return x


class CNN_ms(nn2.Cell):

    def __init__(self):
        # 调用父类的构造函数
        super(CNN_ms, self).__init__()
        # 第一层卷积池化, Sequential内的函数顺序执行
        # 原文中激活函数都是用的sigmoid,这里使用更好的ReLu
        self.conv_pool1 = nn2.SequentialCell(
            [nn2.Conv2d(in_channels=1,
                        pad_mode="pad",
                        padding=2,
                        out_channels=6,
                        kernel_size=(5, 5),
                        ),
             nn2.ReLU(),
             nn2.MaxPool2d(2, stride=2)]
        )
        self.reshape = mindspore.ops.Reshape()
        self.conv_pool2 = nn2.SequentialCell(
            [nn2.Conv2d(in_channels=6,
                        out_channels=16,
                        pad_mode="valid",
                        kernel_size=(5, 5)
                        ),
             nn2.ReLU(),
             nn2.MaxPool2d(2, stride=2)]
        )

        # 全连接层
        self.fc1 = nn2.SequentialCell(  # 这里用全连接层代替原文的卷积层
            [nn2.Dense(16 * 5 * 5, 120),
             nn2.ReLU()]
        )

        # 全连接层
        self.fc2 = nn2.SequentialCell([nn2.Dense(120, 84), nn2.ReLU()])
        # 输出层
        self.out = nn2.SequentialCell([nn2.Dense(84, 10), ])

    # 前向传播
    def construct(self, x):
        x = self.conv_pool1(x)
        x = self.conv_pool2(x)
        x = self.reshape(x, (-1, 400))  # resize to 2-dims(batch_size, 16*5*5) 展平成1维
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.out(x)
        return x


def data_loader(batch_size):
    # 将数据类型转换成tensor的函数
    transform = torchvision.transforms.ToTensor()

    train_set = torchvision.datasets.MNIST(root='minist', train=True, transform=transform, download=True)
    train_loaders = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True, num_workers=2)

    test_set = torchvision.datasets.MNIST(root='minist', train=False, transform=transform, download=True)
    test_loaders = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True, num_workers=2)

    return train_loaders, test_loaders


def train_loop(model, x, y, loss_fn, optimizer):
    # Define forward function
    def forward_fn(data, label):
        logits = model(data)
        loss = loss_fn(logits, label)
        return loss, logits

    # Get gradient function
    grad_fn = mindspore.ops.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

    # Define function of one-step training
    def train_step(data, label):
        (loss, _), grads = grad_fn(data, label)
        loss = mindspore.ops.depend(loss, optimizer(grads))
        return loss

    loss = train_step(x, y)
    return loss


if __name__ == '__main__':
    model1 = CNN_torch()
    model2 = CNN_ms()
    epochs = 3

    batch_size = 32
    train_loader, test_loader = data_loader(batch_size)

    # 损失函数
    loss_fun1 = nn1.CrossEntropyLoss()
    loss_fun2 = nn2.CrossEntropyLoss()

    # 优化器
    learning_rate = 1e-2
    optimizer1 = torch.optim.SGD(model1.parameters(), lr=learning_rate)
    optimizer2 = nn2.SGD(params=model2.trainable_params(), learning_rate=learning_rate)

    for epoch in range(epochs):
        # model1.train()
        model2.set_train()
        print("epoch:" + str(epoch))
        batch = 0

        def forward_fn(data, label):
            logits = model2(data)
            loss = loss_fun2(logits, label)
            return loss, logits

        # Get gradient function
        grad_fn = mindspore.ops.value_and_grad(forward_fn, None, optimizer2.parameters, has_aux=True)

        # Define function of one-step training
        def train_step(data, label):
            (loss, _), grads = grad_fn(data, label)
            loss = mindspore.ops.depend(loss, optimizer2(grads))
            return loss

        for data in train_loader:
            imgs, targets = data
            imgs_array, targets_array = imgs.numpy(), targets.numpy()
            imgs_ms, targets_ms = mindspore.Tensor(imgs_array, mindspore.float32), mindspore.Tensor(targets_array,
                                                                                                    mindspore.int32)

            output_torch = model1(imgs)
            loss_torch = loss_fun1(output_torch, targets)
            # 优化器优化模型
            optimizer1.zero_grad()
            loss_torch.backward()
            optimizer1.step()

            loss_ms = train_step(imgs_ms, targets_ms)
            # loss_ms = train_loop(model2, imgs_ms, targets_ms, loss_fun2, optimizer2)
            if batch % 10 == 0:
                print("epoch:" + str(epoch) + "batch:" + str(batch) + " torch_loss:" + str(loss_torch.item()) +
                # print("epoch:" + str(epoch) + "step:" + str(batch) + " ms_loss:" + str(
                      " ms_loss:" + str(loss_ms.asnumpy()))
                print("free -m :", os.system("free -m"))

            batch += 1

        # 测试步骤开始
        model1.eval()
        model2.set_train(False)
        test_data_size = 0
        total_test_loss = 0
        total_accuracy = 0
        test_loss_ms, correct_ms = 0, 0
        with torch.no_grad():
            for data in test_loader:
                imgs, targets = data
                imgs_array, targets_array = imgs.numpy(), targets.numpy()
                imgs_ms, targets_ms = mindspore.Tensor(imgs_array, mindspore.float32), mindspore.Tensor(targets_array,
                                                                                                        mindspore.int32)
                test_data_size += len(imgs_ms)

                outputs_torch = model1(imgs)
                loss_torch = loss_fun1(outputs_torch, targets)

                pred_ms = model2(imgs_ms)
                test_loss_ms += loss_fun2(pred_ms, targets_ms).asnumpy()
                correct_ms += (pred_ms.argmax(1) == targets_ms).asnumpy().sum()

                total_test_loss = total_test_loss + loss_torch.item()
                accuracy = (outputs_torch.argmax(1) == targets).sum()
                total_accuracy = total_accuracy + accuracy

        test_loss_ms /= test_data_size
        correct_ms /= test_data_size

        print("Pytorch Test Accuracy: {}%".format(
            100 * total_accuracy / test_data_size))  # +" "+"Pytorch Test Loss: {}".format(total_test_loss/ test_data_size)
        print(f"Mindspore Test Accuracy: {(100 * correct_ms)}%\n")  # , Mindspore Test Loss: {test_loss_ms}

评论 (7)

xiangjiawei007 创建了Bug-Report
xiangjiawei007 添加了
 
attr/function
标签
xiangjiawei007 移除了
 
attr/function
标签
xiangjiawei007 添加了
 
sig/compiler
标签
xiangjiawei007 添加了
 
attr/function
标签
xiangjiawei007 添加了
 
v1.10
标签
xiangjiawei007 添加了
 
v2.0.0.alpha
标签
xiangjiawei007 添加了
 
v2.0.0.rc1
标签
xiangjiawei007 添加了
 
kind/bug
标签
xiangjiawei007 里程碑设置为B-SIG-Compiler(已删除)
xiangjiawei007 移除了
 
v2.0.0.alpha
标签
xiangjiawei007 添加了
 
v2.0.0.alpha
标签
xiangjiawei007 关联仓库设置为MindSpore/mindspore
xiangjiawei007 关联分支设置为master
展开全部操作日志

与余坚峰、黄辉定位,当前ms_function存在内存泄漏问题,此为问题1; 问题2,请排查是否有其他问题导致内存泄漏

优先分析重复编译问题

1、每次调用value_and_grad,都会创建_Grad对象。
2、forward_fn函数定义在for循环内,导致forward_fn的地址不一样,导致_Grad中self.fn缓存无法命中
3、

class _Grad(GradOperation_):
    ...
    def __call__(self, fn, weights=None, grad_position=0):
        ...
    
        @jit(input_signature=dynamic_shape_inputs)
        def after_grad(*args):
            return grad_(fn, weights, grad_position)(*args)

对于写在函数内部的jit函数,进入__call__时,after_grad才会被初始化,且每次进入__call__都会初始化after_grad。因此after_grad的创建时间不一样,导致jit编译缓存无法命中。

综上,代码1的train_loop在for循环内的写法会导致重复编译,导致内存不断增长

注释掉代码1中关于数据、torch相关代码后如下

if __name__ == '__main__':
    #model1 = CNN_torch()
    model2 = CNN_ms()
    epochs = 3

    batch_size = 32
    #train_loader, test_loader = data_loader(batch_size)

    # 损失函数
    #loss_fun1 = nn1.CrossEntropyLoss()
    loss_fun2 = nn2.CrossEntropyLoss()

    # 优化器
    learning_rate = 1e-2
    #optimizer1 = torch.optim.SGD(model1.parameters(), lr=learning_rate)
    optimizer2 = nn2.SGD(params=model2.trainable_params(), learning_rate=learning_rate)

    for epoch in range(epochs):
        # model1.train()
        model2.set_train()
        print("epoch:" + str(epoch))
        batch = 0

        def forward_fn(data, label):
            logits = model2(data)
            loss = loss_fun2(logits, label)
            return loss, logits

        # Get gradient function
        grad_fn = mindspore.ops.value_and_grad(forward_fn, None, optimizer2.parameters, has_aux=True)

        # Define function of one-step training
        def train_step(data, label):
            (loss, _), grads = grad_fn(data, label)
            loss = mindspore.ops.depend(loss, optimizer2(grads))
            return loss
        
        data = (mindspore.Tensor(np.ones([32, 1, 28,28], dtype = np.float32)), mindspore.Tensor(np.ones([32], dtype = np.int32)))
        while batch < 1800:
        #for data in train_loader:
            imgs, targets = data
            imgs_array, targets_array = imgs.numpy(), targets.numpy()
            imgs_ms, targets_ms = mindspore.Tensor(imgs_array, mindspore.float32), mindspore.Tensor(targets_array,
                                                                                                    mindspore.int32)

            #output_torch = model1(imgs)
            #loss_torch = loss_fun1(output_torch, targets)
            # 优化器优化模型
            #optimizer1.zero_grad()
            #loss_torch.backward()
            #optimizer1.step()

            loss_ms = train_step(imgs_ms, targets_ms)
            # loss_ms = train_loop(model2, imgs_ms, targets_ms, loss_fun2, optimizer2)
            if batch % 50 == 0:
                #print("epoch:" + str(epoch) + "batch:" + str(batch) + " torch_loss:" + str(loss_torch.item()) +
                # print("epoch:" + str(epoch) + "step:" + str(batch) + " ms_loss:" + str(
                #      " ms_loss:" + str(loss_ms.asnumpy()))
                #print("free -m :", os.system("free -m"))
                print("epoch:" + str(epoch) + "batch:" + str(batch), psutil.Process(os.getpid()).memory_info().rss / 1024)

            batch += 1

结果如下:

epoch:0batch:0 838972.0
epoch:0batch:50 886256.0
epoch:0batch:100 877084.0
epoch:0batch:150 876308.0
epoch:0batch:200 879956.0
epoch:0batch:250 879788.0
epoch:0batch:300 882016.0
epoch:0batch:350 877404.0
epoch:0batch:400 879392.0
epoch:0batch:450 881336.0
epoch:0batch:500 882280.0
epoch:0batch:550 884076.0
epoch:0batch:600 885076.0
epoch:0batch:650 877680.0
epoch:0batch:700 881732.0
epoch:0batch:750 882808.0
epoch:0batch:800 884548.0
epoch:0batch:850 884164.0
epoch:0batch:900 880000.0
epoch:0batch:950 877096.0
epoch:0batch:1000 882108.0
epoch:0batch:1050 879740.0
epoch:0batch:1100 885292.0
epoch:0batch:1150 886488.0
epoch:0batch:1200 878544.0
epoch:0batch:1250 882800.0
epoch:0batch:1300 889340.0
epoch:0batch:1350 882784.0
epoch:0batch:1400 887128.0
epoch:0batch:1450 883168.0
epoch:0batch:1500 885100.0
epoch:0batch:1550 882144.0
epoch:0batch:1600 885100.0
epoch:0batch:1650 885168.0
epoch:0batch:1700 888200.0
epoch:0batch:1750 882752.0

进程占用内存就维持在了882M左右,并无泄漏。

注释掉代码1中关于数据、torch相关代码后如下

if __name__ == '__main__':
    #model1 = CNN_torch()
    model2 = CNN_ms()
    epochs = 3
    batch_size = 32
    #train_loader, test_loader = data_loader(batch_size)
    # 损失函数
    #loss_fun1 = nn1.CrossEntropyLoss()
    loss_fun2 = nn2.CrossEntropyLoss()
    # 优化器
    learning_rate = 1e-2
    #optimizer1 = torch.optim.SGD(model1.parameters(), lr=learning_rate)
    optimizer2 = nn2.SGD(params=model2.trainable_params(), learning_rate=learning_rate)
    for epoch in range(epochs):
        # model1.train()
        model2.set_train()
        print("epoch:" + str(epoch))
        batch = 0
        def forward_fn(data, label):
            logits = model2(data)
            loss = loss_fun2(logits, label)
            return loss, logits
        # Get gradient function
        grad_fn = mindspore.ops.value_and_grad(forward_fn, None, optimizer2.parameters, has_aux=True)
        # Define function of one-step training
        def train_step(data, label):
            (loss, _), grads = grad_fn(data, label)
            loss = mindspore.ops.depend(loss, optimizer2(grads))
            return loss
        
        data = (mindspore.Tensor(np.ones([32, 1, 28,28], dtype = np.float32)), mindspore.Tensor(np.ones([32], dtype = np.int32)))
        while batch < 1800:
        #for data in train_loader:
            imgs, targets = data
            imgs_array, targets_array = imgs.numpy(), targets.numpy()
            imgs_ms, targets_ms = mindspore.Tensor(imgs_array, mindspore.float32), mindspore.Tensor(targets_array,
                                                                                                    mindspore.int32)
            #output_torch = model1(imgs)
            #loss_torch = loss_fun1(output_torch, targets)
            # 优化器优化模型
            #optimizer1.zero_grad()
            #loss_torch.backward()
            #optimizer1.step()
            loss_ms = train_step(imgs_ms, targets_ms)
            # loss_ms = train_loop(model2, imgs_ms, targets_ms, loss_fun2, optimizer2)
            if batch % 50 == 0:
                #print("epoch:" + str(epoch) + "batch:" + str(batch) + " torch_loss:" + str(loss_torch.item()) +
                # print("epoch:" + str(epoch) + "step:" + str(batch) + " ms_loss:" + str(
                #      " ms_loss:" + str(loss_ms.asnumpy()))
                #print("free -m :", os.system("free -m"))
                print("epoch:" + str(epoch) + "batch:" + str(batch), psutil.Process(os.getpid()).memory_info().rss / 1024)
            batch += 1

结果如下:

epoch:0batch:0 838972.0
epoch:0batch:50 886256.0
epoch:0batch:100 877084.0
epoch:0batch:150 876308.0
epoch:0batch:200 879956.0
epoch:0batch:250 879788.0
epoch:0batch:300 882016.0
epoch:0batch:350 877404.0
epoch:0batch:400 879392.0
epoch:0batch:450 881336.0
epoch:0batch:500 882280.0
epoch:0batch:550 884076.0
epoch:0batch:600 885076.0
epoch:0batch:650 877680.0
epoch:0batch:700 881732.0
epoch:0batch:750 882808.0
epoch:0batch:800 884548.0
epoch:0batch:850 884164.0
epoch:0batch:900 880000.0
epoch:0batch:950 877096.0
epoch:0batch:1000 882108.0
epoch:0batch:1050 879740.0
epoch:0batch:1100 885292.0
epoch:0batch:1150 886488.0
epoch:0batch:1200 878544.0
epoch:0batch:1250 882800.0
epoch:0batch:1300 889340.0
epoch:0batch:1350 882784.0
epoch:0batch:1400 887128.0
epoch:0batch:1450 883168.0
epoch:0batch:1500 885100.0
epoch:0batch:1550 882144.0
epoch:0batch:1600 885100.0
epoch:0batch:1650 885168.0
epoch:0batch:1700 888200.0
epoch:0batch:1750 882752.0

进程占用内存就维持在了882M左右,并无泄漏。

@huanghui 使用相同的数据测试不会更新计算的梯度吧?

@huanghui 使用相同的数据测试不会更新计算的梯度吧?

@沐燕舟 不会

经CCB决策:
1、优化官网教程的写法,使之不会重复编译。
2、jit函数的不断缓存导致的泄漏,当前机制无法回收,转Q1需求。#I684RX

huanghui 添加了
 
ccb/rfc
标签
huanghui 任务类型Bug-Report 修改为RFC
huanghui 任务状态TODO 修改为ACCEPTED
huanghui 里程碑B-SIG-Compiler(已删除) 修改为未设置
huanghui 里程碑设置为IT-编译执行-TODO
huanghui 里程碑IT-编译执行-TODO 修改为未设置
huanghui 里程碑设置为B-SIG-Compiler(已删除)
xiangjiawei007 添加了
 
kind/net-sv-test
标签

登录 后才可以发表评论

状态
负责人
项目
里程碑
Pull Requests
关联的 Pull Requests 被合并后可能会关闭此 issue
分支
开始日期   -   截止日期
-
置顶选项
优先级
预计工期 (小时)
参与者(3)
6574850 irmo 1617097675
Python
1
https://gitee.com/mindspore/mindspore.git
git@gitee.com:mindspore/mindspore.git
mindspore
mindspore
mindspore

搜索帮助