OpenAIが提供する音声認識オープンソースWhisperとは(3)

AI

transcribe関数の確認

def transcribe(
    model: "Whisper",
    audio: Union[str, np.ndarray, torch.Tensor],
    *,
    verbose: Optional[bool] = None,
    temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
    compression_ratio_threshold: Optional[float] = 2.4,
    logprob_threshold: Optional[float] = -1.0,
    no_speech_threshold: Optional[float] = 0.6,
    condition_on_previous_text: bool = True,
    initial_prompt: Optional[str] = None,
    word_timestamps: bool = False,
    prepend_punctuations: str = "\"'“¿([{-",
    append_punctuations: str = "\"'.。,,!!??::”)]}、",
    clip_timestamps: Union[str, List[float]] = "0",
    hallucination_silence_threshold: Optional[float] = None,
    **decode_options,
):

    dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
    if model.device == torch.device("cpu"):
        if torch.cuda.is_available():
            warnings.warn("Performing inference on CPU when CUDA is available")
        if dtype == torch.float16:
            warnings.warn("FP16 is not supported on CPU; using FP32 instead")
            dtype = torch.float32

    if dtype == torch.float32:
        decode_options["fp16"] = False

    # Pad 30-seconds of silence to the input audio, for slicing
    mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
    content_frames = mel.shape[-1] - N_FRAMES
    content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)

    if decode_options.get("language", None) is None:
        if not model.is_multilingual:
            decode_options["language"] = "en"
        else:
            if verbose:
                print(
                    "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
                )
            mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
            _, probs = model.detect_language(mel_segment)
            decode_options["language"] = max(probs, key=probs.get)
            if verbose is not None:
                print(
                    f"Detected language: {LANGUAGES[decode_options['language']].title()}"
                )

    language: str = decode_options["language"]
    task: str = decode_options.get("task", "transcribe")
    tokenizer = get_tokenizer(
        model.is_multilingual,
        num_languages=model.num_languages,
        language=language,
        task=task,
    )

    if isinstance(clip_timestamps, str):
        clip_timestamps = [
            float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
        ]
    seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
    if len(seek_points) == 0:
        seek_points.append(0)
    if len(seek_points) % 2 == 1:
        seek_points.append(content_frames)
    seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))

    punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"

    if word_timestamps and task == "translate":
        warnings.warn("Word-level timestamps on translations may not be reliable.")

    def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
        temperatures = (
            [temperature] if isinstance(temperature, (int, float)) else temperature
        )
        decode_result = None

        for t in temperatures:
            kwargs = {**decode_options}
            if t > 0:
                # disable beam_size and patience when t > 0
                kwargs.pop("beam_size", None)
                kwargs.pop("patience", None)
            else:
                # disable best_of when t == 0
                kwargs.pop("best_of", None)

            options = DecodingOptions(**kwargs, temperature=t)
            decode_result = model.decode(segment, options)

            needs_fallback = False
            if (
                compression_ratio_threshold is not None
                and decode_result.compression_ratio > compression_ratio_threshold
            ):
                needs_fallback = True  # too repetitive
            if (
                logprob_threshold is not None
                and decode_result.avg_logprob < logprob_threshold
            ):
                needs_fallback = True  # average log probability is too low
            if (
                no_speech_threshold is not None
                and decode_result.no_speech_prob > no_speech_threshold
            ):
                needs_fallback = False  # silence
            if not needs_fallback:
                break

        return decode_result

    clip_idx = 0
    seek = seek_clips[clip_idx][0]
    input_stride = exact_div(
        N_FRAMES, model.dims.n_audio_ctx
    )  # mel frames per output token: 2
    time_precision = (
        input_stride * HOP_LENGTH / SAMPLE_RATE
    )  # time per output token: 0.02 (seconds)
    all_tokens = []
    all_segments = []
    prompt_reset_since = 0

    if initial_prompt is not None:
        initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
        all_tokens.extend(initial_prompt_tokens)
    else:
        initial_prompt_tokens = []

    def new_segment(
        *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
    ):
        tokens = tokens.tolist()
        text_tokens = [token for token in tokens if token < tokenizer.eot]
        return {
            "seek": seek,
            "start": start,
            "end": end,
            "text": tokenizer.decode(text_tokens),
            "tokens": tokens,
            "temperature": result.temperature,
            "avg_logprob": result.avg_logprob,
            "compression_ratio": result.compression_ratio,
            "no_speech_prob": result.no_speech_prob,
        }

    # show the progress bar when verbose is False (if True, transcribed text will be printed)
    with tqdm.tqdm(
        total=content_frames, unit="frames", disable=verbose is not False
    ) as pbar:
        last_speech_timestamp = 0.0
        # NOTE: This loop is obscurely flattened to make the diff readable.
        # A later commit should turn this into a simpler nested loop.
        # for seek_clip_start, seek_clip_end in seek_clips:
        #     while seek < seek_clip_end
        while clip_idx < len(seek_clips):
            seek_clip_start, seek_clip_end = seek_clips[clip_idx]
            if seek < seek_clip_start:
                seek = seek_clip_start
            if seek >= seek_clip_end:
                clip_idx += 1
                if clip_idx < len(seek_clips):
                    seek = seek_clips[clip_idx][0]
                continue
            time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
            window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
            segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
            mel_segment = mel[:, seek : seek + segment_size]
            segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
            mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)

            decode_options["prompt"] = all_tokens[prompt_reset_since:]
            result: DecodingResult = decode_with_fallback(mel_segment)
            tokens = torch.tensor(result.tokens)

            if no_speech_threshold is not None:
                # no voice activity check
                should_skip = result.no_speech_prob > no_speech_threshold
                if (
                    logprob_threshold is not None
                    and result.avg_logprob > logprob_threshold
                ):
                    # don't skip if the logprob is high enough, despite the no_speech_prob
                    should_skip = False

                if should_skip:
                    seek += segment_size  # fast-forward to the next segment boundary
                    continue

            previous_seek = seek
            current_segments = []

            # anomalous words are very long/short/improbable
            def word_anomaly_score(word: dict) -> float:
                probability = word.get("probability", 0.0)
                duration = word["end"] - word["start"]
                score = 0.0
                if probability < 0.15:
                    score += 1.0
                if duration < 0.133:
                    score += (0.133 - duration) * 15
                if duration > 2.0:
                    score += duration - 2.0
                return score

            def is_segment_anomaly(segment: Optional[dict]) -> bool:
                if segment is None or not segment["words"]:
                    return False
                words = [w for w in segment["words"] if w["word"] not in punctuation]
                words = words[:8]
                score = sum(word_anomaly_score(w) for w in words)
                return score >= 3 or score + 0.01 >= len(words)

            def next_words_segment(segments: List[dict]) -> Optional[dict]:
                return next((s for s in segments if s["words"]), None)

            timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
            single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]

            consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
            consecutive.add_(1)
            if len(consecutive) > 0:
                # if the output contains two consecutive timestamp tokens
                slices = consecutive.tolist()
                if single_timestamp_ending:
                    slices.append(len(tokens))

                last_slice = 0
                for current_slice in slices:
                    sliced_tokens = tokens[last_slice:current_slice]
                    start_timestamp_pos = (
                        sliced_tokens[0].item() - tokenizer.timestamp_begin
                    )
                    end_timestamp_pos = (
                        sliced_tokens[-1].item() - tokenizer.timestamp_begin
                    )
                    current_segments.append(
                        new_segment(
                            start=time_offset + start_timestamp_pos * time_precision,
                            end=time_offset + end_timestamp_pos * time_precision,
                            tokens=sliced_tokens,
                            result=result,
                        )
                    )
                    last_slice = current_slice

                if single_timestamp_ending:
                    # single timestamp at the end means no speech after the last timestamp.
                    seek += segment_size
                else:
                    # otherwise, ignore the unfinished segment and seek to the last timestamp
                    last_timestamp_pos = (
                        tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                    )
                    seek += last_timestamp_pos * input_stride
            else:
                duration = segment_duration
                timestamps = tokens[timestamp_tokens.nonzero().flatten()]
                if (
                    len(timestamps) > 0
                    and timestamps[-1].item() != tokenizer.timestamp_begin
                ):
                    # no consecutive timestamps but it has a timestamp; use the last one.
                    last_timestamp_pos = (
                        timestamps[-1].item() - tokenizer.timestamp_begin
                    )
                    duration = last_timestamp_pos * time_precision

                current_segments.append(
                    new_segment(
                        start=time_offset,
                        end=time_offset + duration,
                        tokens=tokens,
                        result=result,
                    )
                )
                seek += segment_size

            if word_timestamps:
                add_word_timestamps(
                    segments=current_segments,
                    model=model,
                    tokenizer=tokenizer,
                    mel=mel_segment,
                    num_frames=segment_size,
                    prepend_punctuations=prepend_punctuations,
                    append_punctuations=append_punctuations,
                    last_speech_timestamp=last_speech_timestamp,
                )

                if not single_timestamp_ending:
                    last_word_end = get_end(current_segments)
                    if last_word_end is not None and last_word_end > time_offset:
                        seek = round(last_word_end * FRAMES_PER_SECOND)

                # skip silence before possible hallucinations
                if hallucination_silence_threshold is not None:
                    threshold = hallucination_silence_threshold
                    if not single_timestamp_ending:
                        last_word_end = get_end(current_segments)
                        if last_word_end is not None and last_word_end > time_offset:
                            remaining_duration = window_end_time - last_word_end
                            if remaining_duration > threshold:
                                seek = round(last_word_end * FRAMES_PER_SECOND)
                            else:
                                seek = previous_seek + segment_size

                    # if first segment might be a hallucination, skip leading silence
                    first_segment = next_words_segment(current_segments)
                    if first_segment is not None and is_segment_anomaly(first_segment):
                        gap = first_segment["start"] - time_offset
                        if gap > threshold:
                            seek = previous_seek + round(gap * FRAMES_PER_SECOND)
                            continue

                    # skip silence before any possible hallucination that is surrounded
                    # by silence or more hallucinations
                    hal_last_end = last_speech_timestamp
                    for si in range(len(current_segments)):
                        segment = current_segments[si]
                        if not segment["words"]:
                            continue
                        if is_segment_anomaly(segment):
                            next_segment = next_words_segment(
                                current_segments[si + 1 :]
                            )
                            if next_segment is not None:
                                hal_next_start = next_segment["words"][0]["start"]
                            else:
                                hal_next_start = time_offset + segment_duration
                            silence_before = (
                                segment["start"] - hal_last_end > threshold
                                or segment["start"] < threshold
                                or segment["start"] - time_offset < 2.0
                            )
                            silence_after = (
                                hal_next_start - segment["end"] > threshold
                                or is_segment_anomaly(next_segment)
                                or window_end_time - segment["end"] < 2.0
                            )
                            if silence_before and silence_after:
                                seek = round(
                                    max(time_offset + 1, segment["start"])
                                    * FRAMES_PER_SECOND
                                )
                                if content_duration - segment["end"] < threshold:
                                    seek = content_frames
                                current_segments[si:] = []
                                break
                        hal_last_end = segment["end"]

                last_word_end = get_end(current_segments)
                if last_word_end is not None:
                    last_speech_timestamp = last_word_end

            if verbose:
                for segment in current_segments:
                    start, end, text = segment["start"], segment["end"], segment["text"]
                    line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
                    print(make_safe(line))

            # if a segment is instantaneous or does not contain text, clear it
            for i, segment in enumerate(current_segments):
                if segment["start"] == segment["end"] or segment["text"].strip() == "":
                    segment["text"] = ""
                    segment["tokens"] = []
                    segment["words"] = []

            all_segments.extend(
                [
                    {"id": i, **segment}
                    for i, segment in enumerate(
                        current_segments, start=len(all_segments)
                    )
                ]
            )
            all_tokens.extend(
                [token for segment in current_segments for token in segment["tokens"]]
            )

            if not condition_on_previous_text or result.temperature > 0.5:
                # do not feed the prompt tokens if a high temperature was used
                prompt_reset_since = len(all_tokens)

            # update progress bar
            pbar.update(min(content_frames, seek) - previous_seek)

    return dict(
        text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
        segments=all_segments,
        language=language,
    )

