OpenAI WhisperをFine Tuningする

Amazon

はじめに

以下のサイトを参考に、openai/whisperモデルのFine Tuningをおこなってみる

Hugging Faceのブログがオリジナルの情報源

その他関連記事

Whisperについて

  • Whisperは68万時間分のラベル付きAudio-Transcriptionデータで教師ありの事前学習をしている
  • WhisperではFine Tuningの必要性は低い
  • LibriSpeech ASRのtest-clean subsetではWER約3%、TED-LIUMデータセットでは4.7%を達成
  • Transformerベースのエンコーダ-デコーダモデル(シーケンス-シーケンスモデル)で、音声スペクトログラムの特徴量をテキストトークンのシーケンスに変換する
    • まず、生音声入力を、ログメルスペクトログラムに変換します(特徴抽出器による処理)
    • 次に、Transformerエンコーダは、スペクトログラムをエンコードして、エンコーダの隠れ状態を生成する
    • デコーダは、以前のトークンとエンコーダの隠れ状態を元にして、自動回帰的にテキストトークンを予測する
  • クロスエントロピー目的関数を使って事前学習およびファインチューニングされ、ターゲットのテキストトークンを正しく分類するよう訓練されている
  • モデルは5種類のサイズで提供され、最小の4つは英語専用または多言語データで訓練され、最大のモデルは多言語専用
  • トレーニング中に、Hugging Face Hubにモデルのチェックポイントを保存することを推奨
  • Common Voiceとは、Wikipediaのテキストを様々な言語で録音したクラウドソース型データセット
  • ASRパイプラインは次の3つのコンポーネントに分解できる:
    • 特徴抽出器: 生の音声入力を前処理する
    • モデル: シーケンス間の変換(音声からテキストへのマッピング)を行う
    • トークナイザー: モデル出力をテキスト形式に後処理する
  • Whisperの特徴抽出器は、16kHzのサンプリングレートの音声入力を期待する
  • Whisperの特徴抽出器は2つの処理を行う
    • パディング/トランケーション:
      • すべての音声サンプルを30秒に調整。30秒未満のサンプルは30秒にパディングされ、30秒以上のサンプルは切り詰められる
    • ログメルスペクトログラムへの変換:
      • パディングされた音声配列をログメルスペクトログラムに変換する。これは、音声信号の周波数を時間軸に沿って視覚的に表現したもの
  • Whisperモデルは、辞書内の語彙アイテムのインデックスに対応するテキストトークンを出力し、トークナイザーは、これらのトークンを実際のテキストに変換する
    • (例: [1169, 3797, 3332] -> "the cat sat")。
  • Whisperのトークナイザーは96言語の転写データで事前学習されており、ほぼすべての多言語ASRに対応可能
  • トークナイザーは、Common Voiceデータセットの最初のサンプルをエンコード・デコードする
  • エンコード時には、トークナイザーが特別なトークン(転写の開始/終了、言語、タスクのトークン)をシーケンスの先頭と末尾に追加。デコード時には、これらの特別なトークンをスキップして、元の入力形式の文字列を返す

Fine Tuningのデモ

デモとして、244Mパラメータ(約1GB)の多言語対応の小型チェックポイントをファインチューニングします。データは、Common Voiceデータセットから選んだ低リソース言語で訓練・評価します。わずか8時間分のファインチューニングデータでも、その言語で高いパフォーマンスを達成できることを示します。

以下は、Sagemaker JupyterLabで実行すると、decode時にカーネルがリスタートし、原因が特定できていない
Sagemaker Studio Classicではdecodeに成功した

準備

!pip install --upgrade pip
!pip install --upgrade datasets transformers accelerate evaluate jiwer tensorboard gradio
from huggingface_hub import notebook_login

notebook_login()

データセットのロード

from datasets import load_dataset, DatasetDict

common_voice = DatasetDict()

# 'use_auth_token=True' を削除
common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation")
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test")

print(common_voice)

Extractorのロード

from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")

Tokenizerのロード

from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

Tokenizerのエンコード/デコードの確認

エンコード/デコードの実行

