SPARKCREATIVE Tech Blog

https://www.spark-creative.jp/

Tesseract ノベルゲームに特化した日本語OCR(文字認識)をしてみた

こんにちは!!!クライアントエンジニアの小林です。

趣味で触っているコーパス作成の過程でテキスト入力を自動化したいなぁと思いOCRを導入してみました。
cloud visionAPI叩くスタイルなので鯖落ちとか面倒だなぁと思い、ローカルで動作するtesseractを選びました。

がしかし、tesseractさん、日本語に対する精度が低いです。

といっても英字は比較的読めているので読めないことはないだろうと思い、プリプロセスを頑張ったらギリギリ実用レベルに達しました。
今回はそれについてのご紹介です。

概要

特定の動作環境下においてtesseractの精度を上げる方法をご紹介します。
汎用性を求める方は巷で有名なGoogleさんのサービスを使うことをおすすめします。
精度が桁違いなので。やはりデータ量は正義なのですね。

流れとしてはコード全文、一部を抜粋して解説としています。
コードだけを見たい方はコード全文をコピペしてください。

作業環境

・windows10
visual studio code
python 3.9.12

処理フロー

  1. cv2.thresholdで2値化した画像に対してOCRを行います。
    以降cv2.thresholdな処理をSimpleThresholdと呼びます。
    2値化に使用するパラメータは、
    初回は設定された範囲の値を全て使う総当たり方式で、
    2回目以降は前回の結果から精度の高いパラメータのみを使用します。
  2. 1と同様のことを行います。
    異なる点は使用する2値化関数をcv2.adaptiveThresholdに変えることです。
    以降cv2.adaptiveThresholdな処理をAdaptiveThresholdと呼びます。
  3. SimpleThresholdとAdaptiveThresholdの結果から一定以上の精度のパラメータを使用し、
    それらで組み合わせられる全てのパラメータでOCRを行います。
    以降SimpleThresholdとAdaptiveThresholdを組み合わせた処理をHybridThresholdと呼びます。
  4. 以上をテスト画像毎に行うことで徐々に最適なパラメータが絞られていく感じです。

シンプルで愚直な方法ですが、それなりに精度を高めることができます。

SequenceMatcher

class SequenceMatcher:
    KOMOJI_SPL = {
        "ぁ":"あ",
        "ぃ":"い",
        "ぅ":"う",
        "ぇ":"え",
        "ぉ":"お",
    }
    KOMOJI_SPL_KEYS = tuple(KOMOJI_SPL.keys())
    
    SYMBOL_SPL = {
        "『":"「",
        "』":"」",
        "「":"『",
        "」":"』",
    }
    SYMBOL_SPL_KEYS = tuple(SYMBOL_SPL.keys())
    
    def __init__(self, a:str, b:str, use_only_length:bool=False):
        """コンストラクタ
        Args:
            a (str): _description_
            b (str): _description_
            use_only_length (bool, optional): 一致率を計算する際にAの長さしか考慮しないか. Defaults to False.
        """
        self.a = a
        self.b = b
        self.use_only_length = use_only_length
    
    def __equal(self, a:str, b:str) -> bool:
        try:
            temp_a = a[0]
        except Exception:
            temp_a = None
            
        try:
            temp_b = b[0]
        except Exception:
            temp_b = None
        
        if temp_a == temp_b:
            return True
        elif temp_a in self.KOMOJI_SPL_KEYS and self.KOMOJI_SPL[temp_a] == temp_b:
            return True
        elif temp_a in self.SYMBOL_SPL_KEYS and self.SYMBOL_SPL[temp_a] == temp_b:
            return True
        elif temp_a is not None and temp_b is not None:
            return self.__equal(a, b[1:])
        
        return False
    
    def ratio(self) -> float:
        if len(self.a) == 0 and len(self.b) == 0:
            return 1
        
        if len(self.a) == 0 or len(self.b) == 0:
            return 0
        
        success_count:int = 0
        
        for i in range(len(self.a)):
            if self.__equal(self.a[i:], self.b[i:]):
                success_count += 1
        
        if self.use_only_length:
            return (2*success_count)/(len(self.a)*2)
        else:
            return (2*success_count)/(len(self.a)+len(self.b))

