OpenAIが提供する音声認識オープンソースWhisperとは(4)

AI

whisperのデバッグ環境構築

whisperのソースコードをチェックアウトする

git clone https://github.com/openai/whisper.git
cd whisper

test.pyの作成

import sys
from whisper.transcribe import cli
if __name__ == '__main__':
    sys.exit(cli())

test.pyの実行

python test.py sample.mp4 --language English

transcribe.py

def cli():
・・・
    model = load_model(model_name, device=device, download_root=model_dir)
・・・

ソースを確認する

load_model関数を読む

デフォルト時の引数は、model_name=small, device=cpu, download_root=Noneが使用される

def load_model(
    name: str,
    device: Optional[Union[str, torch.device]] = None,
    download_root: str = None,
    in_memory: bool = False,
) -> Whisper:
    """ 
    Load a Whisper ASR model

    Parameters
    ----------
    name : str
        one of the official model names listed by `whisper.available_models()`, or
        path to a model checkpoint containing the model dimensions and the model state_dict.
    device : Union[str, torch.device]
        the PyTorch device to put the model into
    download_root: str
        path to download the model files; by default, it uses "~/.cache/whisper"
    in_memory: bool
        whether to preload the model weights into host memory

    Returns
    -------
    model : Whisper
        The Whisper ASR model instance
    """

    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"
    if download_root is None:
        default = os.path.join(os.path.expanduser("~"), ".cache")
        download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")

    if name in _MODELS:
        checkpoint_file = _download(_MODELS[name], download_root, in_memory)
        alignment_heads = _ALIGNMENT_HEADS[name]
    elif os.path.isfile(name):
        checkpoint_file = open(name, "rb").read() if in_memory else name
        alignment_heads = None
    else:
        raise RuntimeError(
            f"Model {name} not found; available models = {available_models()}"
        )   

    with (
        io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
    ) as fp: 
        checkpoint = torch.load(fp, map_location=device)
    del checkpoint_file

    dims = ModelDimensions(**checkpoint["dims"])
    model = Whisper(dims)
    model.load_state_dict(checkpoint["model_state_dict"])

    if alignment_heads is not None:
        model.set_alignment_heads(alignment_heads)

    return model.to(device)

checkpoint_file = ~/.cache/whisper/small.pt

https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt からダウンロードしている

small.ptファイルのサイズは、483.6MB

メモリに余裕があれば、load_model関数にin_memory=Trueを渡して、実行速度を上げる

モデルのcheckpointを確認する

checkpointには、モデルのパラメータ値が保存されている

checkpointのtypeはdictで、dimsとmodel_state_dictというキーをもつ

checkpoint['dims']もdictでありキーは以下である
['n_mels', 'n_vocab', 'n_audio_ctx', 'n_audio_state', 'n_audio_head', 'n_audio_layer', 'n_text_ctx', 'n_text_state', 'n_text_head', 'n_text_layer']