input_str = common_voice["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"Input:                 {input_str}")
print(f"Decoded w/ special:    {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal:             {input_str == decoded_str}")
Input:                 हमने उसका जन्मदिन मनाया।
Decoded w/ special: <|startoftranscript|><|hi|><|transcribe|><|notimestamps|>हमने उसका जन्मदिन मनाया।<|endoftext|>
Decoded w/out special: हमने उसका जन्मदिन मनाया।
Are equal: True

WhisperProcessorへ抽出器とトークナイザーの組み合わせ

Whisperの特徴抽出器とトークナイザーの使用を簡略化するために、両方を1つのクラスであるWhisperProcessorにまとめます。このクラスは、WhisperFeatureExtractorWhisperTokenizerを継承し、音声入力とモデル予測の処理に使用されます。これにより、トレーニング中に管理するオブジェクトは、プロセッサーモデルの2つだけで済みます。

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

サンプル音声のダウンサンプリング

入力音声は1次元の配列で、対応する転写データがあります。Whisperモデルのサンプリングレート(16kHz)に音声のサンプリングレートを合わせる必要があります。入力音声が48kHzでサンプリングされている場合、Whisperの特徴抽出器に渡す前に16kHzにダウンサンプリングする必要があります。

datasetscast_column メソッドを使用して、音声のサンプリングレートを正しい値に設定します。この操作は音声データ自体を即座に変更するのではなく、最初にデータを読み込む際に動的にリサンプリングを行います。

from datasets import Audio

common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

サンプルの確認(48000Hz)

print(common_voice["train"][0])

{‘audio’: {‘path’: ‘/root/.cache/huggingface/datasets/downloads/extracted/1bfc12b9ee30f73bf143fa237d4ba38488008883c25816876e1a35295c9575d3/hi_train_0/common_voice_hi_26008353.mp3’, ‘array’: array([ 5.81611368e-26, -1.48634016e-25, -9.37040538e-26, …,
1.06425901e-07, 4.46416450e-08, 2.61450239e-09]), ‘sampling_rate’: 48000}, ‘sentence’: ‘हमने उसका जन्मदिन मनाया।’}

サンプルの確認(16000Hz)

print(common_voice["train"][0])
{'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/1bfc12b9ee30f73bf143fa237d4ba38488008883c25816876e1a35295c9575d3/hi_train_0/common_voice_hi_26008353.mp3', 'array': array([ 3.81639165e-17,  2.42861287e-17, -1.73472348e-17, ...,
       -1.30981789e-07,  2.63096808e-07,  4.77157300e-08]), 'sampling_rate': 16000}, 'sentence': 'हमने उसका जन्मदिन मनाया।'}

データをモデルに渡す準備をする関数

  • batch["audio"]を呼び出して、音声データを読み込みリサンプリングします。 Datasetsは必要なリサンプリングを自動で行います。
  • 特徴抽出器を使用して、1次元の音声配列からログメルスペクトログラム入力特徴を計算します。
  • トークナイザーを使って、転写データをラベルIDにエンコードします。
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

データ準備の関数は、datasets.map メソッドを使って、すべてのトレーニングデータに適用できます。これにより、各トレーニング例に対して音声のリサンプリングや特徴抽出、ラベルIDへのエンコードが一括で実行されます。

common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)

トレーニングと評価

データの準備ができたら、トレーニングパイプラインに進みます。 Trainerが主要な処理を担当するため、私たちは次のことを行うだけです:

  1. 事前学習済みチェックポイントの読み込み: トレーニング用に正しく設定された事前学習済みモデルをロードします。
  2. データコレーターの定義: 前処理済みデータをモデル用のPyTorchテンソルに変換します。
  3. 評価指標: 評価時に単語誤り率(WER)を使用し、計算を行うcompute_metrics関数を定義します。
  4. トレーニング引数の定義: Trainerがトレーニングスケジュールを構築するために使います。

モデルをファインチューニングした後、テストデータで評価し、ヒンディー語の音声転写が正しく学習されたかを確認します。

Pre−Trainedチェックポイントのロード