ひらがな小文字や記号を区別するか

if temp_a == temp_b:
    return True
elif temp_a in self.KOMOJI_SPL_KEYS and self.KOMOJI_SPL[temp_a] == temp_b:
   return True
elif temp_a in self.SYMBOL_SPL_KEYS and self.SYMBOL_SPL[temp_a] == temp_b:
   return True

探している文字がひらがな小文字や記号一覧に含まれているかを、まずはkeyで判定して、含まれているのであればその後にvalueと比較しています。

ここではひらがな小文字と記号を別々の変数に分けていますが、処理速度を優先するのであればひとつにまとめたほうがいい(はず)です。この処理に限っては見やすさを優先しているので分けています。

ゲシュタルトパターンマッチングのスコア算出方法の追加

if self.use_only_length:
    return (2*success_count)/(len(self.a)*2)
else:
    return (2*success_count)/(len(self.a)+len(self.b))

通常の計算方法がaとbの文字列の長さを足したもので除算するのに対し、今回追加したのはaの文字列の長さを2回足したものを除算する方法です。

これを実装することにより、鍵括弧内の文字だけあっていれば正解とする、といった運用ができます。

a (training_text) b (test_text) use_only_length=False use_only_length=True
こんにちは 「こんにちは」 0.8333 1.0
えっ 「えっ!?」 0.5 1.0

TesseractCommon

import pyocr
import pyocr.builders
import pyocr.tesseract
import numpy as np
from PIL import Image
from sequence_matcher import SequenceMatcher


# regist path
pyocr.tesseract.TESSERACT_CMD = r"C:/Program Files/Tesseract-OCR/tesseract.exe"


# default tool
OCR_TOOL:pyocr.tesseract = pyocr.get_available_tools()[0]
# default text builder
TEXT_BUILDER = pyocr.builders.TextBuilder(tesseract_layout=6)


def image_to_string(srcs:list[np.ndarray], training_text:str, use_only_length:bool=False) -> tuple[float, str]:
    text:str = ""
    for src in srcs:
        text += OCR_TOOL.image_to_string(Image.fromarray(src), "jpn", TEXT_BUILDER)
    # remove symbol
    text = text.replace(" ", "")
    score = SequenceMatcher(training_text, text, use_only_length).ratio()
    return score, text

pyocr.tesseract.TESSERACT_CMD

環境変数に設定してある方は不要な記述です。
私は作業時にPCを再起動するのが面倒だったという理由から使用しそのまま使い続けています。

tesseract_layout

tesseract_layout=6は画像に含まれる複数行の文字を検出可能です。
tesseract_layout=7は画像を1行として文字検出を行うので精度が6よりも高いです。その代わり画像には文字が含まれていることを前提とした挙動をするので、何も書かれていない空の画像からも文字を検出し、結果的に誤検出につながります。
これは画像に含まれる有効ピクセル数の割合から回避することも可能ですが、今回は面倒なので一番シンプルに実装できるtesseract_layout=6で実装しています。

画像から文字に変換後、精度・識字率を算出

後述しますが画像の編集にはOpenCVを利用しています。
そのためImage.fromarraynp.ndarrayからclass Image.Imageに変換しています。

文字検出が終わったら、文字列から半角スペースを排除、その後に先ほど実装したSequenceMatcherで精度・識字率を算出します。

SimpleThreshold

import numpy as np
import cv2
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from tesseract_common import image_to_string


@dataclass
class SimpleThresholdParam:
    thresh:int=30


@dataclass_json
@dataclass
class SimpleThresholdResult:
    """ファイル入出力が発生するのでデータクラスにまとめている
    """
    score:float=0.0
    text:str=""
    param:SimpleThresholdParam=SimpleThresholdParam()


def simple_threshold(srcs:list[np.ndarray], param:SimpleThresholdParam) -> list[np.ndarray]:
    """cv2.threshold

    Args:
        srcs (list[np.ndarray]): _description_
        param (SimpleThresholdParam): _description_

    Returns:
        list[np.ndarray]: _description_
    """
    return [cv2.threshold(src, param.thresh, 255, cv2.THRESH_BINARY)[1] for src in srcs]


