Skip to content
Published on

The Complete torchaudio Guide — From Audio Processing to Speech Recognition, TTS, and Music Analysis

Authors
  • Name
    Twitter
torchaudio Guide

Introduction

torchaudio is the official audio processing library for PyTorch. It supports audio I/O, spectrogram transforms, pretrained models (Wav2Vec2, HuBERT, Whisper), and even real-time streaming.

pip install torch torchaudio

Part 1: Audio Fundamentals

Loading and Saving Audio

import torch
import torchaudio

# Load audio
waveform, sample_rate = torchaudio.load("speech.wav")
print(f"Shape: {waveform.shape}")    # [channels, samples]
print(f"Sample Rate: {sample_rate}")  # 16000
print(f"Duration: {waveform.shape[1] / sample_rate:.2f}s")

# Channels: mono (1) vs stereo (2)
if waveform.shape[0] == 2:
    mono = waveform.mean(dim=0, keepdim=True)  # Stereo -> Mono

# Resampling (44100Hz -> 16000Hz)
resampler = torchaudio.transforms.Resample(
    orig_freq=44100, new_freq=16000
)
waveform_16k = resampler(waveform)

# Save
torchaudio.save("output.wav", waveform_16k, 16000)

# Supported formats: wav, flac, mp3, ogg, opus, sphere
# Backends: sox, soundfile, ffmpeg
print(torchaudio.list_audio_backends())

Audio Visualization

import matplotlib.pyplot as plt

# Waveform
fig, axes = plt.subplots(3, 1, figsize=(12, 8))

# 1. Time domain (waveform)
time_axis = torch.arange(0, waveform.shape[1]) / sample_rate
axes[0].plot(time_axis, waveform[0])
axes[0].set_title("Waveform")
axes[0].set_xlabel("Time (s)")
axes[0].set_ylabel("Amplitude")

# 2. Spectrogram
spectrogram = torchaudio.transforms.Spectrogram(n_fft=1024)(waveform)
axes[1].imshow(
    spectrogram[0].log2().numpy(),
    aspect='auto', origin='lower', cmap='magma'
)
axes[1].set_title("Spectrogram")

# 3. Mel Spectrogram
mel_spec = torchaudio.transforms.MelSpectrogram(
    sample_rate=sample_rate, n_fft=1024, n_mels=80
)(waveform)
axes[2].imshow(
    mel_spec[0].log2().numpy(),
    aspect='auto', origin='lower', cmap='magma'
)
axes[2].set_title("Mel Spectrogram")

plt.tight_layout()
plt.savefig("audio_analysis.png", dpi=150)

Part 2: Core Transforms

Spectrogram Family

# STFT (Short-Time Fourier Transform)
# Converts from time domain to time+frequency domain
spectrogram_transform = torchaudio.transforms.Spectrogram(
    n_fft=1024,       # FFT window size (frequency resolution)
    hop_length=256,    # Window hop interval (time resolution)
    win_length=1024,   # Window length
    power=2.0,         # 2.0=power, 1.0=amplitude
)

spec = spectrogram_transform(waveform)
# shape: [channels, n_freq_bins, time_frames]
# n_freq_bins = n_fft // 2 + 1 = 513
Trade-off between n_fft and hop_length:

n_fft up -> frequency resolution up, time resolution down
n_fft down -> frequency resolution down, time resolution up

Common settings:
+-- Speech: n_fft=400~512, hop=160 (at 16kHz)
+-- Music: n_fft=2048, hop=512 (at 44.1kHz)
+-- General: n_fft=1024, hop=256

Mel Spectrogram — Why Mel?

# The human ear is sensitive to low frequencies, less so to high frequencies.
# Mel scale = a frequency scale that reflects human auditory perception.

mel_transform = torchaudio.transforms.MelSpectrogram(
    sample_rate=16000,
    n_fft=1024,
    hop_length=256,
    n_mels=80,          # Number of Mel filters (typically 40~128)
    f_min=0,            # Minimum frequency
    f_max=8000,         # Maximum frequency (Nyquist)
)

mel_spec = mel_transform(waveform)
# shape: [1, 80, time_frames]

# Convert to dB scale (log compression)
amplitude_to_db = torchaudio.transforms.AmplitudeToDB(stype='power', top_db=80)
mel_spec_db = amplitude_to_db(mel_spec)
Mel frequency conversion formula:
  mel = 2595 * log10(1 + freq / 700)

Frequency -> Mel:
  100 Hz  ->  150 mel   (low freq: dense)
  1000 Hz ->  1000 mel
  4000 Hz ->  2146 mel
  8000 Hz ->  2840 mel  (high freq: sparse)