事前学習済みのWhisper小型モデルのチェックポイントからファインチューニングを開始します。これを行うために、Hugging Face Hubから事前学習済みの重みをロードします。Transformersを使うことで、これは非常に簡単に行えます。

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

推論時、Whisperモデルは自動的に音声の言語を検出し、その言語でトークンIDを予測します。ただし、音声の言語が事前に分かっている場合(例: 多言語ファインチューニング)では、言語を明示的に設定する方が有益です。これにより、誤った言語が予測され、生成中に予測されたテキストが正しい言語からずれることを防げます。そのため、言語とタスクを生成設定に指定し、旧方式のforced_decoder_idsは使わずにNoneに設定します。

model.generation_config.language = "hindi"
model.generation_config.task = "transcribe"

model.generation_config.forced_decoder_ids = None

Data Collatorの定義

シーケンス間の音声モデルのデータコレーターは、入力特徴とラベルを別々に扱う点で独特です。入力特徴は特徴抽出器で、ラベルはトークナイザーで処理されます。
入力特徴はすでに30秒にパディングされ、固定次元のログメルスペクトログラムに変換されているため、PyTorchテンソルに変換するだけです。これは特徴抽出器の padメソッドを使用し、return_tensors=ptで行います。追加のパディングは必要ありません。
一方、ラベルはパディングされていません。まず、トークナイザーの padメソッドを使ってバッチ内の最大長にシーケンスをパディングし、パディングトークンを-100に置き換え、損失計算時に無視されるようにします。ラベルシーケンスの冒頭にある開始トークンは、トレーニング中に追加されるため削除します。
これらの操作は、以前定義したWhisperProcessorを使って行います。

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

これから、先ほど定義したデータコレーターを初期化します。

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

Evaluation Metrics

次に、評価セットで使用する評価指標を定義します。ASRシステムの評価に一般的に使われる単語誤り率(WER)を使用します。詳細についてはWERのドキュメントを参照してください。
EvaluateからWER指標を読み込みます。

import evaluate

metric = evaluate.load("wer")

compute_metrics関数

音声認識モデルの評価に使用される「WER(Word Error Rate)」メトリックを計算するためのcompute_metrics関数の流れは以下の通りです。

  1. label_idsの中で-100pad_token_idに置き換え、損失計算で無視されていたパディングトークンを元に戻す。
  2. モデルの予測とラベルのIDをデコードし、それぞれ文字列に変換する。
  3. 予測と参照ラベル間のWERを計算する。

簡単に言うと、モデルの予測精度を評価するために、ラベルと予測結果を文字列化し、WERを計算します。

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

トレーニングのパラメータ定義

最終ステップでは、トレーニングに関連するすべてのパラメータを定義します。以下に一部のパラメータを説明します。

  • output_dir: モデルの重みを保存するローカルディレクトリ。これはHugging Face Hub上のリポジトリ名にもなります。
  • generation_max_length: 評価時に自動生成されるトークンの最大数を設定します。
  • save_steps: トレーニング中、指定したsave_stepsごとに中間チェックポイントが保存され、Hubに非同期でアップロードされます。
  • eval_steps: トレーニング中、指定したeval_stepsごとに中間チェックポイントの評価が行われます。
  • report_to: トレーニングログをどこに保存するかを指定します。サポートされているプラットフォームは「azure_ml」、「comet_ml」、「mlflow」、「neptune」、「tensorboard」、「wandb」です。好みのプラットフォームを選ぶか、「tensorboard」を選んでHubにログを記録します。
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-hi",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=5000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)

注意: モデルのチェックポイントをHubにアップロードしたくない場合は、push_to_hub=Falseに設定してください。

トレーニング引数は、モデル、データセット、データコレータ、compute_metrics関数と一緒にTrainerに渡すことができます。

HuggingFaceのアクセストークンが書き込みOKになっている必要がある。

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

トレーニング<以下はリソースを多く使うため、参考ブログの和訳>

トレーニングの実行

trainer.train()

トレーニングには、使用しているGPUやGoogle Colabで割り当てられたGPUに応じて、おおよそ5~10時間かかります。GPUによっては、トレーニングを開始する際にCUDAの「メモリ不足」エラーが発生することがあります。その場合、per_device_train_batch_sizeを段階的に2分の1ずつ減らし、gradient_accumulation_stepsを使用して補うことができます。