def simple_threshold_to_string(srcs:list[np.ndarray], param:SimpleThresholdParam, training_text:str, use_only_length:bool=False) -> SimpleThresholdResult:
    """SimpleThreshold後にOCR

    Args:
        srcs (list[np.ndarray]): _description_
        param (SimpleThresholdParam): _description_
        training_text (str): _description_
        use_only_length (bool, optional): _description_. Defaults to False.

    Returns:
        SimpleThresholdResult: _description_
    """
    images = simple_threshold(srcs, param)
    score, text = image_to_string(images, training_text, use_only_length)
    return SimpleThresholdResult(score, text, param)


SIMPLE_THRESHOLD_PARAM_RANGE = [SimpleThresholdParam(thresh) for thresh in range(1,101)]


def get_simple_threshold_param_range() -> list[SimpleThresholdParam]:
    """SimpleThresholdのパラメータ範囲

    Returns:
        list[SimpleThresholdParam]: _description_
    """
    return SIMPLE_THRESHOLD_PARAM_RANGE

AdaptiveThreshold

import numpy as np
import cv2
from dataclasses import dataclass


@dataclass
class AdaptiveThresholdParam:
    block_size:int=3
    c:int=1


def adaptive_threshold(srcs:list[np.ndarray], param:AdaptiveThresholdParam) -> list[np.ndarray]:
    """cv2.adaptiveThreshold

    Args:
        srcs (list[np.ndarray]): _description_
        param (AdaptiveThresholdParam): _description_

    Returns:
        list[np.ndarray]: _description_
    """
    return [cv2.adaptiveThreshold(src, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, param.block_size, param.c) for src in srcs]


ADAPTIVE_THRESHOLD_PARAM_RANGE = [AdaptiveThresholdParam(block_size, c) for block_size in range(3,12,2) for c in range(1,101)]


def get_adaptive_threshold_param_range() -> list[AdaptiveThresholdParam]:
    """AdaptiveThresholdのパラメータ範囲

    Returns:
        list[AdaptiveThresholdParam]: _description_
    """
    return ADAPTIVE_THRESHOLD_PARAM_RANGE

HybridThreshold

import numpy as np
import cv2
from dataclasses import dataclass
from dataclasses_json import dataclass_json
from pathlib import Path
import json
from tesseract_common import image_to_string
from simple_threshold import (
    SimpleThresholdParam,
    simple_threshold,
    get_simple_threshold_param_range,
)
from adaptive_threshold import (
    AdaptiveThresholdParam,
    adaptive_threshold,
    get_adaptive_threshold_param_range,
)


@dataclass_json
@dataclass
class HybridThresholdParam:
    simple:SimpleThresholdParam=SimpleThresholdParam()
    adaptive:AdaptiveThresholdParam=AdaptiveThresholdParam()


@dataclass_json
@dataclass
class HybridThresholdResult:
    score:float=0.0
    text:str=""
    param:HybridThresholdParam=HybridThresholdParam()


def hybrid_threshold_to_string(srcs:list[np.ndarray], param:HybridThresholdParam, training_text:str, use_only_length:bool=False) -> HybridThresholdResult:
    """HybridThreshold後にOCR
    SimpleThresholdとAdaptiveThresholdの有効ピクセルを採用

    Args:
        srcs (list[np.ndarray]): _description_
        param (HybridThresholdParam): _description_
        training_text (str): _description_
        use_only_length (bool, optional): _description_. Defaults to False.

    Returns:
        HybridThresholdResult: _description_
    """
    simple_images = simple_threshold(srcs, param.simple)
    adaptive_images = adaptive_threshold(srcs, param.adaptive)
    hybrid_images = [cv2.min(simple_image, adaptive_image) for simple_image, adaptive_image in zip(simple_images, adaptive_images)]
    score, text = image_to_string(hybrid_images, training_text, use_only_length)
    return HybridThresholdResult(score, text, param)