-> Low frequencies are analyzed finely, high frequencies are coarsely grouped
-> A representation similar to how humans hear!

MFCC (Mel-Frequency Cepstral Coefficients)

# MFCC = Mel Spectrogram + DCT (Discrete Cosine Transform)
# Key features representing the "shape" of speech

mfcc_transform = torchaudio.transforms.MFCC(
    sample_rate=16000,
    n_mfcc=13,          # Number of MFCC coefficients (typically 13~40)
    melkwargs={
        'n_fft': 1024,
        'n_mels': 80,
        'hop_length': 256,
    }
)

mfcc = mfcc_transform(waveform)
# shape: [1, 13, time_frames]

# Delta (1st derivative) + Delta-Delta (2nd derivative)
# -> Adds rate-of-change information for speech
delta = torchaudio.functional.compute_deltas(mfcc)
delta_delta = torchaudio.functional.compute_deltas(delta)

# Final features: concatenate [MFCC, Delta, Delta-Delta]
features = torch.cat([mfcc, delta, delta_delta], dim=1)
# shape: [1, 39, time_frames]
Where are these used?
+-- MFCC: Traditional speech recognition (HMM-GMM), speaker recognition
+-- Mel Spectrogram: Deep learning speech recognition (Wav2Vec2, Whisper)
+-- Spectrogram: Music analysis, environmental sound classification
+-- Raw Waveform: End-to-end models (latest trend)

Part 3: Audio Augmentation

# Time masking (SpecAugment)
time_masking = torchaudio.transforms.TimeMasking(
    time_mask_param=30   # Mask up to 30 frames
)

# Frequency masking (SpecAugment)
freq_masking = torchaudio.transforms.FrequencyMasking(
    freq_mask_param=15   # Mask up to 15 channels
)

# SpecAugment (significantly improves speech recognition accuracy!)
augmented_spec = time_masking(freq_masking(mel_spec))

# Time stretching
time_stretch = torchaudio.transforms.TimeStretch()
stretched = time_stretch(complex_spec, overriding_rate=1.2)  # 20% faster

# Pitch shifting
pitch_shift = torchaudio.transforms.PitchShift(
    sample_rate=16000, n_steps=4  # Shift up by 4 semitones
)
shifted = pitch_shift(waveform)

# Adding noise
def add_noise(waveform, snr_db=10):
    """Add white noise based on SNR in dB"""
    noise = torch.randn_like(waveform)
    signal_power = waveform.norm(p=2)
    noise_power = noise.norm(p=2)
    snr = 10 ** (snr_db / 20)
    scale = signal_power / (snr * noise_power)
    return waveform + scale * noise

Part 4: Pretrained Models

Wav2Vec 2.0 (Speech Recognition)

import torchaudio
from torchaudio.pipelines import WAV2VEC2_ASR_BASE_960H

# Load pipeline
bundle = WAV2VEC2_ASR_BASE_960H
model = bundle.get_model()
labels = bundle.get_labels()  # Token list

# Inference
waveform, sr = torchaudio.load("speech.wav")
if sr != bundle.sample_rate:
    waveform = torchaudio.transforms.Resample(sr, bundle.sample_rate)(waveform)

with torch.no_grad():
    emissions, _ = model(waveform)

# CTC Decoding (Greedy)
def greedy_decode(emissions, labels):
    indices = torch.argmax(emissions, dim=-1)[0]
    tokens = []
    prev = -1
    for idx in indices:
        if idx != prev and idx != 0:  # 0 = blank
            tokens.append(labels[idx])
        prev = idx
    return "".join(tokens).replace("|", " ").strip()

text = greedy_decode(emissions, labels)
print(f"Recognition result: {text}")

HuBERT (Self-Supervised Speech Representations)

from torchaudio.pipelines import HUBERT_BASE

bundle = HUBERT_BASE
model = bundle.get_model()

with torch.no_grad():
    features, _ = model(waveform)
# features: [1, time_frames, 768]
# -> Semantic representation vectors of speech
# -> Used for speaker recognition, emotion analysis, speech classification

Forced Alignment (Subtitle Synchronization)

# Temporal alignment between speech and text!
# -> Essential for subtitle generation and lyrics synchronization

from torchaudio.pipelines import MMS_FA  # Multilingual!

bundle = MMS_FA
model = bundle.get_model()
tokenizer = bundle.get_tokenizer()
aligner = bundle.get_aligner()

transcript = "hello nice to meet you"
tokens = tokenizer(transcript)

with torch.no_grad():
    emissions, _ = model(waveform)

token_spans = aligner(emissions[0], tokens)
# Returns start/end times for each token in frame units!