この関数は、Whisperモデルを使用してオーディオファイルを転記するためのものです。以下は、この関数の主なパラメータと機能の概要です。

  • model: Whisperモデルのインスタンスです。
  • audio: 転記するオーディオファイルのパス、またはオーディオ波形です。
  • verbose: 出力の詳細さを制御するブール値です。
  • temperature: サンプリングの温度を制御します。
  • compression_ratio_threshold: 圧縮率のしきい値です。
  • logprob_threshold: ログ確率のしきい値です。
  • no_speech_threshold: 音声がないとみなされる確率のしきい値です。
  • condition_on_previous_text: 前のテキストに基づいて次のテキストを生成するかどうかを制御します。
  • initial_prompt: 最初のウィンドウのプロンプトとして使用するオプションのテキストです。
  • word_timestamps: 単語レベルのタイムスタンプを含めるかどうかを制御します。
  • prepend_punctuations: 単語レベルのタイムスタンプに先行する句読点です。
  • append_punctuations: 単語レベルのタイムスタンプに後続する句読点です。
  • clip_timestamps: 処理するクリップのタイムスタンプです。
  • hallucination_silence_threshold: 幻覚とみなされる場合にスキップする無音のしきい値です。
  • decode_options: デコードオプションの追加引数です。

関数内の処理の概要は以下の通りです:

  1. 引数の検証と前処理を行います。
  2. オーディオファイルを30秒ごとのセグメントに分割します。
  3. 各セグメントに対して、Whisperモデルを使用してテキストをデコードします。
  4. テキストのデコード結果と、各セグメントの情報を収集します。
  5. 最終的なテキストとセグメント情報を辞書として返します。

