Create YOLOv3 using PyTorch from scratch (Part-4)

In this post we create 3 essential tools in the object detection task: IoU (Intersection-over-Union), NMS (Non-Maximum suppression) and mAP (mean Average Precision).
Create YOLOv3 using PyTorch from scratch (Part-4)

1. Overview

This is Part-4 of the series on building a YOLOv3 model from scratch.

Here is an overview of the series:

  1. Understand the YOLO model.
  2. Build the model backbone.
  3. Load pre-trained weights.
  4. Get the tools ready: this post.

    Before getting into the training part, it might be helpful to first get some utilities ready, including the codes to compute IoU (Intersection-Over-Union), NMS (Non-Maximum suppression) and mAP (mean Average Precision).

  5. Training data preparation.
  6. Train the model.

2. Intersection-Over-Union (IoU)

2.1. Concept

IoU is a common metric in object detection tasks. It is the ratio of, as its name suggests, intersection over union. Figure 1 gives an illustration.

iou_schematic.png

Figure 1: Schematic of IoU between 2 boxes. The intersection region is labeled “Inter”. Union is denoted by hatching.

2.2. Python code

The implementation is also fairly simple. Code first:

def compute_IOU(bbox1, bbox2, x1y1x2y2=False, change_enclose=False):
    '''Compute IoU of 2 groups of bounding boxes

    Args:
        bbox1 (ndarray or tensor): bounding box 1, in either [x_center, y_center, w, h]
            format, or [x1, y1, x2, y2] format.
        bbox2 (ndarray or tensor): bounding box 2, in either [x_center, y_center, w, h]
            format, or [x1, y1, x2, y2] format.
    Keyword Args:
        x1y1x2y2 (bool): If True, assumes bbox1 and bbox2 are in [x1, y1, x2, y2] format.
        change_enclose (bool): If True, when either box is enclosed by the other,
            change iou to 1.
    Returns:
        iou (ndarray or tensor): intersection over union ratios.
    '''

    if not x1y1x2y2:
        bbox1 = to_x1y1x2y2(bbox1)
        bbox2 = to_x1y1x2y2(bbox2)

    b1x1, b1y1, b1x2, b1y2 = bbox1[:,0], bbox1[:,1], bbox1[:,2], bbox1[:,3]
    b2x1, b2y1, b2x2, b2y2 = bbox2[:,0], bbox2[:,1], bbox2[:,2], bbox2[:,3]

    if isinstance(bbox1, torch.Tensor):
        # intersection coordinates
        inter_x1 = torch.max(b1x1, b2x1)
        inter_x2 = torch.min(b1x2, b2x2)
        inter_y1 = torch.max(b1y1, b2y1)
        inter_y2 = torch.min(b1y2, b2y2)

        # intersection area
        inter = torch.clamp(inter_x2 - inter_x1, 0, None) * torch.clamp(inter_y2 - inter_y1, 0, None)
    else:
        # intersection coordinates
        inter_x1 = np.maximum(b1x1, b2x1)
        inter_x2 = np.minimum(b1x2, b2x2)
        inter_y1 = np.maximum(b1y1, b2y1)
        inter_y2 = np.minimum(b1y2, b2y2)

        # intersection area
        inter = np.clip(inter_x2 - inter_x1, 0, None) * np.clip(inter_y2 - inter_y1, 0, None)

    w1, h1 = (b1x2 - b1x1) , (b1y2 - b1y1)
    w2, h2 = (b2x2 - b2x1) , (b2y2 - b2y1)
    area1 = w1 * h1
    area2 = w2 * h2

    # union area
    union = area1 + area2 - inter

    # iou
    iou = inter / (union + 1e-7)

    if change_enclose:
        # check if box1 inside box2
        idx = abs(area1 / (area2 + 1e-6) - iou) <= 1e-3
        iou[idx] = (inter / (area1 + 1e-6))[idx]

        # check if box2 inside box1
        idx = abs(area2 / (area1 + 1e-6) - iou) <= 1e-3
        iou[idx] = (inter / (area2 + 1e-6))[idx]

    return iou

