查看原文
其他

老衲这里有七条炼丹经验传授与你

李中梁 机器学习算法工程师 2021-12-31

作者&编辑:李中梁

前言

用深度学习做图像分类任务也有近一年时间了,从最初模型的准确率只有60%到后来训练到80%,再到最后的90%+的准确率,摸索中踩了很多坑,也总结出了一些经验。现在将一些自己觉得非常实用的模型训练经验写下来作为记录,也方便后来者借鉴验证。

调参经验

  • 模型选择
    通常我会使用一个简单的CNN模型(这个模型一般包含5个卷积层)将数据扔进去训练跑出一个baseline,这一步工作主要是为了验证数据集的质量。如果这个模型训练结果很差就不要先调试模型,需要检查一下你的训练集数据,看看图像的质量,图像标签是否正确,模型的代码是否正确等等,否则就是在做无用功,毕竟:garbage in,garbage out


  • 超参数的选择
    调参是项技术活,调得好CVPR,调不好下海搬砖
    通常要选的超参数有卷积核大小和数目,批训练(batch size)大小,优化函数(optimizer),学习率等等,一般来说卷积核用3*3或者5*5,batch szie 用16或者32不会过拟合,optimizer用Adam(学习率建议用论文中默认的,我试过调整Adam的学习率,效果或都没有默认的好),激活函数用relu这个应该是大家的共识吧。还有就是先跑几百个epoch看loss的变化趋势。


  • 数据预处理
    训练数据对模型的影响是决定性的,提高训练数据的质量,就是在提高模型的准确率。
    图像预处理的时候一般我会抽出部分图像观察,对图像中的噪声进行滤波,图像标签要验证一下,其他的预处理就结合实际情况来看了。一般来说,数据清洗的工作占比是多于写模型的工作(通常是7:3)。


  • 数据增强
    数据增强已经是训练深度网络的常规操作了,这味丹药有利于增加训练数据量,减少网络过拟合程度,男女老少,居家旅行必备良药。
    常用的数据增强方法包括:图像缩放图像翻转图像裁剪图像色彩的饱和度、亮度和对比度变换
    海康威视在ImageNet上曾经用过PCA Jittering方的法,但是由于这个方法的计算量过大,我没有在自己的训练中使用过。他们还使用了有监督的数据增强的方法,有兴趣的同学可以研究一下。


  • 数据不平衡的处理
    如果训练数据中各类样本数目差距较大,很有可能会导致部分类别的准确率很低,从根本上解决样本不平衡的问题就是要把样本变平衡。
    一种是增加样本少的类别的图像数目,可以用上述数据增强的方法。
    另一种就是直接将样本多的类别图像数目减少,可以说是非常简单粗暴了。
    当然,也有人提出类别权重的方法,增加少样本在训练时的权重,间接地增强了图像数目。


  • 自己的数据生成器

    当任务变得复杂,数据规模变大时,框架提供的接口不能满足你的需求,这时你需要有自己的data generation function。例如,我使用keras时需要对输入图片进行多标签任务的训练,而keras本身不包含这样的接口,所以需要自己实现一个data generation function。通过查看官方文档和相关接口实现了一个多标签数据生成器(写这个数据生成器的时候被官方文档坑了一次,暂且不表,下次另起一文详说),代码如下:

# 训练集/测试集数据生成器,替换flow_from_directory()
def flow_from_2DList(directory=None, target_size=(256, 256), 
    color_mode='rgb', classes=None, class_mode='categorical', 
    batch_size=1, shuffle=True, seed=None, save_to_dir=None, 
    save_prefix='', save_format='png', follow_links=False, 
    subset=None, interpolation='nearest'):
    """   
    A DirectoryIterator yielding tuples of (x, y) 
    where x is a numpy array containing a batch of images 
    with shape (batch_size, *target_size, channels) and 
    y is a numpy array of corresponding labels.
    """
    # 每个epoch都要shuffle数据集
    random.shuffle(directory)

    # 参数初始化
    if directory is None:   # python函数的默认参数如果是list这种可变类型,
                            # 需要在函数体内进行初始化,
                            # 否则会在上次的结果后继续使用list
        directory = [ [ 99999 for x in range(4) ] for y in range(batch_size) ]

    list_len = len(directory)
    print('\nlength of directory:', list_len, '\n\n')
    print('\nbatch_size:', batch_size, '\n\n')
    step = list_len//batch_size   # 向下取整得到一个epoch需要多少个step
    print('\nsetp:',step,'\n\n')

    for i in range(step):
        # 每行一个记录读取训练/测试数据,返回(x,[y1,y2,y3])
        batch_images = []

        y_label_age = np.zeros((batch_size, 100))
        y_label_sex = np.zeros((batch_size, 2))
        y_label_sick = np.zeros((batch_size, 2))

        batch_directory = directory[i*batch_size : (i+1)*batch_size].copy()

        batch_size_num = 0 # 循环计数器

        for record in batch_directory:
            file_path = record[0]
            image = cv2.imread(file_path)
            image = cv2.resize(image, target_size)

            batch_images.append(image)

            age = record[1]
            sex = record[2]
            sick = record[3]

            # 将age,sex,sick转换成one-hot编码         
            if age != 0:
                age -= 1
            age = to_categorical(age, num_classes = 100)

            sex = to_categorical(sex-1, num_classes = 2)   
            sick = to_categorical(sick-1, num_classes = 2)

            y_label_age[batch_size_num,:] = age
            y_label_sex[batch_size_num,:] = sex
            y_label_sick[batch_size_num,:] = sick

            batch_size_num += 1

        batch_images = np.array(batch_images)
        y_labels = [y_label_age, y_label_sex, y_label_sick]
        data = (batch_images, y_labels)
        yield data


  • 其他提示
    当然,具体任务不同可能某些经验不能适用,实践是检验真理的唯一标准,祝大家炼丹愉快~

与我交流

github:  https://github.com/keloli

blog:     https://www.jianshu.com/u/d055ee434e59




往期回顾之作者李中梁

【1】【TPAMI重磅综述】 SIFT与CNN的碰撞:万字长文回顾图像检索任务十年探索历程(上篇)

【2】【TPAMI重磅综述】 SIFT与CNN的碰撞:万字长文回顾图像检索任务十年探索历程(下篇)

【3】超快速!10分钟入门Keras指南








机器学习算法工程师


                            一个用心的公众号

长按,识别,加关注

进群,学习,得帮助

你的关注,我们的热度,

我们一定给你学习最大的帮助






: . Video Mini Program Like ,轻点两下取消赞 Wow ,轻点两下取消在看

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存