whisperのデバッグ環境構築
whisperのソースコードをチェックアウトする
git clone https://github.com/openai/whisper.git
cd whisper
test.pyの作成
import sys
from whisper.transcribe import cli
if __name__ == '__main__':
sys.exit(cli())
test.pyの実行
python test.py sample.mp4 --language English
transcribe.py
def cli():
・・・
model = load_model(model_name, device=device, download_root=model_dir)
・・・
ソースを確認する
load_model関数を読む
デフォルト時の引数は、model_name=small, device=cpu, download_root=Noneが使用される
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
with (
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
) as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)
checkpoint_file = ~/.cache/whisper/small.pt
small.ptファイルのサイズは、483.6MB
メモリに余裕があれば、load_model関数にin_memory=Trueを渡して、実行速度を上げる
モデルのcheckpointを確認する
checkpointには、モデルのパラメータ値が保存されている
checkpointのtypeはdictで、dimsとmodel_state_dictというキーをもつ
checkpoint['dims']もdictでありキーは以下である
['n_mels', 'n_vocab', 'n_audio_ctx', 'n_audio_state', 'n_audio_head', 'n_audio_layer', 'n_text_ctx', 'n_text_state', 'n_text_head', 'n_text_layer']
checkpoint['model_state_dict']もdictでありキーは以下である
['decoder.positional_embedding', 'encoder.positional_embedding', 'decoder.token_embedding.weight', 'decoder.blocks.0.mlp_ln.weight', 'decoder.blocks.0.mlp_ln.bias', 'decoder.blocks.0.mlp.0.weight', 'decoder.blocks.0.mlp.0.bias', 'decoder.blocks.0.mlp.2.weight', 'decoder.blocks.0.mlp.2.bias', 'decoder.blocks.0.attn_ln.weight', 'decoder.blocks.0.attn_ln.bias', 'decoder.blocks.0.attn.query.weight', 'decoder.blocks.0.attn.query.bias', 'decoder.blocks.0.attn.key.weight', 'decoder.blocks.0.attn.value.weight', 'decoder.blocks.0.attn.value.bias', 'decoder.blocks.0.attn.out.weight', 'decoder.blocks.0.attn.out.bias', 'decoder.blocks.0.cross_attn_ln.weight', 'decoder.blocks.0.cross_attn_ln.bias', 'decoder.blocks.0.cross_attn.query.weight', 'decoder.blocks.0.cross_attn.query.bias', 'decoder.blocks.0.cross_attn.key.weight', 'decoder.blocks.0.cross_attn.value.weight', 'decoder.blocks.0.cross_attn.value.bias', 'decoder.blocks.0.cross_attn.out.weight', 'decoder.blocks.0.cross_attn.out.bias', 'decoder.blocks.1.mlp_ln.weight', 'decoder.blocks.1.mlp_ln.bias', 'decoder.blocks.1.mlp.0.weight', 'decoder.blocks.1.mlp.0.bias', 'decoder.blocks.1.mlp.2.weight', 'decoder.blocks.1.mlp.2.bias', 'decoder.blocks.1.attn_ln.weight', 'decoder.blocks.1.attn_ln.bias', 'decoder.blocks.1.attn.query.weight', 'decoder.blocks.1.attn.query.bias', 'decoder.blocks.1.attn.key.weight', 'decoder.blocks.1.attn.value.weight', 'decoder.blocks.1.attn.value.bias', 'decoder.blocks.1.attn.out.weight', 'decoder.blocks.1.attn.out.bias', 'decoder.blocks.1.cross_attn_ln.weight', 'decoder.blocks.1.cross_attn_ln.bias', 'decoder.blocks.1.cross_attn.query.weight', 'decoder.blocks.1.cross_attn.query.bias', 'decoder.blocks.1.cross_attn.key.weight', 'decoder.blocks.1.cross_attn.value.weight', 'decoder.blocks.1.cross_attn.value.bias', 'decoder.blocks.1.cross_attn.out.weight', 'decoder.blocks.1.cross_attn.out.bias', 'decoder.blocks.2.mlp_ln.weight', 'decoder.blocks.2.mlp_ln.bias', 'decoder.blocks.2.mlp.0.weight', 'decoder.blocks.2.mlp.0.bias', 'decoder.blocks.2.mlp.2.weight', 'decoder.blocks.2.mlp.2.bias', 'decoder.blocks.2.attn_ln.weight', 'decoder.blocks.2.attn_ln.bias', 'decoder.blocks.2.attn.query.weight', 'decoder.blocks.2.attn.query.bias', 'decoder.blocks.2.attn.key.weight', 'decoder.blocks.2.attn.value.weight', 'decoder.blocks.2.attn.value.bias', 'decoder.blocks.2.attn.out.weight', 'decoder.blocks.2.attn.out.bias', 'decoder.blocks.2.cross_attn_ln.weight', 'decoder.blocks.2.cross_attn_ln.bias', 'decoder.blocks.2.cross_attn.query.weight', 'decoder.blocks.2.cross_attn.query.bias', 'decoder.blocks.2.cross_attn.key.weight', 'decoder.blocks.2.cross_attn.value.weight', 'decoder.blocks.2.cross_attn.value.bias', 'decoder.blocks.2.cross_attn.out.weight', 'decoder.blocks.2.cross_attn.out.bias', 'decoder.blocks.3.mlp_ln.weight', 'decoder.blocks.3.mlp_ln.bias', 'decoder.blocks.3.mlp.0.weight', 'decoder.blocks.3.mlp.0.bias', 'decoder.blocks.3.mlp.2.weight', 'decoder.blocks.3.mlp.2.bias', 'decoder.blocks.3.attn_ln.weight', 'decoder.blocks.3.attn_ln.bias', 'decoder.blocks.3.attn.query.weight', 'decoder.blocks.3.attn.query.bias', 'decoder.blocks.3.attn.key.weight', 'decoder.blocks.3.attn.value.weight', 'decoder.blocks.3.attn.value.bias', 'decoder.blocks.3.attn.out.weight', 'decoder.blocks.3.attn.out.bias', 'decoder.blocks.3.cross_attn_ln.weight', 'decoder.blocks.3.cross_attn_ln.bias', 'decoder.blocks.3.cross_attn.query.weight', 'decoder.blocks.3.cross_attn.query.bias', 'decoder.blocks.3.cross_attn.key.weight', 'decoder.blocks.3.cross_attn.value.weight', 'decoder.blocks.3.cross_attn.value.bias', 'decoder.blocks.3.cross_attn.out.weight', 'decoder.blocks.3.cross_attn.out.bias', 'decoder.blocks.4.mlp_ln.weight', 'decoder.blocks.4.mlp_ln.bias', 'decoder.blocks.4.mlp.0.weight', 'decoder.blocks.4.mlp.0.bias', 'decoder.blocks.4.mlp.2.weight', 'decoder.blocks.4.mlp.2.bias', 'decoder.blocks.4.attn_ln.weight', 'decoder.blocks.4.attn_ln.bias', 'decoder.blocks.4.attn.query.weight', 'decoder.blocks.4.attn.query.bias', 'decoder.blocks.4.attn.key.weight', 'decoder.blocks.4.attn.value.weight', 'decoder.blocks.4.attn.value.bias', 'decoder.blocks.4.attn.out.weight', 'decoder.blocks.4.attn.out.bias', 'decoder.blocks.4.cross_attn_ln.weight', 'decoder.blocks.4.cross_attn_ln.bias', 'decoder.blocks.4.cross_attn.query.weight', 'decoder.blocks.4.cross_attn.query.bias', 'decoder.blocks.4.cross_attn.key.weight', 'decoder.blocks.4.cross_attn.value.weight', 'decoder.blocks.4.cross_attn.value.bias', 'decoder.blocks.4.cross_attn.out.weight', 'decoder.blocks.4.cross_attn.out.bias', 'decoder.blocks.5.mlp_ln.weight', 'decoder.blocks.5.mlp_ln.bias', 'decoder.blocks.5.mlp.0.weight', 'decoder.blocks.5.mlp.0.bias', 'decoder.blocks.5.mlp.2.weight', 'decoder.blocks.5.mlp.2.bias', 'decoder.blocks.5.attn_ln.weight', 'decoder.blocks.5.attn_ln.bias', 'decoder.blocks.5.attn.query.weight', 'decoder.blocks.5.attn.query.bias', 'decoder.blocks.5.attn.key.weight', 'decoder.blocks.5.attn.value.weight', 'decoder.blocks.5.attn.value.bias', 'decoder.blocks.5.attn.out.weight', 'decoder.blocks.5.attn.out.bias', 'decoder.blocks.5.cross_attn_ln.weight', 'decoder.blocks.5.cross_attn_ln.bias', 'decoder.blocks.5.cross_attn.query.weight', 'decoder.blocks.5.cross_attn.query.bias', 'decoder.blocks.5.cross_attn.key.weight', 'decoder.blocks.5.cross_attn.value.weight', 'decoder.blocks.5.cross_attn.value.bias', 'decoder.blocks.5.cross_attn.out.weight', 'decoder.blocks.5.cross_attn.out.bias', 'decoder.blocks.6.mlp_ln.weight', 'decoder.blocks.6.mlp_ln.bias', 'decoder.blocks.6.mlp.0.weight', 'decoder.blocks.6.mlp.0.bias', 'decoder.blocks.6.mlp.2.weight', 'decoder.blocks.6.mlp.2.bias', 'decoder.blocks.6.attn_ln.weight', 'decoder.blocks.6.attn_ln.bias', 'decoder.blocks.6.attn.query.weight', 'decoder.blocks.6.attn.query.bias', 'decoder.blocks.6.attn.key.weight', 'decoder.blocks.6.attn.value.weight', 'decoder.blocks.6.attn.value.bias', 'decoder.blocks.6.attn.out.weight', 'decoder.blocks.6.attn.out.bias', 'decoder.blocks.6.cross_attn_ln.weight', 'decoder.blocks.6.cross_attn_ln.bias', 'decoder.blocks.6.cross_attn.query.weight', 'decoder.blocks.6.cross_attn.query.bias', 'decoder.blocks.6.cross_attn.key.weight', 'decoder.blocks.6.cross_attn.value.weight', 'decoder.blocks.6.cross_attn.value.bias', 'decoder.blocks.6.cross_attn.out.weight', 'decoder.blocks.6.cross_attn.out.bias', 'decoder.blocks.7.mlp_ln.weight', 'decoder.blocks.7.mlp_ln.bias', 'decoder.blocks.7.mlp.0.weight', 'decoder.blocks.7.mlp.0.bias', 'decoder.blocks.7.mlp.2.weight', 'decoder.blocks.7.mlp.2.bias', 'decoder.blocks.7.attn_ln.weight', 'decoder.blocks.7.attn_ln.bias', 'decoder.blocks.7.attn.query.weight', 'decoder.blocks.7.attn.query.bias', 'decoder.blocks.7.attn.key.weight', 'decoder.blocks.7.attn.value.weight', 'decoder.blocks.7.attn.value.bias', 'decoder.blocks.7.attn.out.weight', 'decoder.blocks.7.attn.out.bias', 'decoder.blocks.7.cross_attn_ln.weight', 'decoder.blocks.7.cross_attn_ln.bias', 'decoder.blocks.7.cross_attn.query.weight', 'decoder.blocks.7.cross_attn.query.bias', 'decoder.blocks.7.cross_attn.key.weight', 'decoder.blocks.7.cross_attn.value.weight', 'decoder.blocks.7.cross_attn.value.bias', 'decoder.blocks.7.cross_attn.out.weight', 'decoder.blocks.7.cross_attn.out.bias', 'decoder.blocks.8.mlp_ln.weight', 'decoder.blocks.8.mlp_ln.bias', 'decoder.blocks.8.mlp.0.weight', 'decoder.blocks.8.mlp.0.bias', 'decoder.blocks.8.mlp.2.weight', 'decoder.blocks.8.mlp.2.bias', 'decoder.blocks.8.attn_ln.weight', 'decoder.blocks.8.attn_ln.bias', 'decoder.blocks.8.attn.query.weight', 'decoder.blocks.8.attn.query.bias', 'decoder.blocks.8.attn.key.weight', 'decoder.blocks.8.attn.value.weight', 'decoder.blocks.8.attn.value.bias', 'decoder.blocks.8.attn.out.weight', 'decoder.blocks.8.attn.out.bias', 'decoder.blocks.8.cross_attn_ln.weight', 'decoder.blocks.8.cross_attn_ln.bias', 'decoder.blocks.8.cross_attn.query.weight', 'decoder.blocks.8.cross_attn.query.bias', 'decoder.blocks.8.cross_attn.key.weight', 'decoder.blocks.8.cross_attn.value.weight', 'decoder.blocks.8.cross_attn.value.bias', 'decoder.blocks.8.cross_attn.out.weight', 'decoder.blocks.8.cross_attn.out.bias', 'decoder.blocks.9.mlp_ln.weight', 'decoder.blocks.9.mlp_ln.bias', 'decoder.blocks.9.mlp.0.weight', 'decoder.blocks.9.mlp.0.bias', 'decoder.blocks.9.mlp.2.weight', 'decoder.blocks.9.mlp.2.bias', 'decoder.blocks.9.attn_ln.weight', 'decoder.blocks.9.attn_ln.bias', 'decoder.blocks.9.attn.query.weight', 'decoder.blocks.9.attn.query.bias', 'decoder.blocks.9.attn.key.weight', 'decoder.blocks.9.attn.value.weight', 'decoder.blocks.9.attn.value.bias', 'decoder.blocks.9.attn.out.weight', 'decoder.blocks.9.attn.out.bias', 'decoder.blocks.9.cross_attn_ln.weight', 'decoder.blocks.9.cross_attn_ln.bias', 'decoder.blocks.9.cross_attn.query.weight', 'decoder.blocks.9.cross_attn.query.bias', 'decoder.blocks.9.cross_attn.key.weight', 'decoder.blocks.9.cross_attn.value.weight', 'decoder.blocks.9.cross_attn.value.bias', 'decoder.blocks.9.cross_attn.out.weight', 'decoder.blocks.9.cross_attn.out.bias', 'decoder.blocks.10.mlp_ln.weight', 'decoder.blocks.10.mlp_ln.bias', 'decoder.blocks.10.mlp.0.weight', 'decoder.blocks.10.mlp.0.bias', 'decoder.blocks.10.mlp.2.weight', 'decoder.blocks.10.mlp.2.bias', 'decoder.blocks.10.attn_ln.weight', 'decoder.blocks.10.attn_ln.bias', 'decoder.blocks.10.attn.query.weight', 'decoder.blocks.10.attn.query.bias', 'decoder.blocks.10.attn.key.weight', 'decoder.blocks.10.attn.value.weight', 'decoder.blocks.10.attn.value.bias', 'decoder.blocks.10.attn.out.weight', 'decoder.blocks.10.attn.out.bias', 'decoder.blocks.10.cross_attn_ln.weight', 'decoder.blocks.10.cross_attn_ln.bias', 'decoder.blocks.10.cross_attn.query.weight', 'decoder.blocks.10.cross_attn.query.bias', 'decoder.blocks.10.cross_attn.key.weight', 'decoder.blocks.10.cross_attn.value.weight', 'decoder.blocks.10.cross_attn.value.bias', 'decoder.blocks.10.cross_attn.out.weight', 'decoder.blocks.10.cross_attn.out.bias', 'decoder.blocks.11.mlp_ln.weight', 'decoder.blocks.11.mlp_ln.bias', 'decoder.blocks.11.mlp.0.weight', 'decoder.blocks.11.mlp.0.bias', 'decoder.blocks.11.mlp.2.weight', 'decoder.blocks.11.mlp.2.bias', 'decoder.blocks.11.attn_ln.weight', 'decoder.blocks.11.attn_ln.bias', 'decoder.blocks.11.attn.query.weight', 'decoder.blocks.11.attn.query.bias', 'decoder.blocks.11.attn.key.weight', 'decoder.blocks.11.attn.value.weight', 'decoder.blocks.11.attn.value.bias', 'decoder.blocks.11.attn.out.weight', 'decoder.blocks.11.attn.out.bias', 'decoder.blocks.11.cross_attn_ln.weight', 'decoder.blocks.11.cross_attn_ln.bias', 'decoder.blocks.11.cross_attn.query.weight', 'decoder.blocks.11.cross_attn.query.bias', 'decoder.blocks.11.cross_attn.key.weight', 'decoder.blocks.11.cross_attn.value.weight', 'decoder.blocks.11.cross_attn.value.bias', 'decoder.blocks.11.cross_attn.out.weight', 'decoder.blocks.11.cross_attn.out.bias', 'decoder.ln.weight', 'decoder.ln.bias', 'encoder.conv1.weight', 'encoder.conv1.bias', 'encoder.conv2.weight', 'encoder.conv2.bias', 'encoder.blocks.0.mlp_ln.weight', 'encoder.blocks.0.mlp_ln.bias', 'encoder.blocks.0.mlp.0.weight', 'encoder.blocks.0.mlp.0.bias', 'encoder.blocks.0.mlp.2.weight', 'encoder.blocks.0.mlp.2.bias', 'encoder.blocks.0.attn_ln.weight', 'encoder.blocks.0.attn_ln.bias', 'encoder.blocks.0.attn.query.weight', 'encoder.blocks.0.attn.query.bias', 'encoder.blocks.0.attn.key.weight', 'encoder.blocks.0.attn.value.weight', 'encoder.blocks.0.attn.value.bias', 'encoder.blocks.0.attn.out.weight', 'encoder.blocks.0.attn.out.bias', 'encoder.blocks.1.mlp_ln.weight', 'encoder.blocks.1.mlp_ln.bias', 'encoder.blocks.1.mlp.0.weight', 'encoder.blocks.1.mlp.0.bias', 'encoder.blocks.1.mlp.2.weight', 'encoder.blocks.1.mlp.2.bias', 'encoder.blocks.1.attn_ln.weight', 'encoder.blocks.1.attn_ln.bias', 'encoder.blocks.1.attn.query.weight', 'encoder.blocks.1.attn.query.bias', 'encoder.blocks.1.attn.key.weight', 'encoder.blocks.1.attn.value.weight', 'encoder.blocks.1.attn.value.bias', 'encoder.blocks.1.attn.out.weight', 'encoder.blocks.1.attn.out.bias', 'encoder.blocks.2.mlp_ln.weight', 'encoder.blocks.2.mlp_ln.bias', 'encoder.blocks.2.mlp.0.weight', 'encoder.blocks.2.mlp.0.bias', 'encoder.blocks.2.mlp.2.weight', 'encoder.blocks.2.mlp.2.bias', 'encoder.blocks.2.attn_ln.weight', 'encoder.blocks.2.attn_ln.bias', 'encoder.blocks.2.attn.query.weight', 'encoder.blocks.2.attn.query.bias', 'encoder.blocks.2.attn.key.weight', 'encoder.blocks.2.attn.value.weight', 'encoder.blocks.2.attn.value.bias', 'encoder.blocks.2.attn.out.weight', 'encoder.blocks.2.attn.out.bias', 'encoder.blocks.3.mlp_ln.weight', 'encoder.blocks.3.mlp_ln.bias', 'encoder.blocks.3.mlp.0.weight', 'encoder.blocks.3.mlp.0.bias', 'encoder.blocks.3.mlp.2.weight', 'encoder.blocks.3.mlp.2.bias', 'encoder.blocks.3.attn_ln.weight', 'encoder.blocks.3.attn_ln.bias', 'encoder.blocks.3.attn.query.weight', 'encoder.blocks.3.attn.query.bias', 'encoder.blocks.3.attn.key.weight', 'encoder.blocks.3.attn.value.weight', 'encoder.blocks.3.attn.value.bias', 'encoder.blocks.3.attn.out.weight', 'encoder.blocks.3.attn.out.bias', 'encoder.blocks.4.mlp_ln.weight', 'encoder.blocks.4.mlp_ln.bias', 'encoder.blocks.4.mlp.0.weight', 'encoder.blocks.4.mlp.0.bias', 'encoder.blocks.4.mlp.2.weight', 'encoder.blocks.4.mlp.2.bias', 'encoder.blocks.4.attn_ln.weight', 'encoder.blocks.4.attn_ln.bias', 'encoder.blocks.4.attn.query.weight', 'encoder.blocks.4.attn.query.bias', 'encoder.blocks.4.attn.key.weight', 'encoder.blocks.4.attn.value.weight', 'encoder.blocks.4.attn.value.bias', 'encoder.blocks.4.attn.out.weight', 'encoder.blocks.4.attn.out.bias', 'encoder.blocks.5.mlp_ln.weight', 'encoder.blocks.5.mlp_ln.bias', 'encoder.blocks.5.mlp.0.weight', 'encoder.blocks.5.mlp.0.bias', 'encoder.blocks.5.mlp.2.weight', 'encoder.blocks.5.mlp.2.bias', 'encoder.blocks.5.attn_ln.weight', 'encoder.blocks.5.attn_ln.bias', 'encoder.blocks.5.attn.query.weight', 'encoder.blocks.5.attn.query.bias', 'encoder.blocks.5.attn.key.weight', 'encoder.blocks.5.attn.value.weight', 'encoder.blocks.5.attn.value.bias', 'encoder.blocks.5.attn.out.weight', 'encoder.blocks.5.attn.out.bias', 'encoder.blocks.6.mlp_ln.weight', 'encoder.blocks.6.mlp_ln.bias', 'encoder.blocks.6.mlp.0.weight', 'encoder.blocks.6.mlp.0.bias', 'encoder.blocks.6.mlp.2.weight', 'encoder.blocks.6.mlp.2.bias', 'encoder.blocks.6.attn_ln.weight', 'encoder.blocks.6.attn_ln.bias', 'encoder.blocks.6.attn.query.weight', 'encoder.blocks.6.attn.query.bias', 'encoder.blocks.6.attn.key.weight', 'encoder.blocks.6.attn.value.weight', 'encoder.blocks.6.attn.value.bias', 'encoder.blocks.6.attn.out.weight', 'encoder.blocks.6.attn.out.bias', 'encoder.blocks.7.mlp_ln.weight', 'encoder.blocks.7.mlp_ln.bias', 'encoder.blocks.7.mlp.0.weight', 'encoder.blocks.7.mlp.0.bias', 'encoder.blocks.7.mlp.2.weight', 'encoder.blocks.7.mlp.2.bias', 'encoder.blocks.7.attn_ln.weight', 'encoder.blocks.7.attn_ln.bias', 'encoder.blocks.7.attn.query.weight', 'encoder.blocks.7.attn.query.bias', 'encoder.blocks.7.attn.key.weight', 'encoder.blocks.7.attn.value.weight', 'encoder.blocks.7.attn.value.bias', 'encoder.blocks.7.attn.out.weight', 'encoder.blocks.7.attn.out.bias', 'encoder.blocks.8.mlp_ln.weight', 'encoder.blocks.8.mlp_ln.bias', 'encoder.blocks.8.mlp.0.weight', 'encoder.blocks.8.mlp.0.bias', 'encoder.blocks.8.mlp.2.weight', 'encoder.blocks.8.mlp.2.bias', 'encoder.blocks.8.attn_ln.weight', 'encoder.blocks.8.attn_ln.bias', 'encoder.blocks.8.attn.query.weight', 'encoder.blocks.8.attn.query.bias', 'encoder.blocks.8.attn.key.weight', 'encoder.blocks.8.attn.value.weight', 'encoder.blocks.8.attn.value.bias', 'encoder.blocks.8.attn.out.weight', 'encoder.blocks.8.attn.out.bias', 'encoder.blocks.9.mlp_ln.weight', 'encoder.blocks.9.mlp_ln.bias', 'encoder.blocks.9.mlp.0.weight', 'encoder.blocks.9.mlp.0.bias', 'encoder.blocks.9.mlp.2.weight', 'encoder.blocks.9.mlp.2.bias', 'encoder.blocks.9.attn_ln.weight', 'encoder.blocks.9.attn_ln.bias', 'encoder.blocks.9.attn.query.weight', 'encoder.blocks.9.attn.query.bias', 'encoder.blocks.9.attn.key.weight', 'encoder.blocks.9.attn.value.weight', 'encoder.blocks.9.attn.value.bias', 'encoder.blocks.9.attn.out.weight', 'encoder.blocks.9.attn.out.bias', 'encoder.blocks.10.mlp_ln.weight', 'encoder.blocks.10.mlp_ln.bias', 'encoder.blocks.10.mlp.0.weight', 'encoder.blocks.10.mlp.0.bias', 'encoder.blocks.10.mlp.2.weight', 'encoder.blocks.10.mlp.2.bias', 'encoder.blocks.10.attn_ln.weight', 'encoder.blocks.10.attn_ln.bias', 'encoder.blocks.10.attn.query.weight', 'encoder.blocks.10.attn.query.bias', 'encoder.blocks.10.attn.key.weight', 'encoder.blocks.10.attn.value.weight', 'encoder.blocks.10.attn.value.bias', 'encoder.blocks.10.attn.out.weight', 'encoder.blocks.10.attn.out.bias', 'encoder.blocks.11.mlp_ln.weight', 'encoder.blocks.11.mlp_ln.bias', 'encoder.blocks.11.mlp.0.weight', 'encoder.blocks.11.mlp.0.bias', 'encoder.blocks.11.mlp.2.weight', 'encoder.blocks.11.mlp.2.bias', 'encoder.blocks.11.attn_ln.weight', 'encoder.blocks.11.attn_ln.bias', 'encoder.blocks.11.attn.query.weight', 'encoder.blocks.11.attn.query.bias', 'encoder.blocks.11.attn.key.weight', 'encoder.blocks.11.attn.value.weight', 'encoder.blocks.11.attn.value.bias', 'encoder.blocks.11.attn.out.weight', 'encoder.blocks.11.attn.out.bias', 'encoder.ln_post.weight', 'encoder.ln_post.bias']
checkpointからModelDimensionsのインスタンスを生成
ModelDimensionsクラス
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
dims = ModelDimensions(**checkpoint["dims"])
print(dims)
ModelDimensions(n_mels=80, n_audio_ctx=1500, n_audio_state=768, n_audio_head=12, n_audio_layer=12, n_vocab=51865, n_text_ctx=448, n_text_state=768, n_text_head=12, n_text_layer=12)
Whisperインスタンスの生成
Whisperクラス
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
)
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.dims.n_text_layer, self.dims.n_text_head
)
self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function
model = Whisper(dims)
print(model)
Whisper(
(encoder): AudioEncoder(
(conv1): Conv1d(80, 768, kernel_size=(3,), stride=(1,), padding=(1,))
(conv2): Conv1d(768, 768, kernel_size=(3,), stride=(2,), padding=(1,))
(blocks): ModuleList(
(0-11): 12 x ResidualAttentionBlock(
(attn): MultiHeadAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=False)
(value): Linear(in_features=768, out_features=768, bias=True)
(out): Linear(in_features=768, out_features=768, bias=True)
)
(attn_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
)
(mlp_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
(ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(decoder): TextDecoder(
(token_embedding): Embedding(51865, 768)
(blocks): ModuleList(
(0-11): 12 x ResidualAttentionBlock(
(attn): MultiHeadAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=False)
(value): Linear(in_features=768, out_features=768, bias=True)
(out): Linear(in_features=768, out_features=768, bias=True)
)
(attn_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(cross_attn): MultiHeadAttention(
(query): Linear(in_features=768, out_features=768, bias=True)
(key): Linear(in_features=768, out_features=768, bias=False)
(value): Linear(in_features=768, out_features=768, bias=True)
(out): Linear(in_features=768, out_features=768, bias=True)
)
(cross_attn_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(mlp): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
)
(mlp_ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
(ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)