Some points to note:

  • We put this compute_IOU() function in the utils.py script in the YOLOv3_pytorch project folder. For a structure of the folder, refer back to the Create the Darknet-53 model section of part-2.
  • We assume the input arguments bbox1 and bbox2 are bounding boxes as either numpy arrays or torch tensors, and have columns of [xcenter, ycenter, width, height], or [x1, y1, x2, y2].
  • Sometimes, a small bounding box may be entirely enclosed by a larger one, then the IoU will be the areal ratio between the 2, and could be a fairly small number. This may cause trouble when doing the NMS filtering. We will see why it is an issue in the next section. But to prevent such an issue, we manually change the IoU values of such cases to a number close to 1, if the change_enclose input argument is set to True:

        if change_enclose:
            # check if box1 inside box2
            idx = abs(area1 / (area2 + 1e-6) - iou) <= 1e-3
            iou[idx] = (inter / (area1 + 1e-6))[idx]
    
            # check if box2 inside box1
            idx = abs(area2 / (area1 + 1e-6) - iou) <= 1e-3
            iou[idx] = (inter / (area2 + 1e-6))[idx]
    
    • To convert bounding boxes in [xcenter, ycenter, width, height] format to [x1, y1, x2, y2], we use a to_x1y1x2y2() function:

          def to_x1y1x2y2(bbox):
              '''Convert from [x_center, y_center, w, h] -> [x1, y1, x2, y2]
              '''
              x1 = bbox[:,0] - bbox[:,2] / 2
              x2 = bbox[:,0] + bbox[:,2] / 2
              y1 = bbox[:,1] - bbox[:,3] / 2
              y2 = bbox[:,1] + bbox[:,3] / 2
      
              if isinstance(bbox, torch.Tensor):
                  return torch.vstack([x1, y1, x2, y2]).T
              else:
                  return np.c_[x1, y1, x2, y2]
      

3. Non-Maximum suppression (NMS)

3.1. Concept

NMS is often used to remove duplicate predictions, where multiple similar bounding boxes are predicted around a same target. Figure 2 gives an illustration of the process.

The steps of NMS:

  1. Sort the predictions by their confidence scores, from highest to lowest, and put them into a “waiting list”. In the example of Figure 2, we have:

    box index confidence score
    1 0.99
    4 0.98
    2 0.90
    3 0.81
    5 0.70
    6 0.50
  2. Select the top-ranking bounding box (box-1), remove it from the “waiting list” and put it into the result list.
  3. Compute IoU scores between the just selected bounding box with all remaining ones in the “waiting list”. Remove boxes with IoU scores greater than a prescribed threshold (typically 0.5) from the “waiting list”.

    In the example of Figure 2, let’s assume that box-2, and -3 overlap with box-1 with IoUs >= 0.5, so they are removed. Box-4, -5 and -6 have little to none overlap with box-1, so they stay in the “waiting list”. This makes sense since we are picking a “winner” predictor by suppressing similar competitors, but leaving those irrelevant predictors alone.

  4. Repeat step-2 and -3 until the “waiting list” is empty.

nms_schematic_2.png

Figure 2: Schematic of the NMS process. Confidence score of each bounding box is shown in parentheses in the box label.

In the example in Figure 2, after having selected box-1 and removed box-2 and -3, we pick box-4 as the next top-ranking box. Then we compute the IoUs between box-4 with -5 and -6.

Notice that box-6 is entirely enclosed by box-4, so the “original” IoU ratio will be fairly small, and there is a high chance that it is below the 0.5 threshold, and thus doesn’t get removed.

That’s why when enclosing bounding boxes are detected, we manually change the IoU to the ratio of intersection over the area of the smaller box, which essentially gives an IoU of 1.0.

3.2. Python code for NMS()

Code first:

def NMS(predictions, conf_thres, iou_thres, verbose=True):
    '''Non-maximum suppression on predictions of a single image

    Args:
        predictions (ndarray): bbox predictions. In shape (n, 5+n_classes).
            First 5 columns are: [x_c, y_c, w, h, conf_score]. Last n_classes
            columns are classification predictions.
        conf_thres (float): filter out detections with object confidence score
            lower than this.
        iou_thres (float): float in range [0, 1]. Suppress bboxes with overlaps
            greater than this ratio.
    Returns:
        results (ndarray): filtered detections, in shape (m, 6). m is the number
            of filtered detections. Columns are: [x_c, y_c, w, h, conf_score, class_idx].
            class_idx is np.argmax(results[:,5:], axis=1).
    '''

    results = []
    conf_idx = 4

    if verbose:
        print('Number of predictions before confidence filtering:', len(predictions))

    # select by conf
    idx = predictions[:, conf_idx] >= conf_thres
    predictions = predictions[idx]
    if len(predictions) == 0:
        return results

    if verbose:
        print('Number of predictions after confidence filtering:', len(predictions))

    # conditional on objectness
    predictions[:, 5:] *= predictions[:, 4:5]

    # select by class prob. idx_box select bboxes, idx_cls select classes
    idx_box, idx_cls = np.where(predictions[:, 5:] >= conf_thres)
    boxes = predictions[idx_box, :4]
    confs = predictions[idx_box, idx_cls+5]
    clss = idx_cls[:, None]
    pred = np.c_[boxes, confs, clss]  # [m, 6]

    if verbose:
        print('Number of predictions after class prob filtering:', len(pred))
    if len(pred) == 0:
        return results

    # sort by conf
    idx = np.argsort(pred[:, conf_idx])[::-1]
    pred = pred[idx]
    # NMS
    while True:
        p1, pred = pred[0:1], pred[1:]
        results.append(p1)
        if len(pred) == 0:
            break

        # compute iou with others
        ious = compute_IOU(p1[:4], pred[:, :4], x1y1x2y2=False, change_enclose=True)
        # remove overlaps
        pred = pred[ious < iou_thres]

    results = np.vstack(results)

    return results