この関数は、指定されたオーディオファイルの転写を行う際に、Whisperモデルを使用し、オプションのパラメータに従ってテキストのデコードを行います。

log_mel_spectrogram関数定義

def log_mel_spectrogram(
    audio: Union[str, np.ndarray, torch.Tensor],
    n_mels: int = 80,
    padding: int = 0,
    device: Optional[Union[str, torch.device]] = None,
):
    if not torch.is_tensor(audio):
        if isinstance(audio, str):
            audio = load_audio(audio)
        audio = torch.from_numpy(audio)

    if device is not None:
        audio = audio.to(device)
    if padding > 0:
        audio = F.pad(audio, (0, padding))
    window = torch.hann_window(N_FFT).to(audio.device)
    stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
    magnitudes = stft[..., :-1].abs() ** 2

    filters = mel_filters(audio.device, n_mels)
    mel_spec = filters @ magnitudes

    log_spec = torch.clamp(mel_spec, min=1e-10).log10()
    log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
    log_spec = (log_spec + 4.0) / 4.0
    return log_spec

この関数は、与えられたオーディオデータから対数メルスペクトログラムを計算するためのものです。対数メルスペクトログラムは音声データの周波数成分を表す特徴量であり、音声処理や音声認識などのタスクで広く使用されます。