def get_hybrid_threshold_param_range(best_path:str="") -> list[HybridThresholdParam]:
    """SimpleThreshod+AdaptiveThresholdのパラメータ範囲
    Bestパラメータファイルが指定されていない場合は全ての組み合わせを返す
    
    Returns:
        list[HybridThresholdParam]: _description_
    """
    if Path(best_path).is_file():
        with open(best_path, mode="r", encoding="utf-8") as f:
            json_data = json.load(f)
            return [HybridThresholdParam.from_json(data) for data in json_data]
    else:
        return [HybridThresholdParam(simple, adaptive) for adaptive in get_adaptive_threshold_param_range() for simple in get_simple_threshold_param_range()]

Train

from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
import json
from operator import attrgetter
from argparse_dataclass import ArgumentParser
import cv2
from pathlib import Path
import numpy as np
from tqdm import tqdm
from transcript_table import Transcript
from simple_threshold import (
    simple_threshold_to_string,
    get_simple_threshold_param_range,
    SimpleThresholdResult,
)
from hybrid_threshold import (
    hybrid_threshold_to_string,
    get_hybrid_threshold_param_range,
    HybridThresholdResult,
)


@dataclass
class Config:
    stage:int=field(metadata=dict(choices=[1,2], help="1:validate_testdata 2:validate_testdata"))
    transcripts:str=field(metadata=dict(required=True, help="select transcript"))
    output_path:str=field(metadata=dict(required=True, help="出力結果の保存先"))
    best_score_threshold:int=field(default=99, metadata=dict(help="ベストスコアの閾値"))
    image_color_invert:bool=field(default=True, metadata=dict(help="画像の色を反転させる"))
    ignore_best:bool=field(default=False, metadata=dict(required=True, help="最初のみ最適なパラメータを無視するか"))


def load_images(config:Config, images:list[str]) -> list[np.ndarray]:
    """画像読込

    Args:
        config (Config): _description_
        images (list[str]): _description_

    Returns:
        list[np.ndarray]: _description_
    """
    if config.image_color_invert:
        return [cv2.bitwise_not(cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2GRAY)) for image in images]
    else:
        return [cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2GRAY) for image in images]


def load_transcripts(filename:str) -> list[Transcript]:
    if Path(filename).is_file():
        with open(filename, mode="r", encoding="utf-8") as f:
            json_data:dict[str,str] = json.load(f)
            return [Transcript.from_json(data) for data in json_data.values()]


def validate_testdata(config:Config):
    # load transcripts.
    transcripts = load_transcripts(config.transcripts)
    
    # create output_dir.
    output_path = Path(config.output_path) / "dump" / "validate"
    output_path.mkdir(parents=True, exist_ok=True)
    
    with ThreadPoolExecutor() as executor:
        for transcript in transcripts:
            images = load_images(config, transcript.images)
            
            futures = [
                executor.submit(
                    simple_threshold_to_string,
                    images,
                    param,
                    transcript.training_text,
                    transcript.text!=transcript.training_text,
                )
                for param in get_simple_threshold_param_range()
            ]
            
            success_count:int = 0
            results:list[SimpleThresholdResult] = []
            
            with tqdm(futures, desc=f"{transcript.name}", postfix={"Success":success_count}) as iters:
                for future in iters:
                    result = future.result()
                    results.append(result)
                    success_count += 1 if result.score >= 1.0 else 0
                    iters.set_postfix({"Success":success_count})
            
            results = sorted(results, key=attrgetter("score"), reverse=True)
            
            with open(str(output_path / f"{transcript.name}.json"), mode="w", encoding="utf-8") as f:
                temp_results = [result.to_json(ensure_ascii=False) for result in results]
                json.dump(temp_results, f, indent=2, ensure_ascii=False)