Some points:

  • The low confident predictions are first removed by a conf_thres threshold. This will drastically reduced the number of candidates to go through. Recall that with an input size of 416 x 416 and 3 size scales, YOLOv3 outputs (52 * 52 + 26 * 26 + 13 * 13) * 3 = 10647 predictions per image. With a confidence thresholding we can reduce that down to about a hundred, or even fewer, depending the chosen threshold.
  • The classification probabilities are conditioned on objectness confidence score:

      predictions[:, 5:] *= predictions[:, 4:5]
    

3.3. Python code for batch_NMS()

The above NMS() function works on a single image. We then create another function that works on a batch:

def batch_NMS(predictions, conf_thres, iou_thres):
    '''Do Non-maximum suppression on a batch

    Args:
        predictions (ndarray): detection predictions, in shape (b, n, 5+n_classes).
            b: batch size, n: number of predictions.
        conf_thres (float): filter out detections with object confidence score
            lower than this.
        iou_thres (float): float in range [0, 1]. Suppress bboxes with overlaps
            greater than this ratio.
    Returns:
        results (ndarray): filtered detections for each image in the batch.
            Shape (m, 7). m is number of detections after filtering.
            Columns are: [batch_idx, x_c, y_c, w, h, conf_score, class_idx].
    '''

    results = []
    #---------------Loop through images---------------
    for ii, pii in enumerate(predictions):
        resii = NMS(pii, conf_thres, iou_thres)
        if len(resii) > 0:
            # prepend batch idx to 1st column
            resii = np.c_[np.full([len(resii), 1], ii), resii]
            results.append(resii)

    if len(results):
        results = np.vstack(results)
    return results

It simply iterates through the images in a batch, and calls NMS() on each. If anything is remained after the filtering, we prepend it with an extra column containing the batch indices.

4. mean Average Precision (mAP)

I personally find the concept of mAP rather confusing. I list a few references I used below:

And here are another 2 that I believe are inaccurate (at least):

The key point, I think, is how to get the Precision-Recall curve. But let’s start from the beginning.

4.1. mAP is the average of APs, and AP is parameterized on IoU threshold

mAP (mean Average Precision), is an averaged score across multiple classes. E.g. across the 80 different classes in the COCO detection dataset. For each class, we compute an AP (Average Precision), and the results of all classes are averaged to get the mAP.

The AP score is parameterized by the IoU threshold. For instance, the “traditional” way, also the way used in PASCAL VOC metric, uses a IoU threshold of 0.5. Thus computed AP is denoted as AP@IoU=0.5.

Instead of taking a single IoU=0.5 threshold, one could take multiple IoU thresholds and average the results. For instance, repeat the similar computations at 10 different IoU thresholds, starting from 0.5 to 0.95, with a step of 0.05. The averaged final result is denoted AP@[0.5:0.05:0.95]. This is the so-called COCO primary challenge metric. But let’s leave this multiple-IoU threshold thing to the very end.

Note that the IoU threshold here is NOT the same thing as the threshold used in the NMS process:

  • The IoU threshold in the NMS process is used to flag duplicates.
  • the IoU threshold for AP computation is used to flag true positive and false positive predictions.

4.2. True and false positive predictions, precision and recall

So how do we decide whether a prediction is true or false?

Recall that AP is conditioned on a specific class, so the predicted object has to be of the correct class.