私たちの最良のWER(単語誤り率)は、4000ステップのトレーニング後に32.0%でした。参考として、事前訓練されたWhisperのsmallモデルは63.5%のWERを達成しており、ファインチューニングによって絶対値で31.5%の改善を達成しています。わずか8時間分のトレーニングデータでこれだけの成果を得られたのは悪くありません!

これで、ファインチューニングしたモデルをHugging Face Hubで共有する準備が整いました。適切なタグやREADME情報を追加してよりアクセスしやすくするために、モデルをアップロードする際には適切なキーワード引数(kwargs)を設定できます。これらの値は、使用したデータセット、言語、モデル名に応じて変更できます。

kwargs = {
    "dataset_tags": "mozilla-foundation/common_voice_11_0",
    "dataset": "Common Voice 11.0",  # a 'pretty' name for the training dataset
    "dataset_args": "config: hi, split: test",
    "language": "hi",
    "model_name": "Whisper Small Hi - Sanchit Gandhi",  # a 'pretty' name for your model
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
}

トレーニング結果をHubにアップロードできるようになりました。これを行うには、push_to_hubコマンドを実行してください。

trainer.push_to_hub(**kwargs)

このモデルをHub上のリンクを使って誰とでも共有できるようになりました。たとえば、「your-username/the-name-you-picked」という識別子を使って、他の人もこのモデルを読み込むことができます。

from transformers import WhisperForConditionalGeneration, WhisperProcessor

model = WhisperForConditionalGeneration.from_pretrained("sanchit-gandhi/whisper-small-hi")
processor = WhisperProcessor.from_pretrained("sanchit-gandhi/whisper-small-hi")

ファインチューニングしたモデルはCommon Voiceのヒンディー語テストデータで満足のいく結果を得られましたが、決して最適ではありません。このノートブックの目的は、事前訓練されたWhisperのチェックポイントを任意の多言語ASR(自動音声認識)データセットに対してファインチューニングする方法を示すことです。学習率やドロップアウトなどのトレーニングハイパーパラメータを最適化し、より大きな事前訓練済みチェックポイント(mediumやlarge-v3)を使用することで、結果がさらに改善される可能性があります。

デモのビルド

モデルのファインチューニングが完了したので、そのASR(自動音声認識)機能を披露するデモを作成できます!ここでは、音声入力の前処理からモデルの予測をデコードするまでのすべてを処理してくれるTransformersのパイプラインを使用します。そして、Gradioを使ってインタラクティブなデモを構築します。Gradioは、機械学習デモを構築する最も簡単な方法と言われており、数分でデモを作成することが可能です。

以下の例を実行すると、Gradioデモが生成され、コンピュータのマイクを使って音声を録音し、それをファインチューニングしたWhisperモデルに入力して、対応するテキストに書き起こすことができます。

xxxxxx/whisper-small-hiは、HuggingFaceに先ほどアップロードしたモデル

from transformers import pipeline
import gradio as gr

pipe = pipeline(model="xxxxxx/whisper-small-hi")  # change to "your-username/the-name-you-picked"

def transcribe(audio):
    text = pipe(audio)["text"]
    return text

iface = gr.Interface(
    fn=transcribe, 
    inputs=gr.Audio(source="microphone", type="filepath"), 
    outputs="text",
    title="Whisper Small Hindi",
    description="Realtime demo for Hindi speech recognition using a fine-tuned Whisper small model.",
)

iface.launch()

Datasets、Transformers、およびHugging Face Hubを使用して、Whisperを多言語ASR向けにファインチューニングするためのステップバイステップガイドを紹介しました。自分でファインチューニングを試してみたい場合は、Google Colabを参照してください。英語や多言語ASR向けに他のTransformersモデルをファインチューニングすることに興味がある場合は、examples/pytorch/speech-recognitionにあるサンプルスクリプトもぜひチェックしてください。

関連記事

カテゴリー

アーカイブ

Lang »