関数の主な引数と処理は以下の通りです:

  • audio: オーディオデータを表す引数で、文字列(ファイルパス)、NumPy配列、またはPyTorchテンソルのいずれかを受け取ります。
  • n_mels: Melフィルターの数を指定します。通常は80が使用されます。
  • padding: ゼロパディングするサンプル数を指定します。ゼロパディングは、オーディオデータの末尾に追加されます。
  • device: デバイス(CPUまたはGPU)を指定します。指定された場合、オーディオデータがそのデバイスに移動されます。

関数の主な処理は以下の通りです:

  1. 入力のオーディオデータがPyTorchテンソルでない場合は、適切な形式に変換します。
  2. オーディオデータが指定されたデバイスに移動されます。
  3. ゼロパディングが適用されます。
  4. ハニング窓を用いて短時間フーリエ変換(STFT)が行われます。
  5. スペクトログラムを計算するためのメルフィルタバンクが適用されます。
  6. 対数が取られ、対数メルスペクトログラムが得られます。
  7. 最終的に、対数メルスペクトログラムが返されます。

この関数は、音声データから抽出された特徴量である対数メルスペクトログラムを返すことで、後続の音声処理や機械学習モデルの入力として使用されることが想定されています。

torch.from_numpy関数とは?

torch.from_numpy(audio)は、NumPy配列をPyTorchのテンソルに変換するための関数です。具体的には、与えられたNumPy配列 audio を元にして、PyTorchのテンソルを作成します。

