在学术研究和日常学习中,手写笔记是最常见的信息载体之一。这些笔记往往带有强烈的个人书写风格,导致他人难以辨认和理解。当这些手写内容需要被数字化共享时,传统的人工转录方式效率低下且成本高昂。光学字符识别(OCR)技术的出现为解决这一难题提供了可能。
TrOCR(Transformer-based Optical Character Recognition)是微软基于Transformer架构开发的先进OCR模型。与传统的OCR系统相比,TrOCR具有以下显著优势:
提示:在实际应用中,手写OCR面临的最大挑战是书写风格的多样性。即使是同一人的笔迹,也会因书写工具、书写速度和情绪状态而产生显著差异。
GNHK(GoodNotes Handwriting Kollection)数据集由Goodnotes公司收集,包含全球各地学生的手写英文笔记。该数据集的主要特点包括:
数据集目录结构如下:
code复制├── test_data
│ └── test
│ ├── eng_AF_004.jpg
│ ├── eng_AF_004.json
│ ...
└── train_data
└── train
├── eng_AF_001.jpg
├── eng_AF_001.json
...
原始数据集包含整页文档图像,而TrOCR模型设计用于识别单个单词或短句。因此需要进行以下预处理步骤:
预处理后的数据结构:
code复制├── train_processed
│ ├── images
│ │ ├── eng_AF_001_0.jpg
│ │ ...
│ └── train_processed.csv
└── test_processed
├── images
│ ├── eng_AF_004_0.jpg
│ ...
└── test_processed.csv
关键预处理代码片段:
python复制def polygon_to_bbox(polygon):
points = np.array([(polygon[f'x{i}'], polygon[f'y{i}']) for i in range(4)])
x, y, w, h = cv2.boundingRect(points)
return x, y, w, h
def process_dataset(input_folder, output_folder, csv_path):
with open(csv_path, 'w') as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['image_filename', 'text'])
for filename in os.listdir(input_folder):
if filename.endswith('.json'):
# 处理每个JSON文件...
我们使用Hugging Face提供的microsoft/trocr-small-handwritten作为基础模型,其主要参数如下:
关键模型配置:
python复制model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
model.config.max_length = 64
model.config.num_beams = 4
训练采用以下优化策略:
训练参数设置:
python复制training_args = Seq2SeqTrainingArguments(
output_dir='trocr_handwritten/',
per_device_train_batch_size=48,
per_device_eval_batch_size=48,
num_train_epochs=10,
fp16=True,
evaluation_strategy='epoch',
save_strategy='epoch'
)
使用预训练模型和微调后的模型在测试集上的对比结果:
| 模型版本 | CER | 识别准确率 |
|---|---|---|
| 预训练模型 | 0.82 | 18% |
| 微调模型 | 0.12 | 88% |
训练过程中的CER变化曲线显示,模型性能持续提升直至训练结束:

注意:实际训练中发现,学习率设置对模型收敛影响显著。过大的学习率会导致CER波动,而过小的学习率会延长训练时间。
完整的推理流程包括以下步骤:
关键推理代码:
python复制def ocr(image, processor, model):
pixel_values = processor(image, return_tensors='pt').pixel_values.to(device)
generated_ids = model.generate(pixel_values)
return processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
对比预训练模型和微调模型在相同样本上的表现:
原始图像:

识别结果对比:
在实际部署中,可采用以下优化策略:
问题1:过拟合
问题2:梯度爆炸
max_grad_norm=1.0)问题1:特殊字符识别失败
问题2:连笔字识别困难
当前模型仅支持英文,可通过以下方式扩展多语言能力:
单词级识别的局限性促使我们考虑句子级识别方案:
针对学术笔记中的数学内容,需要:
在实际部署这套手写OCR系统时,我发现模型对书写工具的敏感性比预期要高。圆珠笔和铅笔的识别效果差异可达15%,这提示我们在数据收集阶段需要尽可能覆盖各种书写工具和纸张类型。另一个实用建议是,对于重要的文档数字化项目,可以保留人工校验环节,将模型置信度低于90%的识别结果自动标记为需要人工复核。