Batch Optimizations for CRAFT OCR
Optimizing OCR Text Detection for Large-Scale Font Dataset Processing
The Problem: Sequential Processing Bottleneck
When building our font recognition system, we faced a significant performance bottleneck during dataset preparation. Our approach used the CRAFT (Character-Region Awareness For Text detection) model to detect individual characters in font samples, which allowed our neural network to learn character-level features for better font recognition.
However, the default CRAFT implementation was designed to process images one at a time, creating a major bottleneck in our pipeline. With hundreds of thousands of font images to process, this sequential approach was prohibitively slow:
- Each image went through multiple CPU↔GPU data transfers
- Forward passes through the neural network were performed one image at a time
- Post-processing steps like polygon extraction ran on CPU sequentially
For a dataset with 2M images, this sequential approach would take nearly 15 hours to complete a single epoch on my machine.
The Solution: Monkeypatching CRAFT for Batch Processing
Instead of rewriting the CRAFT library, I implemented a monkeypatching approach that added batch processing capabilities while maintaining compatibility with the original codebase. The key insight was that most of the processing could be parallelized across a batch of images by careful refactoring.
1. Adding Batch Polygon Detection
The key optimization was implementing a get_batch_polygons
method that processed multiple images in a single pass. In the original CRAFT codebase, polygon detection happened sequentially:
# Original sequential approach (simplified)
def get_polygons(self, image: Image.Image) -> List[List[List[int]]]:
# Preprocess single image
x, ratio_w, ratio_h = preprocess_image(np.array(image), self.canvas_size, self.mag_ratio)
# Forward pass for single image
score_text, score_link = self.get_text_map(x, ratio_w, ratio_h)
# Post-processing for single image
boxes, polys = getDetBoxes(score_text, score_link, ...)
boxes = adjustResultCoordinates(boxes, ratio_w, ratio_h)
# Convert to desired format
res = []
for poly in polys:
res.append(poly.astype(np.int32).tolist())
return res
I implemented a batch version that operates on tensors directly, keeping data on the GPU throughout the process:
def get_batch_polygons(self, batch_images: torch.Tensor, ratios_w: torch.Tensor, ratios_h: torch.Tensor):
"""Batch process pre-normalized images on GPU"""
# Forward pass for entire batch
with torch.no_grad():
y, _ = self.net(batch_images)
if self.refiner:
y, _ = self.refiner(y, None)
# Batch post-processing on GPU
text_scores = y[..., 0] # [B, H, W]
link_scores = y[..., 1] if not self.refiner else y[..., 0]
# Threshold maps on GPU
text_mask = (text_scores > self.text_threshold)
link_mask = (link_scores > self.link_threshold)
combined_mask = text_mask & link_mask
# Process each image in batch (still much faster than full sequential)
batch_polys = []
for b_idx in range(batch_images.size(0)):
# Extract polygons with GPU-accelerated connected components
# ... processing code ...
batch_polys.append(polys)
return batch_polys
2. Optimizing Preprocessing and Post-processing
I also refactored the preprocessing and post-processing steps to handle batches efficiently:
def batch_preprocess_image_np(batch_images, canvas_size, mag_ratio):
"""Process a batch of images with vectorized operations where possible"""
batch_size = len(batch_images)
resized_images = []
ratios_w = []
ratios_h = []
# Resize each image (could be parallelized further with multiprocessing)
for i in range(batch_size):
img_resized, target_ratio, _ = resize_aspect_ratio(
batch_images[i], canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio
)
ratio_h = ratio_w = 1 / target_ratio
resized_images.append(img_resized)
ratios_w.append(ratio_w)
ratios_h.append(ratio_h)
# Stack images into a single batch tensor
batch_resized = np.stack(resized_images, axis=0)
# Vectorized normalization (much faster than processing one by one)
batch_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(1, 1, 1, 3)
batch_std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(1, 1, 1, 3)
batch_normalized = (batch_resized / 255.0 - batch_mean) / batch_std
# Transpose from [B, H, W, C] to [B, C, H, W] for PyTorch
batch_transposed = np.transpose(batch_normalized, (0, 3, 1, 2))
return batch_transposed, ratios_w, ratios_h
3. Preprocessing Character Patches with CRAFT
While the batch processing optimization significantly improved the initial CRAFT text detection speed, we still faced another critical bottleneck during training. In our character-level approach, the model needed to process individual character patches extracted by CRAFT for each image. Doing this extraction during every training epoch was extremely computationally expensive.
The naive approach was to run character extraction during each forward pass of the model:
# Original approach: Extract patches during EVERY forward pass
def forward(self, images, targets=None):
# Extract character patches using CRAFT (SLOW!)
patches_data = self.extract_patches_with_craft(images)
# Process patches through character classifier
output = self.char_classifier(patches_data['patches'], patches_data['attention_mask'])
return output
This took ages to run because:
- The CRAFT model would run for every image in every batch, in every epoch
- Character patch extraction would repeat unnecessarily
- Training a single epoch took hours instead of minutes
The solution was to preprocess the CRAFT character extractions once and save them to disk. This ends up costing a bit of disk space but in exchange we get that sweet, sweet speedup.
def preextract_craft_patches(data_dir, output_dir):
"""Extract and save character patches for the entire dataset once"""
for mode in ['train', 'test']:
# Load dataset
dataset = FontDataset(data_dir, mode=mode)
# Create output storage (using HDF5 for efficient storage)
with h5py.File(os.path.join(output_dir, f'{mode}_craft_boxes.h5'), 'w') as h5f:
boxes_group = h5f.create_group('boxes')
# Process all images with optimized batch CRAFT
for idx, batch in enumerate(tqdm(DataLoader(dataset, batch_size=32))):
# Use batch-optimized CRAFT to get character polygons
batch_polys = craft_model.get_batch_polygons(batch['images'], ...)
# Convert polygons to bounding boxes and save
for i, polygons in enumerate(batch_polys):
img_idx = idx * 32 + i
boxes = convert_polygons_to_boxes(polygons)
boxes_group.create_dataset(
name=str(img_idx),
data=np.array(boxes, dtype=np.int32),
compression="gzip"
)
Then, during training, we modified the dataset to load these precomputed boxes:
class CharacterFontDataset(Dataset):
def __init__(self, data_dir, train=True, use_precomputed_craft=True):
# ...initialization code...
self.use_precomputed_craft = use_precomputed_craft
if use_precomputed_craft:
# Load precomputed CRAFT boxes
self.boxes_h5 = h5py.File(
os.path.join(data_dir, f'{"train" if train else "test"}_craft_boxes.h5'),
'r'
)
def __getitem__(self, idx):
# Load image and label
image = self.images[idx]
label = self.labels[idx]
if self.use_precomputed_craft:
# Use precomputed boxes instead of running CRAFT
boxes = self.boxes_h5['boxes'][str(idx)][:]
# Extract patches using boxes
patches, attention_mask = self.extract_patches_from_boxes(image, boxes)
else:
# Fallback to computing boxes on-the-fly (very slow)
patches, attention_mask = self.extract_patches_with_craft(image)
return {
'patches': patches,
'attention_mask': attention_mask,
'label': label
}
This optimization had a dramatic impact on training performance:
- Without preextraction: ~15 hours per epoch
- With preextraction: ~5 hours per epoch
A one-time cost of ~10 hours to preprocess the entire dataset saved us hundreds of hours during model development and experimentation.
Key Takeaways
- Keep data on the GPU: Minimize CPU↔GPU transfers by operating on batches
- Vectorize operations: Use numpy/PyTorch’s vectorized operations instead of loops
- Use prefetching: DataLoader’s prefetching capabilities minimize idle GPU time
- Precompute where possible: Trading storage for computation often makes sense
- Consider robustness: For long-running processes, implement checkpointing and recovery
- Monkeypatching vs. rewriting: Extending existing libraries through monkeypatching can be an efficient approach when full rewrites aren’t practical
These optimizations were crucial to be able to train anything in a reasonable time on my good ol A4500, allowing us to iterate quickly on our character-level font recognition approach.