この関数は、NumPy配列を受け取り、その配列のデータを共有するPyTorchのテンソルを作成します。つまり、NumPy配列とPyTorchテンソルは同じメモリを共有し、一方の変更が他方にも反映されることになります。この方法により、データのコピーが不要になり、効率的にデータを共有できます。

この関数は、PyTorchでNumPyデータを扱う際に非常に便利であり、特に既存のNumPyコードをPyTorchに移行する場合に役立ちます。

PyTorchテンソルとは?

PyTorchのテンソル(Tensor)は、多次元の配列を表現するためのデータ構造です。テンソルは、NumPyの配列に似た操作が可能であり、数値計算や機械学習などの科学技術計算に広く使用されています。PyTorchのテンソルは、計算グラフや自動微分機能を持つ動的な計算フレームワークであるPyTorchの基本的なデータ構造の一つです。

PyTorchのテンソルは以下の特徴を持ちます:

  1. 多次元配列: テンソルは0次元(スカラー)、1次元(ベクトル)、2次元(行列)、3次元以上の多次元配列を表現することができます。
  2. GPUサポート: PyTorchのテンソルはCPUとGPUの両方で動作し、GPUを利用して高速な並列計算を実行することができます。
  3. 自動微分: PyTorchは自動微分機能を提供し、計算過程での微分を自動的に計算することができます。これにより、ニューラルネットワークの学習や最適化が容易になります。
  4. 計算グラフ: PyTorchのテンソルは計算グラフを構築し、それを利用して計算履歴を追跡することができます。これにより、誤差逆伝播法を実装する際に便利です。

PyTorchのテンソルは、科学技術計算や機械学習のさまざまなタスクで広く使用されており、その柔軟性とパフォーマンスの高さから、多くの研究者や開発者によって好まれています。

torch.deviceとは?

torch.device は、PyTorchでテンソルをどのデバイス(CPUまたはGPU)に配置するかを指定するためのクラスです。このクラスを使用することで、計算を特定のデバイス上で実行することができます。

torch.device を使用することで、以下のようなことが可能です:

  1. デバイスの指定: テンソルを特定のデバイス(例:torch.device("cuda") または torch.device("cuda:0"))に移動させることができます。”cuda” はGPUを指し、”:0″ はGPUの番号を示します。”cuda:0″ は複数のGPUを使用している場合に、最初のGPUを指します。
  2. デバイスのチェック: テンソルがどのデバイスに配置されているかを確認するために使用します。たとえば、tensor.device 属性を使用して、テンソルがどのデバイスにあるかを取得できます。
  3. デバイス間のデータ移動: テンソルを異なるデバイス間で移動させるために使用します。tensor.to(device) メソッドを使用して、テンソルを特定のデバイスに移動させることができます。

PyTorchでは、torch.device を使用してテンソルを適切なデバイスに配置することで、GPUを活用して高速な計算を実行することができます。また、マルチGPU環境での並列処理を行う際にも便利です。

torch.hann_window関数とは?

torch.hann_window は、ハン窓(Hann window)を生成するPyTorchの関数です。ハン窓は、信号処理やスペクトル解析などの分野で使用される一般的な窓関数の一つです。

ハン窓は、次の数式で定義されます:

$$
w(n) = 0.5 – 0.5 \cdot \cos\left(\frac{2\pi n}{N – 1}\right)
$$

ここで、(n) は窓関数のインデックスであり、(N) は窓の長さ(サンプル数)です。この窓関数は、主に信号の端に存在する周波数成分のエネルギーを減少させるために使用されます。

torch.hann_window 関数は、指定されたサイズのハン窓を生成し、PyTorchのテンソルとして返します。この関数を使用することで、信号処理やスペクトル解析のために窓関数を生成し、それを入力信号に適用することができます。

torch.stft 関数とは?

torch.stft 関数は、PyTorchでの短時間フーリエ変換(Short-Time Fourier Transform, STFT)を計算するために使用されます。STFTは、時間領域の信号を周波数領域に変換する手法であり、信号の周波数成分が時間によってどのように変化するかを分析するのに役立ちます。