Then the localization of the prediction has to be reasonably good. This is measured by the IoU score: if the predicted bounding box overlaps with a ground truth label of the specific class at IoU >= IoU_threshold, then it is flagged as a true positive prediction. Otherwise, it is a false positive prediction.

An example is given in Figure 3. The predicted boxes have gone through the NMS process, and there are 4 boxes left: box-1, -2, -7 and -8.

AP_schematic.png

Figure 3: Demo for AP computation for the “person” class. 2 ground truth labels are given as “label-1” and “label-2”. 4 preditions are given as “box-1”, “box-4”, “box-7” and “box-8”.

Here, we are looking at the AP for the “person” class, so box-7 that predicts a laptop is irrelevant, and we ignore it.

There are 2 ground truth “person” labels in the this image. Box-1 is a close match with label-1, so it is a true positive. Similarly for box-4.

Box-8 also predicts a “person”, but by too poor a localization (IoU with label-2 < the IoU threshold, which, for instance, was set to 0.5). So box-8 is flagged as false positive.

But how do we know box-1 should be compared against label-1, and box-8 against label-2? In fact, we compare each prediction with all labels of class “person”, and chose the best match with the highest IoU.

Then, for this image and class “person”, we have 2 true positives: TP=2, and 1 false positive: FP=1.

So, precision and recall can be computed:

  • precision: P = TP / (TP + FP) = 2 / 3 = 0.666.
  • recall: R = TP / #labels = 2 / 2 = 1.

4.3. Area Under the P-R Curve (AUC)

Precision and recall are 2 competing scores, and typically one has to scarifies one for another. Only a hypothetically ideal predictor can achieve precision = recall = 1.

To measure the combined effect of precision and recall, we can compute the Area Under the Curve (AUC), with Precision (P) on the y-axis, and Recall (R) on the x-axis. And this area is defined as the AP for this class.

So we need to first construct such a P-R curve.

Here is where I think the 2 posts I gave above are inaccurate: they seem to suggest that one constructs a P-R curve by iterating through a range of IoU thresholds. Doing that does give us a list of (P, R) pairs, but that is in contradiction to the fact one can compute the AUC, and the AP score, at a fixed IoU=0.5 threshold.

So I believe the following method is correct:

  1. Pick an IoU threshold level, e.g. IoU=0.5.
  2. Pick a class, e.g. “person”.
  3. Get an image been predicted, and its ground truth label. Select detections of the “person” class, and labels of the “person” class.
  4. Decide each of the detections is true or false positive, as done in the previous sub-section.
  5. Repeat 3-4 for all the images to be evaluated.
  6. Put results in a table like below:
Table 1: Example AP computing table.
conf score (float) TP (1 or 0) FP (1 or 0) Accum TP (int) Accum FP (int) Precision (float) Recall (float)
0.98 1 0 1 0 1 / 1 1 / N
0.88 1 0 2 0 2 / 2 2 / N
0.80 0 1 2 1 2 / 3 2 / N
0.77 0 1 2 2 2 / 4 2 / N
0.74 1 0 3 2 3 / 5 3 / N
0.62 1 0 4 2 4 / 6 4 / N
0.55 0 1 4 3 4 / 7 4 / N
0.54 1 0 5 3 5 / 8 5 / N
0.44 0 1 5 4 5 / 9 5 / N

(Note that I’m filling the table with made-up numbers).

Some explanations about the table:

  • conf score is the objectness confidence score of each prediction. And the table rows are sorted by this score, from high to low.
  • TP and FP are binary flags of the predictions, judged by the IoU score being above the 0.5 threshold or not.
  • Accum TP is the accumulative sum of TP, similarly for Accum FP.
  • Precision is computed as Accum TP / (Accum TP + Accum FP).
  • Recall is computed as Accum TP / N, where N is the total number labels for the “person” class within these evaluated images.

Then, the last 2 columns, Precision and Recall, give us the desired P-R curve. We only need to prepend the sequence with a beginning point of \((recall=0, precision=1)\).

When plotted out, the curve looks like the blue curve in Figure 4 blow (suppose N = 15.)

Notice that there are zig-zag patterns in the curve. To smooth these out, an 11-point interpolation method was devised, and it works as follows:

Divide the recall range of [0, 1] to 11 equally spaced intervals: 0, 0.1, 0.2, ... 1.

At each interpolated recall level \(r \in \{0, 0.1, \cdots, 1\}\), get an interpolated precision using:

\[ P_{interp}(r) = \max_{\widetilde{r}, \widetilde{r} \ge r} P(\widetilde{r}) \]

