别再手动切图了!用PyTorch从零搭建CRNN,搞定不定长文本识别(附完整代码)

张开发
2026/4/18 4:56:33 15 分钟阅读

分享文章

别再手动切图了!用PyTorch从零搭建CRNN,搞定不定长文本识别(附完整代码)
用PyTorch从零实现CRNN端到端不定长文本识别实战指南在票据识别、车牌检测等实际场景中我们常遇到文字长度不固定、排版复杂的图像。传统OCR方案需要先切割单字再分类流程繁琐且误差累积。本文将带你用PyTorch实现CRNN卷积循环神经网络这种端到端模型能直接输入整图输出文本省去切割步骤。我们会从数据生成、模型构建一直讲到训练技巧和API封装提供可直接运行的完整代码。1. CRNN架构解析与PyTorch实现1.1 为什么CRNN适合文本识别传统OCR流程需要精确的字符定位切割而CRNN通过三种核心组件实现端到端识别CNN部分使用7层卷积网络提取视觉特征最后输出512维的特征图。特别设计的1x2池化窗口保留宽度方向信息适应文本水平排列特性。RNN部分采用双向LSTM处理特征序列每个时间步对应原图的一个垂直切片。双向结构能同时利用前后文信息对3和8等相似字符区分效果显著。CTC层解决输入输出对齐问题允许模型输出比输入短的序列通过blank机制处理重复字符。例如hello可能被编码为h-e-l-l-o。class CRNN(nn.Module): def __init__(self, imgH, nc, nclass, nh): super(CRNN, self).__init__() # CNN结构 self.cnn nn.Sequential( nn.Conv2d(nc, 64, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d(2, 2), nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), nn.ReLU(True), nn.Conv2d(256, 256, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)), nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), nn.ReLU(True), nn.Conv2d(512, 512, 3, 1, 1), nn.ReLU(True), nn.MaxPool2d((2,2), (2,1), (0,1)), nn.Conv2d(512, 512, 2, 1, 0), nn.BatchNorm2d(512), nn.ReLU(True) ) # RNN结构 self.rnn nn.Sequential( BidirectionalLSTM(512, nh, nh), BidirectionalLSTM(nh, nh, nclass) )1.2 关键实现细节输入尺寸处理图像高度必须为16的倍数通过assert imgH % 16 0校验。典型输入为(1, 32, 160)的灰度图输出40个时间步的特征序列。双向LSTM实现自定义BidirectionalLSTM模块处理序列方向class BidirectionalLSTM(nn.Module): def __init__(self, nIn, nHidden, nOut): super(BidirectionalLSTM, self).__init__() self.rnn nn.LSTM(nIn, nHidden, bidirectionalTrue) self.embedding nn.Linear(nHidden * 2, nOut)特征序列转换CNN输出需从(b,c,h,w)转换为(w,b,c)的序列格式满足LSTM输入要求conv conv.squeeze(2) # 移除高度维度 conv conv.permute(2, 0, 1) # [w, b, c]2. 数据准备与增强策略2.1 生成不定长文本图像使用TextRecognitionDataGenerator工具创建训练数据关键参数配置参数说明典型值字体支持中文需添加.ttf文件SimSun, Arial背景模拟真实场景噪声/渐变/纹理变形增加多样性透视/旋转/模糊长度控制输出复杂度1-15个字符from PIL import Image, ImageDraw, ImageFont import random def generate_text_image(text, width160, height32): 生成单张文本图像 font ImageFont.truetype(simsun.ttf, 28) image Image.new(L, (width, height), color255) draw ImageDraw.Draw(image) # 随机位置扰动 x random.randint(0, max(0, width - len(text)*20)) y random.randint(0, height - 30) draw.text((x, y), text, fontfont, fill0) return image2.2 数据增强技巧椒盐噪声模拟扫描文档的噪点def add_noise(img, noise_level0.02): noisy img.copy() num_noise int(noise_level * img.size) coords [np.random.randint(0, i-1, num_noise) for i in img.shape] noisy[coords[0], coords[1]] 255 * (np.random.rand(num_noise) 0.5) return noisy弹性变形使用OpenCV的网格变换模拟手写变形模糊处理高斯模糊模拟失焦场景3. 训练流程与调优技巧3.1 CTC损失函数配置PyTorch的CTCLoss需要特别注意输入格式输入Log softmax后的概率矩阵 (T, N, C)目标标签序列的稀疏表示输入长度每个样本的序列长度目标长度每个标签的实际长度criterion nn.CTCLoss(blank0, reductionmean) optimizer torch.optim.Adam(model.parameters(), lr0.001) # 训练循环示例 for epoch in range(100): for batch_idx, (data, target, input_len, target_len) in enumerate(train_loader): output model(data) # (T, N, C) output F.log_softmax(output, dim2) loss criterion(output, target, input_len, target_len) optimizer.zero_grad() loss.backward() optimizer.step()3.2 学习率调度策略采用热启动(warmup)配合余弦退火scheduler torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr0.01, steps_per_epochlen(train_loader), epochs100, anneal_strategycos )3.3 常见问题解决梯度爆炸添加梯度裁剪nn.utils.clip_grad_norm_(model.parameters(), 5)过拟合在CNN后添加Dropout层概率设为0.2-0.5收敛慢使用预训练CNN部分冻结前几层参数4. 模型部署与API封装4.1 导出TorchScript模型model.eval() example torch.rand(1, 1, 32, 160) # 示例输入 traced_script torch.jit.trace(model, example) traced_script.save(crnn.pt)4.2 Flask接口实现from flask import Flask, request, jsonify import torch import cv2 import numpy as np app Flask(__name__) model torch.jit.load(crnn.pt, map_locationcpu) app.route(/ocr, methods[POST]) def ocr(): file request.files[image] img cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_GRAYSCALE) img cv2.resize(img, (160, 32)) tensor torch.from_numpy(img).float().unsqueeze(0).unsqueeze(0) / 255.0 with torch.no_grad(): output model(tensor) text decode_ctc(output) # CTC解码实现 return jsonify({text: text})4.3 性能优化技巧量化推理使用torch.quantization.quantize_dynamic减少模型大小批处理预测修改模型支持batch输入提升吞吐量缓存机制对相似尺寸图片复用预处理结果5. 进阶改进方向5.1 注意力机制增强在CRNN基础上加入注意力模块提升长文本识别效果class AttentionCRNN(nn.Module): def __init__(self, imgH, nc, nclass, nh): super().__init__() self.crnn CRNN(imgH, nc, nclass1, nh) # 1 for attention mask self.attention nn.Sequential( nn.Linear(nh*2, nh), nn.Tanh(), nn.Linear(nh, 1) )5.2 多语言支持方案扩展字符集合并中英文和符号的字符表混合训练交替使用不同语言数据批次语言检测前置分类器自动切换识别模型5.3 实际部署考量内存占用量化后模型通常10MB适合移动端推理速度在CPU上约50ms/图GPU加速可达10ms错误修正结合词典和N-gram进行后处理我在实际项目中发现对于复杂背景的身份证识别添加透视变换数据增强能提升15%的准确率。另外将CNN部分的最后两层改为可变形卷积(DCN)对弯曲文本的识别效果有明显改善。

更多文章