checkpoint['model_state_dict']もdictでありキーは以下である
['decoder.positional_embedding', 'encoder.positional_embedding', 'decoder.token_embedding.weight', 'decoder.blocks.0.mlp_ln.weight', 'decoder.blocks.0.mlp_ln.bias', 'decoder.blocks.0.mlp.0.weight', 'decoder.blocks.0.mlp.0.bias', 'decoder.blocks.0.mlp.2.weight', 'decoder.blocks.0.mlp.2.bias', 'decoder.blocks.0.attn_ln.weight', 'decoder.blocks.0.attn_ln.bias', 'decoder.blocks.0.attn.query.weight', 'decoder.blocks.0.attn.query.bias', 'decoder.blocks.0.attn.key.weight', 'decoder.blocks.0.attn.value.weight', 'decoder.blocks.0.attn.value.bias', 'decoder.blocks.0.attn.out.weight', 'decoder.blocks.0.attn.out.bias', 'decoder.blocks.0.cross_attn_ln.weight', 'decoder.blocks.0.cross_attn_ln.bias', 'decoder.blocks.0.cross_attn.query.weight', 'decoder.blocks.0.cross_attn.query.bias', 'decoder.blocks.0.cross_attn.key.weight', 'decoder.blocks.0.cross_attn.value.weight', 'decoder.blocks.0.cross_attn.value.bias', 'decoder.blocks.0.cross_attn.out.weight', 'decoder.blocks.0.cross_attn.out.bias', 'decoder.blocks.1.mlp_ln.weight', 'decoder.blocks.1.mlp_ln.bias', 'decoder.blocks.1.mlp.0.weight', 'decoder.blocks.1.mlp.0.bias', 'decoder.blocks.1.mlp.2.weight', 'decoder.blocks.1.mlp.2.bias', 'decoder.blocks.1.attn_ln.weight', 'decoder.blocks.1.attn_ln.bias', 'decoder.blocks.1.attn.query.weight', 'decoder.blocks.1.attn.query.bias', 'decoder.blocks.1.attn.key.weight', 'decoder.blocks.1.attn.value.weight', 'decoder.blocks.1.attn.value.bias', 'decoder.blocks.1.attn.out.weight', 'decoder.blocks.1.attn.out.bias', 'decoder.blocks.1.cross_attn_ln.weight', 'decoder.blocks.1.cross_attn_ln.bias', 'decoder.blocks.1.cross_attn.query.weight', 'decoder.blocks.1.cross_attn.query.bias', 'decoder.blocks.1.cross_attn.key.weight', 'decoder.blocks.1.cross_attn.value.weight', 'decoder.blocks.1.cross_attn.value.bias', 'decoder.blocks.1.cross_attn.out.weight', 'decoder.blocks.1.cross_attn.out.bias', 'decoder.blocks.2.mlp_ln.weight', 'decoder.blocks.2.mlp_ln.bias', 'decoder.blocks.2.mlp.0.weight', 'decoder.blocks.2.mlp.0.bias', 'decoder.blocks.2.mlp.2.weight', 'decoder.blocks.2.mlp.2.bias', 'decoder.blocks.2.attn_ln.weight', 'decoder.blocks.2.attn_ln.bias', 'decoder.blocks.2.attn.query.weight', 'decoder.blocks.2.attn.query.bias', 'decoder.blocks.2.attn.key.weight', 'decoder.blocks.2.attn.value.weight', 'decoder.blocks.2.attn.value.bias', 'decoder.blocks.2.attn.out.weight', 'decoder.blocks.2.attn.out.bias', 'decoder.blocks.2.cross_attn_ln.weight', 'decoder.blocks.2.cross_attn_ln.bias', 'decoder.blocks.2.cross_attn.query.weight', 'decoder.blocks.2.cross_attn.query.bias', 'decoder.blocks.2.cross_attn.key.weight', 'decoder.blocks.2.cross_attn.value.weight', 'decoder.blocks.2.cross_attn.value.bias', 'decoder.blocks.2.cross_attn.out.weight', 'decoder.blocks.2.cross_attn.out.bias', 'decoder.blocks.3.mlp_ln.weight', 'decoder.blocks.3.mlp_ln.bias', 'decoder.blocks.3.mlp.0.weight', 'decoder.blocks.3.mlp.0.bias', 'decoder.blocks.3.mlp.2.weight', 'decoder.blocks.3.mlp.2.bias', 'decoder.blocks.3.attn_ln.weight', 'decoder.blocks.3.attn_ln.bias', 'decoder.blocks.3.attn.query.weight', 'decoder.blocks.3.attn.query.bias', 'decoder.blocks.3.attn.key.weight', 'decoder.blocks.3.attn.value.weight', 'decoder.blocks.3.attn.value.bias', 'decoder.blocks.3.attn.out.weight', 'decoder.blocks.3.attn.out.bias', 'decoder.blocks.3.cross_attn_ln.weight', 'decoder.blocks.3.cross_attn_ln.bias', 'decoder.blocks.3.cross_attn.query.weight', 'decoder.blocks.3.cross_attn.query.bias', 'decoder.blocks.3.cross_attn.key.weight', 'decoder.blocks.3.cross_attn.value.weight', 'decoder.blocks.3.cross_attn.value.bias', 'decoder.blocks.3.cross_attn.out.weight', 'decoder.blocks.3.cross_attn.out.bias', 'decoder.blocks.4.mlp_ln.weight', 'decoder.blocks.4.mlp_ln.bias', 'decoder.blocks.4.mlp.0.weight', 'decoder.blocks.4.mlp.0.bias', 'decoder.blocks.4.mlp.2.weight', 'decoder.blocks.4.mlp.2.bias', 'decoder.blocks.4.attn_ln.weight', 'decoder.blocks.4.attn_ln.bias', 'decoder.blocks.4.attn.query.weight', 'decoder.blocks.4.attn.query.bias', 'decoder.blocks.4.attn.key.weight', 'decoder.blocks.4.attn.value.weight', 'decoder.blocks.4.attn.value.bias', 'decoder.blocks.4.attn.out.weight', 'decoder.blocks.4.attn.out.bias', 'decoder.blocks.4.cross_attn_ln.weight', 'decoder.blocks.4.cross_attn_ln.bias', 'decoder.blocks.4.cross_attn.query.weight', 'decoder.blocks.4.cross_attn.query.bias', 'decoder.blocks.4.cross_attn.key.weight', 'decoder.blocks.4.cross_attn.value.weight', 'decoder.blocks.4.cross_attn.value.bias', 'decoder.blocks.4.cross_attn.out.weight', 'decoder.blocks.4.cross_attn.out.bias', 'decoder.blocks.5.mlp_ln.weight', 'decoder.blocks.5.mlp_ln.bias', 'decoder.blocks.5.mlp.0.weight', 'decoder.blocks.5.mlp.0.bias', 'decoder.blocks.5.mlp.2.weight', 'decoder.blocks.5.mlp.2.bias', 'decoder.blocks.5.attn_ln.weight', 'decoder.blocks.5.attn_ln.bias', 'decoder.blocks.5.attn.query.weight', 'decoder.blocks.5.attn.query.bias', 'decoder.blocks.5.attn.key.weight', 'decoder.blocks.5.attn.value.weight', 'decoder.blocks.5.attn.value.bias', 'decoder.blocks.5.attn.out.weight', 'decoder.blocks.5.attn.out.bias', 'decoder.blocks.5.cross_attn_ln.weight', 'decoder.blocks.5.cross_attn_ln.bias', 'decoder.blocks.5.cross_attn.query.weight', 'decoder.blocks.5.cross_attn.query.bias', 'decoder.blocks.5.cross_attn.key.weight', 'decoder.blocks.5.cross_attn.value.weight', 'decoder.blocks.5.cross_attn.value.bias', 'decoder.blocks.5.cross_attn.out.weight', 'decoder.blocks.5.cross_attn.out.bias', 'decoder.blocks.6.mlp_ln.weight', 'decoder.blocks.6.mlp_ln.bias', 'decoder.blocks.6.mlp.0.weight', 'decoder.blocks.6.mlp.0.bias', 'decoder.blocks.6.mlp.2.weight', 'decoder.blocks.6.mlp.2.bias', 'decoder.blocks.6.attn_ln.weight', 'decoder.blocks.6.attn_ln.bias', 'decoder.blocks.6.attn.query.weight', 'decoder.blocks.6.attn.query.bias', 'decoder.blocks.6.attn.key.weight', 'decoder.blocks.6.attn.value.weight', 'decoder.blocks.6.attn.value.bias', 'decoder.blocks.6.attn.out.weight', 'decoder.blocks.6.attn.out.bias', 'decoder.blocks.6.cross_attn_ln.weight', 'decoder.blocks.6.cross_attn_ln.bias', 'decoder.blocks.6.cross_attn.query.weight', 'decoder.blocks.6.cross_attn.query.bias', 'decoder.blocks.6.cross_attn.key.weight', 'decoder.blocks.6.cross_attn.value.weight', 'decoder.blocks.6.cross_attn.value.bias', 'decoder.blocks.6.cross_attn.out.weight', 'decoder.blocks.6.cross_attn.out.bias', 'decoder.blocks.7.mlp_ln.weight', 'decoder.blocks.7.mlp_ln.bias', 'decoder.blocks.7.mlp.0.weight', 'decoder.blocks.7.mlp.0.bias', 'decoder.blocks.7.mlp.2.weight', 'decoder.blocks.7.mlp.2.bias', 'decoder.blocks.7.attn_ln.weight', 'decoder.blocks.7.attn_ln.bias', 'decoder.blocks.7.attn.query.weight', 'decoder.blocks.7.attn.query.bias', 'decoder.blocks.7.attn.key.weight', 'decoder.blocks.7.attn.value.weight', 'decoder.blocks.7.attn.value.bias', 'decoder.blocks.7.attn.out.weight', 'decoder.blocks.7.attn.out.bias', 'decoder.blocks.7.cross_attn_ln.weight', 'decoder.blocks.7.cross_attn_ln.bias', 'decoder.blocks.7.cross_attn.query.weight', 'decoder.blocks.7.cross_attn.query.bias', 'decoder.blocks.7.cross_attn.key.weight', 'decoder.blocks.7.cross_attn.value.weight', 'decoder.blocks.7.cross_attn.value.bias', 'decoder.blocks.7.cross_attn.out.weight', 'decoder.blocks.7.cross_attn.out.bias', 'decoder.blocks.8.mlp_ln.weight', 'decoder.blocks.8.mlp_ln.bias', 'decoder.blocks.8.mlp.0.weight', 'decoder.blocks.8.mlp.0.bias', 'decoder.blocks.8.mlp.2.weight', 'decoder.blocks.8.mlp.2.bias', 'decoder.blocks.8.attn_ln.weight', 'decoder.blocks.8.attn_ln.bias', 'decoder.blocks.8.attn.query.weight', 'decoder.blocks.8.attn.query.bias', 'decoder.blocks.8.attn.key.weight', 'decoder.blocks.8.attn.value.weight', 'decoder.blocks.8.attn.value.bias', 'decoder.blocks.8.attn.out.weight', 'decoder.blocks.8.attn.out.bias', 'decoder.blocks.8.cross_attn_ln.weight', 'decoder.blocks.8.cross_attn_ln.bias', 'decoder.blocks.8.cross_attn.query.weight', 'decoder.blocks.8.cross_attn.query.bias', 'decoder.blocks.8.cross_attn.key.weight', 'decoder.blocks.8.cross_attn.value.weight', 'decoder.blocks.8.cross_attn.value.bias', 'decoder.blocks.8.cross_attn.out.weight', 'decoder.blocks.8.cross_attn.out.bias', 'decoder.blocks.9.mlp_ln.weight', 'decoder.blocks.9.mlp_ln.bias', 'decoder.blocks.9.mlp.0.weight', 'decoder.blocks.9.mlp.0.bias', 'decoder.blocks.9.mlp.2.weight', 'decoder.blocks.9.mlp.2.bias', 'decoder.blocks.9.attn_ln.weight', 'decoder.blocks.9.attn_ln.bias', 'decoder.blocks.9.attn.query.weight', 'decoder.blocks.9.attn.query.bias', 'decoder.blocks.9.attn.key.weight', 'decoder.blocks.9.attn.value.weight', 'decoder.blocks.9.attn.value.bias', 'decoder.blocks.9.attn.out.weight', 'decoder.blocks.9.attn.out.bias', 'decoder.blocks.9.cross_attn_ln.weight', 'decoder.blocks.9.cross_attn_ln.bias', 'decoder.blocks.9.cross_attn.query.weight', 'decoder.blocks.9.cross_attn.query.bias', 'decoder.blocks.9.cross_attn.key.weight', 'decoder.blocks.9.cross_attn.value.weight', 'decoder.blocks.9.cross_attn.value.bias', 'decoder.blocks.9.cross_attn.out.weight', 'decoder.blocks.9.cross_attn.out.bias', 'decoder.blocks.10.mlp_ln.weight', 'decoder.blocks.10.mlp_ln.bias', 'decoder.blocks.10.mlp.0.weight', 'decoder.blocks.10.mlp.0.bias', 'decoder.blocks.10.mlp.2.weight', 'decoder.blocks.10.mlp.2.bias', 'decoder.blocks.10.attn_ln.weight', 'decoder.blocks.10.attn_ln.bias', 'decoder.blocks.10.attn.query.weight', 'decoder.blocks.10.attn.query.bias', 'decoder.blocks.10.attn.key.weight', 'decoder.blocks.10.attn.value.weight', 'decoder.blocks.10.attn.value.bias', 'decoder.blocks.10.attn.out.weight', 'decoder.blocks.10.attn.out.bias', 'decoder.blocks.10.cross_attn_ln.weight', 'decoder.blocks.10.cross_attn_ln.bias', 'decoder.blocks.10.cross_attn.query.weight', 'decoder.blocks.10.cross_attn.query.bias', 'decoder.blocks.10.cross_attn.key.weight', 'decoder.blocks.10.cross_attn.value.weight', 'decoder.blocks.10.cross_attn.value.bias', 'decoder.blocks.10.cross_attn.out.weight', 'decoder.blocks.10.cross_attn.out.bias', 'decoder.blocks.11.mlp_ln.weight', 'decoder.blocks.11.mlp_ln.bias', 'decoder.blocks.11.mlp.0.weight', 'decoder.blocks.11.mlp.0.bias', 'decoder.blocks.11.mlp.2.weight', 'decoder.blocks.11.mlp.2.bias', 'decoder.blocks.11.attn_ln.weight', 'decoder.blocks.11.attn_ln.bias', 'decoder.blocks.11.attn.query.weight', 'decoder.blocks.11.attn.query.bias', 'decoder.blocks.11.attn.key.weight', 'decoder.blocks.11.attn.value.weight', 'decoder.blocks.11.attn.value.bias', 'decoder.blocks.11.attn.out.weight', 'decoder.blocks.11.attn.out.bias', 'decoder.blocks.11.cross_attn_ln.weight', 'decoder.blocks.11.cross_attn_ln.bias', 'decoder.blocks.11.cross_attn.query.weight', 'decoder.blocks.11.cross_attn.query.bias', 'decoder.blocks.11.cross_attn.key.weight', 'decoder.blocks.11.cross_attn.value.weight', 'decoder.blocks.11.cross_attn.value.bias', 'decoder.blocks.11.cross_attn.out.weight', 'decoder.blocks.11.cross_attn.out.bias', 'decoder.ln.weight', 'decoder.ln.bias', 'encoder.conv1.weight', 'encoder.conv1.bias', 'encoder.conv2.weight', 'encoder.conv2.bias', 'encoder.blocks.0.mlp_ln.weight', 'encoder.blocks.0.mlp_ln.bias', 'encoder.blocks.0.mlp.0.weight', 'encoder.blocks.0.mlp.0.bias', 'encoder.blocks.0.mlp.2.weight', 'encoder.blocks.0.mlp.2.bias', 'encoder.blocks.0.attn_ln.weight', 'encoder.blocks.0.attn_ln.bias', 'encoder.blocks.0.attn.query.weight', 'encoder.blocks.0.attn.query.bias', 'encoder.blocks.0.attn.key.weight', 'encoder.blocks.0.attn.value.weight', 'encoder.blocks.0.attn.value.bias', 'encoder.blocks.0.attn.out.weight', 'encoder.blocks.0.attn.out.bias', 'encoder.blocks.1.mlp_ln.weight', 'encoder.blocks.1.mlp_ln.bias', 'encoder.blocks.1.mlp.0.weight', 'encoder.blocks.1.mlp.0.bias', 'encoder.blocks.1.mlp.2.weight', 'encoder.blocks.1.mlp.2.bias', 'encoder.blocks.1.attn_ln.weight', 'encoder.blocks.1.attn_ln.bias', 'encoder.blocks.1.attn.query.weight', 'encoder.blocks.1.attn.query.bias', 'encoder.blocks.1.attn.key.weight', 'encoder.blocks.1.attn.value.weight', 'encoder.blocks.1.attn.value.bias', 'encoder.blocks.1.attn.out.weight', 'encoder.blocks.1.attn.out.bias', 'encoder.blocks.2.mlp_ln.weight', 'encoder.blocks.2.mlp_ln.bias', 'encoder.blocks.2.mlp.0.weight', 'encoder.blocks.2.mlp.0.bias', 'encoder.blocks.2.mlp.2.weight', 'encoder.blocks.2.mlp.2.bias', 'encoder.blocks.2.attn_ln.weight', 'encoder.blocks.2.attn_ln.bias', 'encoder.blocks.2.attn.query.weight', 'encoder.blocks.2.attn.query.bias', 'encoder.blocks.2.attn.key.weight', 'encoder.blocks.2.attn.value.weight', 'encoder.blocks.2.attn.value.bias', 'encoder.blocks.2.attn.out.weight', 'encoder.blocks.2.attn.out.bias', 'encoder.blocks.3.mlp_ln.weight', 'encoder.blocks.3.mlp_ln.bias', 'encoder.blocks.3.mlp.0.weight', 'encoder.blocks.3.mlp.0.bias', 'encoder.blocks.3.mlp.2.weight', 'encoder.blocks.3.mlp.2.bias', 'encoder.blocks.3.attn_ln.weight', 'encoder.blocks.3.attn_ln.bias', 'encoder.blocks.3.attn.query.weight', 'encoder.blocks.3.attn.query.bias', 'encoder.blocks.3.attn.key.weight', 'encoder.blocks.3.attn.value.weight', 'encoder.blocks.3.attn.value.bias', 'encoder.blocks.3.attn.out.weight', 'encoder.blocks.3.attn.out.bias', 'encoder.blocks.4.mlp_ln.weight', 'encoder.blocks.4.mlp_ln.bias', 'encoder.blocks.4.mlp.0.weight', 'encoder.blocks.4.mlp.0.bias', 'encoder.blocks.4.mlp.2.weight', 'encoder.blocks.4.mlp.2.bias', 'encoder.blocks.4.attn_ln.weight', 'encoder.blocks.4.attn_ln.bias', 'encoder.blocks.4.attn.query.weight', 'encoder.blocks.4.attn.query.bias', 'encoder.blocks.4.attn.key.weight', 'encoder.blocks.4.attn.value.weight', 'encoder.blocks.4.attn.value.bias', 'encoder.blocks.4.attn.out.weight', 'encoder.blocks.4.attn.out.bias', 'encoder.blocks.5.mlp_ln.weight', 'encoder.blocks.5.mlp_ln.bias', 'encoder.blocks.5.mlp.0.weight', 'encoder.blocks.5.mlp.0.bias', 'encoder.blocks.5.mlp.2.weight', 'encoder.blocks.5.mlp.2.bias', 'encoder.blocks.5.attn_ln.weight', 'encoder.blocks.5.attn_ln.bias', 'encoder.blocks.5.attn.query.weight', 'encoder.blocks.5.attn.query.bias', 'encoder.blocks.5.attn.key.weight', 'encoder.blocks.5.attn.value.weight', 'encoder.blocks.5.attn.value.bias', 'encoder.blocks.5.attn.out.weight', 'encoder.blocks.5.attn.out.bias', 'encoder.blocks.6.mlp_ln.weight', 'encoder.blocks.6.mlp_ln.bias', 'encoder.blocks.6.mlp.0.weight', 'encoder.blocks.6.mlp.0.bias', 'encoder.blocks.6.mlp.2.weight', 'encoder.blocks.6.mlp.2.bias', 'encoder.blocks.6.attn_ln.weight', 'encoder.blocks.6.attn_ln.bias', 'encoder.blocks.6.attn.query.weight', 'encoder.blocks.6.attn.query.bias', 'encoder.blocks.6.attn.key.weight', 'encoder.blocks.6.attn.value.weight', 'encoder.blocks.6.attn.value.bias', 'encoder.blocks.6.attn.out.weight', 'encoder.blocks.6.attn.out.bias', 'encoder.blocks.7.mlp_ln.weight', 'encoder.blocks.7.mlp_ln.bias', 'encoder.blocks.7.mlp.0.weight', 'encoder.blocks.7.mlp.0.bias', 'encoder.blocks.7.mlp.2.weight', 'encoder.blocks.7.mlp.2.bias', 'encoder.blocks.7.attn_ln.weight', 'encoder.blocks.7.attn_ln.bias', 'encoder.blocks.7.attn.query.weight', 'encoder.blocks.7.attn.query.bias', 'encoder.blocks.7.attn.key.weight', 'encoder.blocks.7.attn.value.weight', 'encoder.blocks.7.attn.value.bias', 'encoder.blocks.7.attn.out.weight', 'encoder.blocks.7.attn.out.bias', 'encoder.blocks.8.mlp_ln.weight', 'encoder.blocks.8.mlp_ln.bias', 'encoder.blocks.8.mlp.0.weight', 'encoder.blocks.8.mlp.0.bias', 'encoder.blocks.8.mlp.2.weight', 'encoder.blocks.8.mlp.2.bias', 'encoder.blocks.8.attn_ln.weight', 'encoder.blocks.8.attn_ln.bias', 'encoder.blocks.8.attn.query.weight', 'encoder.blocks.8.attn.query.bias', 'encoder.blocks.8.attn.key.weight', 'encoder.blocks.8.attn.value.weight', 'encoder.blocks.8.attn.value.bias', 'encoder.blocks.8.attn.out.weight', 'encoder.blocks.8.attn.out.bias', 'encoder.blocks.9.mlp_ln.weight', 'encoder.blocks.9.mlp_ln.bias', 'encoder.blocks.9.mlp.0.weight', 'encoder.blocks.9.mlp.0.bias', 'encoder.blocks.9.mlp.2.weight', 'encoder.blocks.9.mlp.2.bias', 'encoder.blocks.9.attn_ln.weight', 'encoder.blocks.9.attn_ln.bias', 'encoder.blocks.9.attn.query.weight', 'encoder.blocks.9.attn.query.bias', 'encoder.blocks.9.attn.key.weight', 'encoder.blocks.9.attn.value.weight', 'encoder.blocks.9.attn.value.bias', 'encoder.blocks.9.attn.out.weight', 'encoder.blocks.9.attn.out.bias', 'encoder.blocks.10.mlp_ln.weight', 'encoder.blocks.10.mlp_ln.bias', 'encoder.blocks.10.mlp.0.weight', 'encoder.blocks.10.mlp.0.bias', 'encoder.blocks.10.mlp.2.weight', 'encoder.blocks.10.mlp.2.bias', 'encoder.blocks.10.attn_ln.weight', 'encoder.blocks.10.attn_ln.bias', 'encoder.blocks.10.attn.query.weight', 'encoder.blocks.10.attn.query.bias', 'encoder.blocks.10.attn.key.weight', 'encoder.blocks.10.attn.value.weight', 'encoder.blocks.10.attn.value.bias', 'encoder.blocks.10.attn.out.weight', 'encoder.blocks.10.attn.out.bias', 'encoder.blocks.11.mlp_ln.weight', 'encoder.blocks.11.mlp_ln.bias', 'encoder.blocks.11.mlp.0.weight', 'encoder.blocks.11.mlp.0.bias', 'encoder.blocks.11.mlp.2.weight', 'encoder.blocks.11.mlp.2.bias', 'encoder.blocks.11.attn_ln.weight', 'encoder.blocks.11.attn_ln.bias', 'encoder.blocks.11.attn.query.weight', 'encoder.blocks.11.attn.query.bias', 'encoder.blocks.11.attn.key.weight', 'encoder.blocks.11.attn.value.weight', 'encoder.blocks.11.attn.value.bias', 'encoder.blocks.11.attn.out.weight', 'encoder.blocks.11.attn.out.bias', 'encoder.ln_post.weight', 'encoder.ln_post.bias']