具体的には、torch.stft 関数は次のようになります:

  • audio: STFTを計算する入力信号です。通常、オーディオ信号が与えられます。この信号は、PyTorchのテンソルとして表現されます。
  • n_fft: STFTにおけるフーリエ変換の窓のサイズ(フレームサイズ)です。これは、入力信号をフーリエ変換する際のウィンドウのサイズを指定します。
  • hop_length: フレーム間のサンプル数です。これは、連続するフレームのオーバーラップを制御します。通常、STFTにおいて、隣接するフレームは一部が重なり合います。このパラメーターは、重なり具合を制御します。
  • window: 窓関数を指定します。窓関数は、STFTの各フレームに適用され、周波数解析の精度やスプリンクルの軽減に影響を与えます。
  • return_complex: 出力を複素数形式で返すかどうかを指定します。Trueに設定すると、STFTの結果が複素数形式で返されます。Falseに設定すると、結果は実部と虚部が別々のテンソルとして返されます。

torch.stft 関数の出力は、入力信号をSTFTに変換した結果であり、通常は複素数形式のテンソルとして返されます。この出力を用いて、信号の周波数特性や時間的変化を分析することができます。

mel_filters関数定義

def mel_filters(device, n_mels: int) -> torch.Tensor:
    """
    load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
    Allows decoupling librosa dependency; saved using:

        np.savez_compressed(
            "mel_filters.npz",
            mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
            mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
        )
    """
    assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"

    filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
    with np.load(filters_path, allow_pickle=False) as f:
        return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)

次に、計算されたSTFT(短時間フーリエ変換)からメルスペクトログラムを計算するための処理を行っています。メルスペクトログラムは、音声信号の周波数成分を解析するために一般的に使用される手法の一つであり、周波数を人間の聴覚特性に合わせて変換したスペクトログラムです。

  1. magnitudes = stft[..., :-1].abs() ** 2:
    まず、計算されたSTFT(短時間フーリエ変換)から振幅スペクトログラム(magnitude spectrogram)を計算しています。ここでは、複素数形式のSTFTの結果から、実部と虚部を取り出し、絶対値を取って振幅を得ています。その後、絶対値を2乗して、振幅の二乗を計算しています。これにより、各時間フレームごとの周波数成分のパワースペクトルが得られます。
  2. filters = mel_filters(audio.device, n_mels):
    次に、メルフィルタバンク(mel filter bank)を作成しています。メルフィルタバンクは、周波数をメル尺度に変換するためのフィルタの集合であり、音声信号の周波数特性を解析するために使用されます。
  3. mel_spec = filters @ magnitudes:
    最後に、計算された振幅スペクトログラムとメルフィルタバンクを乗算して、メルスペクトログラムを計算しています。この操作により、各時間フレームにおける音声信号のメル尺度に変換された周波数成分が得られます。

入力オーディオ信号からメルスペクトログラムが計算されます。メルスペクトログラムは、音声信号の周波数特性を解析するために広く使用され、音声処理や音声認識などのアプリケーションで重要な役割を果たします。

最後に、

メルスペクトログラムを対数スケールに変換し、値を正規化する処理を行っています。

  1. log_spec = torch.clamp(mel_spec, min=1e-10).log10():
    まず、メルスペクトログラムを取得した後、torch.clamp()関数を使用して、メルスペクトログラムの値を下限値1e-10にクリップします。これは、対数関数が負の値に定義されないためです。その後、torch.log10()関数を使用して、対数変換を行います。これにより、各要素の値が対数スケールに変換されます。
  2. log_spec = torch.maximum(log_spec, log_spec.max() - 8.0):
    次に、対数変換されたメルスペクトログラムの値を正規化します。ここでは、torch.maximum()関数を使用して、対数変換されたメルスペクトログラムの値と、その最大値から8.0を引いた値の間で要素ごとの最大値を取得します。これにより、スペクトログラムの値の範囲が8.0に制限されます。
  3. log_spec = (log_spec + 4.0) / 4.0:
    最後に、正規化された対数スペクトログラムの値をさらにスケーリングします。各要素の値に4.0を加え、その結果を4.0で除算することで、値の範囲が[0, 1]にスケーリングされます。これにより、対数スペクトログラムが適切な範囲に正規化され、後続の処理や可視化に適した形式になります。

関連記事

カテゴリー

アーカイブ

Lang »