def train_threshold_parameter(config:Config):
    # load transcripts.
    transcripts = load_transcripts(config.transcripts)
    
    # create output_dir.
    output_path = Path(config.output_path) / "dump" / "parameter"
    output_path.mkdir(parents=True, exist_ok=True)
    
    # best param path.
    best_path = output_path / "best.json"
    
    # 0-1
    best_score_threshold = config.best_score_threshold / 100
    
    with ThreadPoolExecutor() as executor:
        for idx, transcript in enumerate(transcripts):
            # TODO: この除外設定正式に実装したい
            #       できればstage1の結果が一定数未満の場合は自動的に除外設定に入れるような
            #if transcript.name in ["TEXT013", "TEXT014"]:
            #    continue
            
            images = load_images(config, transcript.images)
            
            futures = [
                executor.submit(
                    hybrid_threshold_to_string,
                    images,
                    param,
                    transcript.training_text,
                    transcript.text!=transcript.training_text,
                )
                for param in get_hybrid_threshold_param_range("" if config.ignore_best and idx == 0 else best_path)
            ]
            
            success_count:int = 0
            results:list[HybridThresholdResult] = []
            
            with tqdm(futures, desc=f"{transcript.name}", postfix={"Success":success_count}) as iters:
                for future in iters:
                    result = future.result()
                    results.append(result)
                    success_count += 1 if result.score >= 1.0 else 0
                    iters.set_postfix({"Success":success_count})
            
            results = sorted(results, key=attrgetter("score"), reverse=True)
            
            # save all result.
            with open(str(output_path / f"{transcript.name}.json"), mode="w", encoding="utf-8") as f:
                temp_results = [result.to_json(ensure_ascii=False) for result in results]
                json.dump(temp_results, f, indent=2, ensure_ascii=False)
            
            # save best result.
            temp_bests = [result.param.to_json(ensure_ascii=False) for result in results if result.score >= best_score_threshold]
            if len(temp_bests) > 0:
                with open(str(best_path), mode="w", encoding="utf-8") as f:
                    json.dump(temp_bests, f, indent=2, ensure_ascii=False)
            else:
                print(f"\"{transcript.name}\" not found best parameters.")
        
        print("")
        
        # try with best parameters.
        best_param = get_hybrid_threshold_param_range(best_path)[0]
        
        for transcript in transcripts:
            images = load_images(config, transcript.images)
            result = hybrid_threshold_to_string(
                images,
                best_param,
                transcript.training_text,
                transcript.text!=transcript.training_text
            )
            print(f"{transcript.name}|{result.score:.2f}|{result.text}")


def my_app(config:Config):
    if config.stage == 1:
        validate_testdata(config)
    elif config.stage == 2:
        train_threshold_parameter(config)


if __name__ == "__main__":
    my_app(ArgumentParser(Config).parse_args())

テキストと画像データの読込

@dataclass_json
@dataclass
class Transcript:
    name:str=""
    text:str=""
    training_text:str=""
    images:list[str]=field(default_factory=list)


def load_transcripts(filename:str) -> list[Transcript]:
    if Path(filename).is_file():
        with open(filename, mode="r", encoding="utf-8") as f:
            json_data:dict[str,str] = json.load(f)
            return [Transcript.from_json(data) for data in json_data.values()]

画像とテキストデータの作成フローをツール化しちゃっているためにこのような実装になっていますが、必要なものはテキストと画像データパスの2つだけなのでdict[str,str]等で再現可能です。

画像を読込後 白黒化等の下処理

def load_images(config:Config, images:list[str]) -> list[np.ndarray]:
    """画像読込

    Args:
        config (Config): _description_
        images (list[str]): _description_

    Returns:
        list[np.ndarray]: _description_
    """
    if config.image_color_invert:
        return [cv2.bitwise_not(cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2GRAY)) for image in images]
    else:
        return [cv2.cvtColor(cv2.imread(image), cv2.COLOR_BGR2GRAY) for image in images]

OpenCVで画像を読込後に白黒化をし必要に応じて色反転を行っています。
2値化処理内でもパラメータを指定すればできますが、ここで処理したほうが高速化につながります。

このグレースケール化は何もいじっていないのでガンマ補正等を行う高度なグレースケール化をするとさらに精度が高まる気がします。

validate_testdata

最適なパラメータを見つける前に正常なテストデータかを調べます。

SimpleThresholdのパラメータ範囲を全て試した精度結果が出力されます。全て試すので正解が1つも無いということはあり得ません。その場合は、誤字が含まれているか、そもそもtesseractが未学習な文字の2択に絞られます。こういったデータは修正するか除外をします。

これらが終わってようやく最適なパラメータ探索ができます。

train_threshold_parameter

最適なパラメータ探索の処理です。
コード内のコメント通りです。

demo

step1: 画像とテキストを用意
step2: 最適なパラメータ探索