checkpointからModelDimensionsのインスタンスを生成

ModelDimensionsクラス

@dataclass
class ModelDimensions:         
    n_mels: int
    n_audio_ctx: int           
    n_audio_state: int         
    n_audio_head: int          
    n_audio_layer: int         
    n_vocab: int               
    n_text_ctx: int            
    n_text_state: int          
    n_text_head: int           
    n_text_layer: int
dims = ModelDimensions(**checkpoint["dims"])
print(dims)

ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=768, n_audio_head=12, n_audio_layer=12, n_vocab=51865, n_text_ctx=448, n_text_state=768, n_text_head=12, n_text_layer=12)

Whisperインスタンスの生成

Whisperクラス

class Whisper(nn.Module):
    def __init__(self, dims: ModelDimensions):
        super().__init__()
        self.dims = dims
        self.encoder = AudioEncoder(
            self.dims.n_mels,
            self.dims.n_audio_ctx,
            self.dims.n_audio_state,
            self.dims.n_audio_head,
            self.dims.n_audio_layer,
        )
        self.decoder = TextDecoder(
            self.dims.n_vocab,
            self.dims.n_text_ctx,
            self.dims.n_text_state,
            self.dims.n_text_head,
            self.dims.n_text_layer,
        )
        # use the last half among the decoder layers for time alignment by default;
        # to use a specific set of heads, see `set_alignment_heads()` below.
        all_heads = torch.zeros(
            self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
        )
        all_heads[self.dims.n_text_layer // 2 :] = True
        self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)

    def set_alignment_heads(self, dump: bytes):
        array = np.frombuffer(
            gzip.decompress(base64.b85decode(dump)), dtype=bool
        ).copy()
        mask = torch.from_numpy(array).reshape(
            self.dims.n_text_layer, self.dims.n_text_head
        )
        self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)

    def embed_audio(self, mel: torch.Tensor):
        return self.encoder(mel)

    def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
        return self.decoder(tokens, audio_features)

    def forward(
        self, mel: torch.Tensor, tokens: torch.Tensor
    ) -> Dict[str, torch.Tensor]:
        return self.decoder(tokens, self.encoder(mel))

    @property
    def device(self):
        return next(self.parameters()).device

    @property
    def is_multilingual(self):
        return self.dims.n_vocab >= 51865

    @property
    def num_languages(self):
        return self.dims.n_vocab - 51765 - int(self.is_multilingual)

    def install_kv_cache_hooks(self, cache: Optional[dict] = None):
        """
        The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
        tensors calculated for the previous positions. This method returns a dictionary that stores
        all caches, and the necessary hooks for the key and value projection modules that save the
        intermediate tensors to be reused during later calculations.

        Returns
        -------
        cache : Dict[nn.Module, torch.Tensor]
            A dictionary object mapping the key/value projection modules to its cache
        hooks : List[RemovableHandle]
            List of PyTorch RemovableHandle objects to stop the hooks to be called
        """
        cache = {**cache} if cache is not None else {}
        hooks = []

        def save_to_cache(module, _, output):
            if module not in cache or output.shape[1] > self.dims.n_text_ctx:
                # save as-is, for the first token or cross attention
                cache[module] = output
            else:
                cache[module] = torch.cat([cache[module], output], dim=1).detach()
            return cache[module]

        def install_hooks(layer: nn.Module):
            if isinstance(layer, MultiHeadAttention):
                hooks.append(layer.key.register_forward_hook(save_to_cache))
                hooks.append(layer.value.register_forward_hook(save_to_cache))

        self.decoder.apply(install_hooks)
        return cache, hooks

    detect_language = detect_language_function
    transcribe = transcribe_function
    decode = decode_function
model = Whisper(dims)
print(model)

Whisper(
  (encoder): AudioEncoder(
    (conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
    (conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
    (blocks): ModuleList(
      (0-11): 12 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=False)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (out): Linear(in_features=768, out_features=768, bias=True)
        )
        (attn_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (mlp_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): TextDecoder(
    (token_embedding): Embedding(51865, 768)
    (blocks): ModuleList(
      (0-11): 12 x ResidualAttentionBlock(
        (attn): MultiHeadAttention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=False)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (out): Linear(in_features=768, out_features=768, bias=True)
        )
        (attn_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (cross_attn): MultiHeadAttention(
          (query): Linear(in_features=768, out_features=768, bias=True)
          (key): Linear(in_features=768, out_features=768, bias=False)
          (value): Linear(in_features=768, out_features=768, bias=True)
          (out): Linear(in_features=768, out_features=768, bias=True)
        )
        (cross_attn_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
        (mlp_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
    (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)

関連記事

カテゴリー

アーカイブ

Lang »