for span, token in zip(token_spans, transcript):
    start_time = span.start * model.hop_length / sample_rate
    end_time = span.end * model.hop_length / sample_rate
    print(f"  '{token}': {start_time:.3f}s ~ {end_time:.3f}s")

Part 5: Audio Effects

# torchaudio.functional — GPU-accelerated audio processing

import torchaudio.functional as F

# Volume adjustment
loud = F.gain(waveform, gain_db=6.0)    # +6dB
quiet = F.gain(waveform, gain_db=-6.0)  # -6dB

# High-pass / Low-pass filter
highpass = F.highpass_biquad(waveform, sample_rate, cutoff_freq=300)
lowpass = F.lowpass_biquad(waveform, sample_rate, cutoff_freq=3000)

# Equalizer
eq = F.equalizer_biquad(
    waveform, sample_rate,
    center_freq=1000,  # Around 1kHz
    gain=5.0,          # +5dB boost
    Q=0.707
)

# Reverb
rir, _ = torchaudio.load("room_impulse_response.wav")  # RIR file
reverb = F.fftconvolve(waveform, rir)

# Fade in/out
fade = torchaudio.transforms.Fade(
    fade_in_len=sample_rate,      # 1 second fade in
    fade_out_len=sample_rate * 3  # 3 second fade out
)
faded = fade(waveform)

# VAD (Voice Activity Detection)
vad = torchaudio.transforms.Vad(sample_rate=16000)
speech_only = vad(waveform)  # Remove silent segments

Part 6: Practical Projects

Environmental Sound Classification (Audio Classification)

import torch.nn as nn

class AudioClassifier(nn.Module):
    def __init__(self, n_classes=10):
        super().__init__()
        self.mel = torchaudio.transforms.MelSpectrogram(
            sample_rate=16000, n_fft=1024, n_mels=64
        )
        self.db = torchaudio.transforms.AmplitudeToDB()

        # Feed the Mel spectrogram into a CNN as if it were an "image"!
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Linear(128, n_classes)

    def forward(self, waveform):
        # [B, 1, samples] -> [B, 1, n_mels, time]
        x = self.mel(waveform)
        x = self.db(x)
        x = self.cnn(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Mel Spectrogram = an "image" of audio
# -> Can be classified with CNNs (ResNet, EfficientNet)!

Real-Time Streaming Processing

from torchaudio.io import StreamReader

# Real-time microphone input processing
reader = StreamReader(src=":0", format="avfoundation")  # macOS
reader.add_basic_audio_stream(
    frames_per_chunk=16000,  # 1 second chunks
    sample_rate=16000,
)

for (chunk,) in reader.stream():
    # chunk: [1, 16000]
    mel = mel_transform(chunk)
    with torch.no_grad():
        prediction = model(mel)
    print(f"Detected: {labels[prediction.argmax()]}")

Quiz — torchaudio (click to reveal!)

Q1. Why is the Mel scale needed? ||Human hearing is sensitive to low frequencies and less responsive to high frequencies. The Mel scale reflects this by analyzing low frequencies finely and grouping high frequencies coarsely. It incorporates human auditory characteristics into deep learning models.||

Q2. When you increase n_fft, which resolution goes up and which goes down? ||n_fft up -> frequency resolution up (finer frequency discrimination), time resolution down (harder to track temporal changes). A trade-off similar to the uncertainty principle.||

Q3. What are the two types of masking in SpecAugment? ||Time Masking: masks consecutive frames along the time axis with zeros. Frequency Masking: masks consecutive channels along the frequency axis with zeros. As data augmentation, these significantly improve speech recognition accuracy.||

Q4. What is the difference between MFCC and Mel Spectrogram, and what are their use cases? ||MFCC: Applies DCT to Mel Spectrogram to extract coefficients (13~40 dimensions). Used in traditional speech recognition and speaker recognition. Mel Spectrogram: A 2D time-frequency representation. Fed directly into deep learning models (current trend).||

Q5. What are the applications of Forced Alignment? ||Temporal alignment of speech and text. Subtitle generation (accurate timing), lyrics synchronization (karaoke), pronunciation assessment (language learning apps).||

Q6. What role does the blank token play in CTC decoding for Wav2Vec 2.0? ||It separates consecutive identical tokens and represents time intervals with no output. In greedy decoding, blanks (index 0) and consecutive duplicates are removed to produce the final text.||

Q7. Why can a Mel Spectrogram be fed into a CNN? ||A Mel Spectrogram has the same structure as a 2D image (frequency axis x time axis). It can be treated as a 1-channel grayscale image, allowing direct use of image classification models like ResNet and EfficientNet.||

GitHub