ここで紹介するstep1はツール化してありますが、ペイントソフトとテキストソフトを使用することで手動対応可能なので必須ではありません。テスト内容が多い場合はツール化をしたほうが作業の効率化につながるのでおすすめです。

step1

ツール起動

撮影範囲を指定
著作権の関係で弊社コーポレートサイトをキャプチャしていますが、実際にはゲーム画面を撮っています。

テキストとトレーニングテキストをセット
キストが画像から抽出されるべき文字列、トレーニングテキストがOCRで識字してほしい箇所の文字列として分けています。

step2

コマンドプロンプトvscodeからtrainを実行します。
今回はvscodeから実行します。vscodeから実行する際にコマンドプロンプトと同じようにargparseを受け取らせたい場合は、launch.jsonのargsを編集することで出来ます。

まずはstage1を実行してtesseractが絶対に読めない文字を抽出していきます。
今回のテストデータでは13,14, 24のテキストが読めないようです。
読めない文字を抽出すると「雫、頬、軋」でした。

読めないデータを除外したらstage2を実行します。
ここではゴリゴリにマシンパワーを消費します。

スペックによりますが大体1時間で終わります。

全ての画像を探索し終えるとベストパラメータでのOCR結果を出力されます。
これも公開すると引用に該当しちゃうのでほぼモザイクです。
tesseractが読めない文字が含まれている13,14,24を省いた正解数が34/37なので大体90%ほどの精度です。

tesseract-jpnが未学習な文字

ちょっくら未学習な文字が気になったので調べてみました。

ノベルゲームでよく実装されているモトヤLマルベリ3等幅で対応されている文字のうち未対応なものを抽出してみました。

まずはフォントのバイナリから対応文字のデータを抽出しようと思いましたが、フリーソフトがあったのでおとなしくそちらを使います。

出力結果がこちらです。

かなり幅広く対応していますね。

全てを比較するのは少々無駄な気もするので今回比較する内容は

・Basic Latin(基本ラテン文字
・Hiragana(ひらがな)
・Katakana(カタカナ)
・CJK Unified Ideographs(CJK統合漢字)

の4つとします。

ちなみにCJK Compatibility Ideographsをtesseract(CJK互換漢字)を学習しようとすると文字列まわりのエンコードエラーが出たので比較対象から外しています。「塚﨑諸館鶴」ここら辺の文字は使いそうなのでちょっと不満です。

それではlangdata_lstmのjpn.training_textから重複文字を削除して差分を見てみましょう。

基本ラテン文字 完璧
ひらがな ゑ゛゜
カタカナ
CJK統合漢字 たくさん

結果としてはこのような感じです。

ノベルゲームだと悪ふざけで「ゑ」という文字も使うので正直欲しかったです。

濁点と半濁点は、上半分にあるか、下半分にあるかという判定でも入れないと誤認しちゃいそうなので何となく理解できます。bboxをいじればいけそうですが、そのアルゴリズムは入ってないんでしょうかね。
カタカナの小さい「ヮ」もその流れでしょうか。

そして漢字が壊滅的です。これだと精度以前に実用性の問題が出てきちゃいますね。
fine-tuningで解消はできますが、これだけ多いとスクラッチ学習したほうがマシな気がします。

OCR開発目線から見ると基本ラテン文字だけで網羅できる英語圏ってうらやましいですね。

終わり

お疲れさまでした!!!

今後

テスト動作では大津の結果が悪かったので非採用としたのですが自動算出系は間違いなく便利なはずなのでもう少し詰めてみたいなぁといった気持ちです。

あとはそもそもTesseract-OCRを使い続けるには限界がありそうなので、アルゴリズムを変えるようなリリースがないなら自作しちゃうのもアリかなと思えてきました。
Googleさんを使えばいいじゃない」と思うかもしれませんが、APIってどうしても不安なんですよねぇ。

Cloud Vision 試してみた

hybrid_thresholdをベストパラメータで処理して、リクエスト数を削減するために行ごとに分けた画像は結合してvision apiに投げてみました。

。。。

なんということでしょう。
精度100%です。

tesseract。。。。。。。。

左はhybrid_thresholdをしないで投げた結果でいくつかミスがあります。
右はhybrid_thresholdを適用してから投げた結果です。