鉴于最近的空闲,我准备找个机会接触机器学习, 需求的一部分是识别特定场景中的扑克牌.

Convolutional Neural Network

Google 的机器学习高级课程 包含了大量动图, 对构建 CNN 的基础概念和架构很有效. 卷积用来提取局部特赠, ReLU 引入非线性, 池化减少数据量.

PyTorch 的图像分类入门 对代码层面有一个基础的认知, 同时也了解到训练是准备数据和正确答案.

Modern Face Recognition with Deep Learning 直观的讲述了人脸识别实际使用到具体技术, 对了解工业界有很好的帮助.

看完 PyTorch 的入门后, 我脑门上写着的大概是: 模型这么简单是因为这是个案例? 还是工业界的也长这个样子, 全靠数据?

Image Classification on ImageNet 展示了图像分类领域的领先者. PyTorch 正好提供了对应的代码 examples/imagenet, 让我可以看看这些模型到底长什么杨.

看上去对使用者来说, 确实数据是关键. 那么这些预训练好的模型对识别扑克牌的花色和点数效果好吗? 感觉应该不好. 使用少量的数据可以训练/微调模型吗? 不知道, 先试试吧. 稍微修改 examples/imagenet 使其能够运行在 macbook 的 GPU 后, 第一个问题出现了: 太慢了. 同时我对从零训练完全没有信心, 同时也感觉自己无法凑够那么多数据量.

OpenCV

微调现有的模型是否可以呢? 我脑回路比较大, 准备用 ocr 模型做底, 考虑到主要是字母, 所以准备用 EasyOCR. 在准备微调数据的过程中, 我发现 opencv 是不是也够了?

核心的点在于 opencv 能够找到轮廓, contour, 而且在这个场景中非常符合我的预期. 随后仅需要判断是否是矩阵和长宽比在一定范围内基本就能定位所有的扑克牌. contour.png

def is_rectangle(contour):
    # 使用多边形逼近
    epsilon = 0.1 * cv2.arcLength(contour, True)
    approx = cv2.approxPolyDP(contour, epsilon, True)

    # 判断是否有4个顶点
    if len(approx) == 4:
        # 检查是否是凸形
        if cv2.isContourConvex(approx):
            # 检查每个角度是否接近90度
            for i in range(4):
                pt1 = approx[i][0]
                pt2 = approx[(i + 1) % 4][0]
                pt3 = approx[(i + 2) % 4][0]
                
                # 向量 (pt1 -> pt2) 和 (pt2 -> pt3)
                v1 = pt2 - pt1
                v2 = pt3 - pt2

                # 计算向量之间的夹角
                angle = np.degrees(np.arccos(np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))))

                # 检查角度是否接近 90 度 (可以允许一定的误差,如±10度)
                if not (80 <= angle <= 100):
                    return False

            return True
    return False


if __name__ == '__main__':
    orig = cv2.imread("poker.png")
    _, thresh = cv2.threshold(cv2.cvtColor(orig, cv2.COLOR_BGR2GRAY), 150, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(image=thresh, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_NONE)

    def filter_contour(contour):
        _, _, w, h = cv2.boundingRect(contour)
        if w*h > 200 and h/w > 1.2 and h/w < 1.5:
            return is_rectangle(contour)
        return False

    contours = list(filter(filter_contour, contours))
    cv2.drawContours(orig, contours, -1, (0, 0, 255), 3)

    cv2.imwrite("contour.png", orig)
    cv2.imshow("contours", orig)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

在定位了扑克牌后, 我们需要识别点数和花色.

点数的识别仅需要在裁剪出左上部分后使用 ocr 即可. 为了处理少量不符合预期的场景, 我们在发现 easyocr 的识别结果可能不正确时, 会使用 paddleocr 再次识别.

valid_ocr_results = {
    "2": "2",
    "3": "3",
    "4": "4",
    "5": "5",
    "6": "6",
    "7": "7",
    "8": "8",
    "9": "9",
    "IO": "T",
    "J": "J",
    "Q": "Q",
    "K": "K",
    "A": "A",
}


def parse_rank(img):
    data = np.asarray(img)
    ocr = reader.readtext(data, detail=0)
    if len(ocr) == 1:
        if ocr[0] in valid_ocr_results:
            return valid_ocr_results[ocr[0]]

    ocr = paddle.ocr(data, cls=True)
    if len(ocr) == 1 and ocr[0] is not None:
        r = ocr[0][0][1][0]
        if r in valid_ocr_results:
            return valid_ocr_results[r]

    return ""

花色的识别一种思路是微调训练 ocr 模型, 但效果未知, 过程对我来说也有挑战. 另一种思路是继续利用 opencv 的轮廓识别. 裁剪出右下部分, 提取出轮廓和已知的四种花色进行对比.

def parse_suit(img):
    contours = find_suit(img)

    bestSuit = ""
    bestSimilarity = 0.03
    idx = len(contours)-1
    contour = contours[idx]
    for suit, scontour in suit_contours.items():
        similarity = cv2.matchShapes(contour, scontour, cv2.CONTOURS_MATCH_I1, 0)
        if similarity < bestSimilarity:
            bestSimilarity = similarity
            bestSuit = suit
    return bestSuit

上述三个步骤, 在固定场景下基本可以百分百的识别出扑克牌.

从这次的体验来说, 感受主要有几点:

  • 机器学习发展的即厉害又容易落地, 图像是其中一个重要而突出的领域
  • 显卡和数据是深入使用的主要拦路虎