where \(P(\widetilde{r})\) is the precision value taken at recall level \(\widetilde{r}\) in the table above. So the above equation is taking the maximum observed precision level, at and to the right of a sampled recall level \(r\).

When plotted out, the 11-point interpolated P-R curve is the red curve in Figure 4.

Finally, the AP is computed as:

\[ AP = \frac{1}{11} \sum_{r \in \{0, 0.1, \cdots, 1\}} P_{interp}(r)) \]

ap_demo.png

Figure 4: Demo P-R curve. Blue curve is the computed P-R values taken from Table 1 (assuming N = 15). Red curve is the 11-point interpolated curve.

Below is the Python code to generate Figure 4:

import numpy as np
import matplotlib.pyplot as plt

tp = np.array([1, 1, 0, 0, 1, 1, 0, 1, 0]) # true positive flags
fp = 1 - tp                                # false positive flags
acc_tp = np.cumsum(tp)                     # accumulative sum
acc_fp = np.cumsum(fp)

N = 15                                     # total number of labels
pre = acc_tp / (acc_tp + acc_fp)           # precision
rec = acc_tp / N                           # recall

pre = np.r_[1., pre]                       # add a beginning point
rec = np.r_[0., rec]

# do 11-point interpolation
rec_interp = np.linspace(0, 1, 11, endpoint=True)
pre_interp = []
for rii in rec_interp:
    p_on_right = pre[rec>=rii]
    if len(p_on_right) > 0:
        pii = np.max(p_on_right)
    else:
        pii = 0
    pre_interp.append(pii)

#-------------------Plot------------------------
figure = plt.figure()
ax = figure.add_subplot(1,1,1)

ax.plot(rec, pre, 'b-o', ms=5, label='P-R')
ax.step(rec_interp, pre_interp, where='post', color='r')
ax.plot(rec_interp, pre_interp, marker='^', ms=8, ls='none', color='r', label='P-R (11-point interp)')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_ylim(-0.2, 1.2)
ax.grid(True, axis='both')
ax.legend()
figure.show()

4.4. Python code for AP computation

Code first:

def compute_AP(pred, label, class_idx, conf_thres=0.25, iou_thres=0.5, interp=11):
    '''Compute Average Precision for a class

    Args:
        pred (ndarray): predictions, in shape (n, 7). n is the number of detections.
            The columns are: [batch_idx, x_c, y_c, w, h, conf, clss].
        label (ndarray): ground truth labels, in shape (m, 6). m is the number
            of labels. The columns are: [batch_idx, x_c, y_c, w, h, clss].
        class_idx (int): target class idx.
    Keyword Args:
        conf_thres (float): optionally filter prediction by confidence score.
        iou_thres (float): IoU score threshold to flag true/false positives.
        interp (int or None): if None, compute Area Under Curve using trapz
            integration. If 11, smooth the P-R curve using 11-point interpolation.
    Returns:
        precision (ndarray): 1d array, precision values.
        recall (ndarray): 1d array, recall values.
        ap (float): average precision.
    '''

    # select class
    pred = pred[pred[:, -1] == class_idx]
    label = label[label[:, -1] == class_idx]

    if len(pred) == 0:
        return 0, 0, 0

    # filter by confidence
    conf_idx = np.where(pred[:, 5] >= conf_thres)[0]
    if len(conf_idx) == 0:
        # all false positive
        return 0, 0, 0

    npred = len(pred)
    nlabel = len(label)
    tp = np.zeros(npred, dtype='int')   # true positive flags
    label_visited = np.zeros(nlabel, dtype='int')  # record labels that have been matched with a detection

    # loop through detections
    for idxii in conf_idx:
        pii = pred[[idxii]]

        # select labels in same image
        labidx = np.where(label[:, 0] == pii[:, 0])[0]
        labii = label[labidx]
        if len(labii) == 0:
            # no matching label, false positive
            continue

        # compute iou with labels
        ious = compute_IOU(pii[:, 1:5], labii[:, 1:5], x1y1x2y2=False, change_enclose=False)
        # get the  closest label
        best_iou_idx = np.argmax(ious)
        best_iou = ious[best_iou_idx]
        all_labidx = labidx[best_iou_idx]

        if best_iou >= iou_thres and label_visited[all_labidx] == 0:
            tp[idxii] = 1                  # mark this prediction true positive
            label_visited[all_labidx] = 1  # mark this label as visited

    # build the AP computing table
    # sort by conf
    conf_idx = np.argsort(pred[:, 5])[::-1]
    pred = pred[conf_idx]
    tp = tp[conf_idx]
    fp = 1 - tp
    acc_tp = np.cumsum(tp)
    acc_fp = np.cumsum(fp)

    precision = acc_tp / (acc_tp + acc_fp + 1e-6)
    recall = acc_tp / nlabel

    # add starting point
    precision = np.r_[1., precision]
    recall = np.r_[0., recall]

    # do interpolation if needed
    if not interp:
        ap = np.trapz(precision, recall)
    elif interp == 11:
        rr = np.linspace(0, 1, 11, endpoint=True)
        ap = 0
        for rii in rr:
            pre_on_right = precision[recall >= rii]
            if len(pre_on_right):
                ap += np.max(pre_on_right)
        ap /= 11
    else:
        raise Exception("<interp> is None or 11.")

    return precision, recall, ap

