查看原文
其他

从零开始行人重识别

酱油妹 OpenCV学堂 2020-02-04

点击上方↑↑↑“OpenCV学堂”关注我

教你一步一步如何运行行人重识别代码

作者知乎: https://zhuanlan.zhihu.com/p/50387521

序言

探索了行人特征的基本学习方法。在这个实践中,我们将会学到如何一步一步搭建简单的行人重识别系统。欢迎任何建议。


pytorch源码

https://github.com/layumi/Person_reID_baseline_pytorch

行人重识别可以看成为图像检索的问题。给定一张摄像头A拍摄到的查询图像,我们需要找到这个人在其他摄像头下的图像。行人重识别的核心在于如何找到有鉴别力的行人表达。很多近期的方法使用了深度学习模型来抽取视觉特征,达到了SOTA的结果。

需要安装的软件


Python 3.6
GPU Memory >= 6G
Numpy
Pytorch 0.3+ (http://pytorch.org/)
Torchvision from the source


git clone https://github.com/pytorch/vision
cd vision
python setup.py install


Part.1训练

你可能注意到下载下来的数据集是如下分布的:

├── Market/
│   ├── bounding_box_test/          /* Files for testing (candidate images pool)
│   ├── bounding_box_train/         /* Files for training
│   ├── gt_bbox/                    /* We do not use it
│   ├── gt_query/                   /* Files for multiple query testing
│   ├── query/                      /* Files for testing (query images)
│   ├── readme.txt

那么现在打开刚刚下载的代码prepare.py。将第五行的地址改为你本地的地址,比如 \home\zzd\Download\Market,然后在终端中跑一下。

python prepare.py

我们在下载的文件夹中创建了一个子文件夹叫 pytorch

├── Market/
│   ├── bounding_box_test/          /* Files for testing (candidate images pool)
│   ├── bounding_box_train/         /* Files for training
│   ├── gt_bbox/                    /* We do not use it
│   ├── gt_query/                   /* Files for multiple query testing
│   ├── query/                      /* Files for testing (query images)
│   ├── readme.txt
│   ├── pytorch/
│       ├── train/                   /* train
│           ├── 0002
|           ├── 0007
|           ...
│       ├── val/                     /* val
│       ├── train_all/               /* train+val      
│       ├── query/                   /* query files  
│       ├── gallery/                 /* gallery files

跑完之后,在pytorch的每个子文件夹中,图像都是按ID来排列的。

现在我们已经成功准备好了图像来做后面的训练了。

快速问答:prepare.py 是如何识别同ID的图像?

+ Quick Question. How to recognize the images of the same ID?

对于Market1501这个数据集而言,图像的文件名中就包含了 ID label 和 CameraID, 具体命名可在这个链接看到here.


Part 1.2: Build Neural Network (model.py)

我们可以利用预训练的模型。普遍来说,利用ImageNet预训练的网络能达到更好的结果,因为它保留了一些好的特征。

在pytorch里,我们可以通过两行代码来引入他们。

from torchvision import models
model = models.resnet50(pretrained=True)

你可以使用下面这行代码来简单检查网络结构

print(model)

但在实际使用中,我们需要修改网络。因为Market1501中有751个种类(不同的人)。而不是像ImageNet一样有1000类。所以我们需要改变我们的模型来训练我们的分类器。

import torch
import torch.nn as nn
from torchvision import models

# Define the ResNet50-based Model
class ft_net(nn.Module):
   def __init__(self, class_num = 751):
       super(ft_net, self).__init__()
       #load the model
       model_ft = models.resnet50(pretrained=True)
       # change avg pooling to global pooling
       model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
       self.model = model_ft
       self.classifier = ClassBlock(2048, class_num) #define our classifier.

   def forward(self, x):
       x = self.model.conv1(x)
       x = self.model.bn1(x)
       x = self.model.relu(x)
       x = self.model.maxpool(x)
       x = self.model.layer1(x)
       x = self.model.layer2(x)
       x = self.model.layer3(x)
       x = self.model.layer4(x)
       x = self.model.avgpool(x)
       x = torch.squeeze(x)
       x = self.classifier(x) #use our classifier.
       return x
+ Quick Question. Why we use AdaptiveAvgPool2d? What is the difference between the AvgPool2d and AdaptiveAvgPool2d?
+ Quick Question. Does the model have parameters now? How to initialize the parameter in the new layer?
  • 快速问题

  1. 为什么我们使用AdaptiveAvgPool2d? AvgPool2d和 AdaptiveAvgPool2d区别在哪里?

  2. 模型现在有参数么?我们怎么初始化参数?

更多细节在 model.py中. 你可以等看完这个实践再回过头去看一下代码。

Part 1.3: 训练 (python train.py)

好的。现在我们准备好了训练数据 和定义好的网络结构。
我们可以输入如下命令开始训练:

python train.py --gpu_ids 0 --name ft_ResNet50 --train_all --batchsize 32  --data_dir your_data_path
--gpu_ids which gpu to run.
--name the name of the model.
--data_dir the path of the training data.
--train_all using all images to train.
--batchsize batch size.
--erasing_p random erasing probability.

让我们来看一下train.py.当中我们做了什么。第一件事情是如何读数据和他们的label. 我们使用了 torch.utils.data.DataLoader, 可以获得两个迭代器dataloaders['train'] and dataloaders['val'] 来读数据.

image_datasets = {}
image_datasets['train'] = datasets.ImageFolder(os.path.join(data_dir, 'train'),
                                         data_transforms['train'])
image_datasets['val'] = datasets.ImageFolder(os.path.join(data_dir, 'val'),
                                         data_transforms['val'])

dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
                                            shuffle=True, num_workers=8) # 8 workers may work faster
             for x in ['train''val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train''val']}

以下则是主要的代码来训练模型。是的,一共只有20行,但确保你要理解每一行。

# Iterate over data.
           for data in dataloaders[phase]:
               # get a batch of inputs
               inputs, labels = data
               now_batch_size,c,h,w = inputs.shape
               if now_batch_size<opt.batchsize: # skip the last batch
                   continue
               # print(inputs.shape)
               # wrap them in Variable, if gpu is used, we transform the data to cuda.
               if use_gpu:
                   inputs = Variable(inputs.cuda())
                   labels = Variable(labels.cuda())
               else:
                   inputs, labels = Variable(inputs), Variable(labels)

               # zero the parameter gradients
               optimizer.zero_grad()

               #-------- forward --------
               outputs = model(inputs)
               _, preds = torch.max(outputs.data, 1)
               loss = criterion(outputs, labels)

               #-------- backward + optimize --------
               # only if in training phase
               if phase == 'train':
                   loss.backward()
                   optimizer.step()
+ Quick Question. Why we need optimizer.zero_grad()? What happens if we remove it?
+ Quick Question. The dimension of the outputs is batchsize*751. Why?
  • 快速问答。
    为什么我们需要optimizer.zero_grad() ?如果我们去掉这一行会发生什么?
    输出的维度是batchsize*751. 为什么?

if epoch%10 == 9:
                   save_network(model, epoch)
               draw_curve(epoch)

每十轮,我们会保存网络和更新loss曲线。可以去看看这两个函数具体怎么写。


Part.2 测试

Part 2.1: 特征提取 (python test.py)

这一部分, 我们载入我们刚刚训练的模型 来抽取每张图片的视觉特征

python test.py --gpu_ids 0 --name ft_ResNet50 --test_dir your_data_path  --batchsize 32 --which_epoch 59

--gpu_ids which gpu to run.
--name the dir name of the trained model.
--batchsize batch size.
--which_epoch select the i-th model.
--data_dir the path of the testing data.

让我们看看我们在 test.py中做了什么。首先,我们需要载入模型的结构,然后载入weight。

model_structure = ft_net(751)
model = load_network(model_structure)

对于每张查询图片(query)和 查询库图像(gallery),我们抽取特征通过简单的前向传播.

outputs = model(input_img)
# ---- L2-norm Feature ------
ff = outputs.data.cpu()
fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
ff = ff.div(fnorm.expand_as(ff))
+ Quick Question. Why we flip the test image horizontally when testing? How to fliplr in pytorch?
+ Quick Question. Why we L2-norm the feature?

Part 2.2: 评测

是的,现在我们有了每张图片的特征。我们需要做的事情只有用特征去匹配图像。

python evaluate_gpu.py

让我们看看我们在 evaluate_gpu.py做了什么. 我们将图像按他们的相似度排序。

query = qf.view(-1,1)
# print(query.shape)
score = torch.mm(gf,query) # Cosine Distance
score = score.squeeze(1).cpu()
score = score.numpy()
# predict index
index = np.argsort(score)  #from small to large
index = index[::-1]

注意到有两种图像我们不把他们考虑为true-matches

  • 一种是Junk_index1 错误检测的图像,主要是包含一些人的部件。

  • 一种是Junk_index2 相同的人在同一摄像头下,按照reid的定义,我们不需要检索这一类图像。

query_index = np.argwhere(gl==ql)
   camera_index = np.argwhere(gc==qc)
   # The images of the same identity in different cameras
   good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
   # Only part of body is detected.
   junk_index1 = np.argwhere(gl==-1)
   # The images of the same identity in same cameras
   junk_index2 = np.intersect1d(query_index, camera_index)

我们可以使用 compute_mAP 来计算最后的结果. 在这个函数中,我们忽略了junk_index带来的影响。

CMC_tmp = compute_mAP(index, good_index, junk_index)

Part.3一个简单的可视化程序

可视化结果,

python demo.py --query_index 777

--query_index which query you want to test. You may select a number in the range of 0 ~ 3367.

代码类似 evaluate.py. 我们加入了可视化的部分。

try# Visualize Ranking Result
   # Graphical User Interface is needed
   fig = plt.figure(figsize=(16,4))
   ax = plt.subplot(1,11,1)
   ax.axis('off')
   imshow(query_path,'query')
   for i in range(10): #Show top-10 images
       ax = plt.subplot(1,11,i+2)
       ax.axis('off')
       img_path, _ = image_datasets['gallery'].imgs[index[i]]
       label = gallery_label[index[i]]
       imshow(img_path)
       if label == query_label:
           ax.set_title('%d'%(i+1), color='green'# true matching
       else:
           ax.set_title('%d'%(i+1), color='red'# false matching
       print(img_path)
except RuntimeError:
   for i in range(10):
       img_path = image_datasets.imgs[index[i]]
       print(img_path[0])
   print('If you want to see the visualization of the ranking result, graphical user interface is needed.')

Part.4:轮到你啦

  • Market-1501 是一个在清华大学夏天收集的数据集.
    让我们试试另一个数据集 DukeMTMC-reID, 是在Duke大学冬天采集的。
    你可以在这里 Here 下到数据集. 试试去训练这个数据集
    这个数据集和Market类似. 你可以 Here 看SOTA的结果

  • Quick Question. Could we directly apply the model trained on Market-1501 to DukeMTMC-reID? Why?
    快速问答。我们能直接用Market训好的模型放到DukeMTMC-reID上测试么? 为什么?

试试 Triplet Loss. Triplet loss是另一种广泛使用的目标函数. 你可以看看

https://github.com/layumi/Person-reID-triplet-loss. 

我把代码风格和本实践保持了一致, 你可以看看我改了什么.

Part.5 :其它相关工作

我们可以使用语句描述来找人么? 看这篇论文吧

https://arxiv.org/pdf/1711.05535.pdf

我们也可以用其他loss来进一步提升结果 (比如contrastive loss) ? 看看 this paper.

https://arxiv.org/abs/1611.05666

Person-reID 数据集不够大? You may check this paper and try some data augmentation method like random erasing.

https://arxiv.org/abs/1701.07717

行人检测的不好? 试试 Open Pose和 Spatial Transformer 来对其图像。

https://github.com/CMU-Perceptual-Computing-Lab/openpose
https://github.com/layumi/Pedestrian_Alignment


Reference

[1] Deng, Jia, Wei Dong, Richard Socher, Li-Jia Li, Kai Li, and Li Fei-Fei. "Imagenet: A large-scale hierarchical image database." In Computer Vision and Pattern Recognition, 2009. CVPR 2009. IEEE Conference on, pp. 248-255. Ieee, 2009.

往期精选

告诉大家你 在看

    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存