Some explanations:

Remember that AP is class-specific, so we start by selecting the correct class. If nothing gets selected, then true positive detections, precision and AP are all 0:

# select class
pred = pred[pred[:, -1] == class_idx]
label = label[label[:, -1] == class_idx]

if len(pred) == 0:
    return 0, 0, 0

We initialize an 0-valued label_visited array to record the “visited” labels. This is because if a ground truth label is assigned to a prediction (by having the highest IoU score among all labels), it can not be assigned to a second prediction in the future, to avoid double counting.

So, when we flag a prediction as true positive, we also flag that matched label as “visited”:

# compute iou with labels
ious = compute_IOU(pii[:, 1:5], labii[:, 1:5], x1y1x2y2=False, change_enclose=False)
# get the  closest label
best_iou_idx = np.argmax(ious)
best_iou = ious[best_iou_idx]
all_labidx = labidx[best_iou_idx]

if best_iou >= iou_thres and label_visited[all_labidx] == 0:
    tp[idxii] = 1                  # mark this prediction true positive
    label_visited[all_labidx] = 1  # mark this label as visited

Also note that when calling compute_IOU(), we set the change_enclose flag to False. Because this is a different context than NMS filtering, and we do need a lower IOU score to properly reflect the poor matches in such enclosed cases.

4.5. Python code for mAP computation

Having got the AP computation code, it is trivial to get mAP: we simply loop through all classes to get an average:

def compute_mAP(pred, label, conf_thres=0.25, iou_thres=0.5, interp=11):
    '''Compute mean Average Precision across classes

    Args:
        pred (ndarray): predictions, in shape (n, 7). n is the number of detections.
            The columns are: [batch_idx, x_c, y_c, w, h, conf, clss].
        label (ndarray): ground truth labels, in shape (m, 6). m is the number
            of labels. The columns are: [batch_idx, x_c, y_c, w, h, clss].
    Keyword Args:
        conf_thres (float): optionally filter prediction by confidence score.
        iou_thres (float): IoU score threshold to flag true/false positives.
        interp (int or None): if None, compute Area Under Curve using trapz
            integration. If 11, smooth the P-R curve using 11-point interpolation.
    Returns:
        mAP (float): mean Average Precision.
    '''

    class_idx = np.unique(label[:, -1])
    class_aps = []
    # loop through classes
    for clsii in class_idx:
        _, _, apii = compute_AP(pred, label, clsii, conf_thres, iou_thres, interp)
        class_aps.append(apii)

    mAP = np.mean(class_aps)

    return mAP

4.6. Multiple IoU thresholds

Now it is the correct point to extend into multiple IoU threshold values.

Having got the mAP using IoU=0.5, we can iterate through the IoU range of [0.5, 0.95], with a step of 0.05. The average mAP across these 10 IoU levels is called mAP@[0.5:0.05:0.95]. Or, using the COCO naming convention, AP@[0.5:0.05:0.95]: they treat AP and mAP as synonyms (thanks COCO for the confusion).

5. Summary

In this post we create 3 essential tools in object detection task:

  1. Intersection-Over-Union (IoU): a goodness of match metric between a bounding box prediction and a ground truth label.
  2. Non-Maximum suppression (NMS): a method to filter out duplicate detections.
  3. mean Average Precision (mAP): a metric of the overall accuracy of a detection model.

All these will be used in later posts when we write the model-training code, and it is a good idea to store these functions in the utils.py script in the YOLOv3_pytorch project folder.

Author: guangzhi

Created: 2022-06-22 Wed 22:40

Validate

Leave a Reply