Sound to Vector Embeddings

Audio embeddings convert sound signals (speech, music, environmental sounds) into dense vector representations that capture acoustic and semantic features, enabling similarity search, classification, and retrieval.

Core Idea

Audio embeddings map audio waveforms or spectrograms to fixed-size vectors in a high-dimensional space where semantically similar sounds are close together. The embedding function $E: \mathcal{A} \rightarrow \mathbb{R}^d$ transforms an audio signal $A \in \mathcal{A}$ into a vector $\mathbf{v} \in \mathbb{R}^d$.

Mathematical Foundation

Preprocessing: Audio to Spectrogram

Raw audio waveform $x(t)$ is converted to a spectrogram: $$S(t, f) = |\text{STFT}(x(t))|^2$$

where:

  • $\text{STFT}$ is the Short-Time Fourier Transform
  • $S(t, f) \in \mathbb{R}^{T \times F}$ is the time-frequency representation
  • $T$ is the number of time frames
  • $F$ is the number of frequency bins

Inference (Forward Pass):

For CNN-based encoder: $$\mathbf{v} = E(A) = \text{GlobalPool}(\text{CNN}(\text{MelSpectrogram}(A)))$$

For Transformer-based encoder (Wav2Vec2): $$\mathbf{v} = E(A) = \text{MeanPool}(\text{Transformer}(\text{CNN}(A)))$$

For sequence-to-vector aggregation: $$\mathbf{v} = \frac{1}{T}\sum_{t=1}^{T} \mathbf{h}_t$$

where:

  • $A \in \mathbb{R}^{L}$ is the raw audio waveform (length $L$ samples)
  • $\mathbf{h}_t \in \mathbb{R}^d$ are hidden states at each time step
  • $\mathbf{v} \in \mathbb{R}^d$ is the output embedding vector

Training Objective:

Self-supervised learning with contrastive loss (Wav2Vec2): $$\mathcal{L} = -\log \frac{\exp(\text{sim}(\mathbf{c}_t, \mathbf{q}t) / \tau)}{\sum{\tilde{q} \in Q_t} \exp(\text{sim}(\mathbf{c}_t, \tilde{\mathbf{q}}) / \tau)}$$

where:

  • $\mathbf{c}_t$ is the context vector at time $t$
  • $\mathbf{q}_t$ is the quantized target vector
  • $Q_t$ is the set of quantized vectors (positive + negatives)
  • $\tau$ is the temperature parameter

Supervised Fine-tuning: $$\mathcal{L} = -\sum_{i=1}^{N} \log P(y_i | \mathbf{v}_i)$$

where $y_i$ is the label (e.g., speaker ID, emotion class) and $\mathbf{v}_i = E(A_i)$.


Architecture Overview


PyTorch Implementation

CNN-based Audio Encoder

 1import torch
 2import torch.nn as nn
 3import torchaudio
 4import torchaudio.transforms as T
 5
 6class AudioEncoder(nn.Module):
 7    def __init__(self, embedding_dim=512, sample_rate=16000):
 8        super().__init__()
 9        self.sample_rate = sample_rate
10        
11        # Mel spectrogram extractor
12        self.mel_spectrogram = T.MelSpectrogram(
13            sample_rate=sample_rate,
14            n_fft=2048,
15            hop_length=512,
16            n_mels=128
17        )
18        
19        # CNN feature extractor
20        self.feature_extractor = nn.Sequential(
21            nn.Conv2d(1, 64, kernel_size=3, padding=1),
22            nn.ReLU(),
23            nn.MaxPool2d(2),
24            nn.Conv2d(64, 128, kernel_size=3, padding=1),
25            nn.ReLU(),
26            nn.MaxPool2d(2),
27            nn.Conv2d(128, 256, kernel_size=3, padding=1),
28            nn.ReLU(),
29            nn.AdaptiveAvgPool2d((1, 1))
30        )
31        
32        # Projection head
33        self.projection = nn.Sequential(
34            nn.Linear(256, embedding_dim),
35            nn.ReLU(),
36            nn.Linear(embedding_dim, embedding_dim)
37        )
38    
39    def forward(self, waveform):
40        # waveform: [B, L] where L is audio length
41        # Convert to mel spectrogram: [B, 1, T, F]
42        mel_spec = self.mel_spectrogram(waveform).unsqueeze(1)
43        
44        # Extract features: [B, 256, 1, 1]
45        features = self.feature_extractor(mel_spec)
46        features = features.view(features.size(0), -1)
47        
48        # Project to embedding: [B, embedding_dim]
49        embedding = self.projection(features)
50        return nn.functional.normalize(embedding, p=2, dim=1)
51
52# Usage
53encoder = AudioEncoder(embedding_dim=512)
54encoder.eval()
55
56# Load audio
57waveform, sample_rate = torchaudio.load("audio.wav")
58# Resample if needed
59if sample_rate != 16000:
60    resampler = T.Resample(sample_rate, 16000)
61    waveform = resampler(waveform)
62
63# Ensure fixed length (e.g., 3 seconds)
64target_length = 16000 * 3
65if waveform.size(1) > target_length:
66    waveform = waveform[:, :target_length]
67else:
68    padding = target_length - waveform.size(1)
69    waveform = torch.nn.functional.pad(waveform, (0, padding))
70
71with torch.no_grad():
72    embedding = encoder(waveform)
73    # embedding shape: [1, 512]

Wav2Vec2-based Encoder

 1import torch
 2import torch.nn as nn
 3from transformers import Wav2Vec2Model, Wav2Vec2Processor
 4
 5class Wav2Vec2AudioEncoder(nn.Module):
 6    def __init__(self, model_name="facebook/wav2vec2-base"):
 7        super().__init__()
 8        self.wav2vec2 = Wav2Vec2Model.from_pretrained(model_name)
 9        self.embedding_dim = self.wav2vec2.config.hidden_size
10        self.processor = Wav2Vec2Processor.from_pretrained(model_name)
11    
12    def forward(self, input_values):
13        # input_values: [B, L] raw waveform
14        outputs = self.wav2vec2(input_values=input_values)
15        # Mean pooling over time: [B, L, d] -> [B, d]
16        hidden_states = outputs.last_hidden_state
17        embedding = hidden_states.mean(dim=1)
18        return nn.functional.normalize(embedding, p=2, dim=1)
19
20# Usage
21encoder = Wav2Vec2AudioEncoder()
22encoder.eval()
23
24# Load and preprocess audio
25waveform, sample_rate = torchaudio.load("audio.wav")
26# Wav2Vec2 expects 16kHz
27if sample_rate != 16000:
28    resampler = T.Resample(sample_rate, 16000)
29    waveform = resampler(waveform)
30
31# Process with Wav2Vec2 processor
32inputs = encoder.processor(
33    waveform.squeeze().numpy(),
34    sampling_rate=16000,
35    return_tensors="pt"
36)
37
38with torch.no_grad():
39    embedding = encoder(inputs.input_values)
40    # embedding shape: [1, 768]

Whisper-based Encoder

 1import torch
 2import torch.nn as nn
 3from transformers import WhisperModel, WhisperProcessor
 4
 5class WhisperAudioEncoder(nn.Module):
 6    def __init__(self, model_name="openai/whisper-base"):
 7        super().__init__()
 8        self.whisper = WhisperModel.from_pretrained(model_name)
 9        self.embedding_dim = self.whisper.config.d_model
10        self.processor = WhisperProcessor.from_pretrained(model_name)
11    
12    def forward(self, input_features):
13        # input_features: [B, T, F] mel spectrogram features
14        outputs = self.whisper.encoder(input_features=input_features)
15        # Mean pooling over time
16        hidden_states = outputs.last_hidden_state
17        embedding = hidden_states.mean(dim=1)
18        return nn.functional.normalize(embedding, p=2, dim=1)
19
20# Usage
21encoder = WhisperAudioEncoder()
22encoder.eval()
23
24waveform, sample_rate = torchaudio.load("audio.wav")
25inputs = encoder.processor(
26    waveform.squeeze().numpy(),
27    sampling_rate=16000,
28    return_tensors="pt"
29)
30
31with torch.no_grad():
32    embedding = encoder(input_features=inputs.input_features)
33    # embedding shape: [1, 512]

Training with Contrastive Loss

 1import torch
 2import torch.nn as nn
 3import torch.nn.functional as F
 4
 5class AudioTripletLoss(nn.Module):
 6    def __init__(self, margin=0.5):
 7        super().__init__()
 8        self.margin = margin
 9    
10    def forward(self, anchor, positive, negative):
11        pos_dist = F.pairwise_distance(anchor, positive)
12        neg_dist = F.pairwise_distance(anchor, negative)
13        loss = F.relu(pos_dist - neg_dist + self.margin)
14        return loss.mean()
15
16# Training loop
17encoder = AudioEncoder(embedding_dim=512)
18criterion = AudioTripletLoss(margin=0.5)
19optimizer = torch.optim.Adam(encoder.parameters(), lr=1e-4)
20
21for anchor_audio, positive_audio, negative_audio in dataloader:
22    optimizer.zero_grad()
23    
24    anchor_emb = encoder(anchor_audio)
25    positive_emb = encoder(positive_audio)
26    negative_emb = encoder(negative_audio)
27    
28    loss = criterion(anchor_emb, positive_emb, negative_emb)
29    loss.backward()
30    optimizer.step()

LangChain Integration

 1from langchain_community.embeddings import HuggingFaceEmbeddings
 2from langchain_community.vectorstores import FAISS
 3from langchain.schema import Document
 4import torchaudio
 5import base64
 6
 7# For audio embeddings, typically use custom models
 8# Example with Wav2Vec2 via HuggingFace
 9from transformers import Wav2Vec2Model, Wav2Vec2Processor
10
11class AudioEmbeddings:
12    def __init__(self, model_name="facebook/wav2vec2-base"):
13        self.model = Wav2Vec2Model.from_pretrained(model_name)
14        self.processor = Wav2Vec2Processor.from_pretrained(model_name)
15        self.model.eval()
16    
17    def embed_audio(self, audio_path):
18        waveform, sample_rate = torchaudio.load(audio_path)
19        if sample_rate != 16000:
20            resampler = T.Resample(sample_rate, 16000)
21            waveform = resampler(waveform)
22        
23        inputs = self.processor(
24            waveform.squeeze().numpy(),
25            sampling_rate=16000,
26            return_tensors="pt"
27        )
28        
29        with torch.no_grad():
30            outputs = self.model(**inputs)
31            embedding = outputs.last_hidden_state.mean(dim=1)
32            return embedding.squeeze().numpy()
33
34# Create embeddings
35embeddings_model = AudioEmbeddings()
36audio_embeddings = [
37    embeddings_model.embed_audio(f"audio_{i}.wav")
38    for i in range(100)
39]
40
41documents = [Document(page_content=f"Audio {i}") for i in range(100)]
42
43# Store in vector database
44vectorstore = FAISS.from_embeddings(
45    texts=[doc.page_content for doc in documents],
46    embeddings=audio_embeddings
47)
48
49# Similarity search
50query_emb = embeddings_model.embed_audio("query.wav")
51results = vectorstore.similarity_search_by_vector(query_emb, k=5)

Key Concepts

Mel Spectrogram: Converts linear frequency scale to mel scale (perceptually uniform): $$m = 2595 \log_{10}(1 + f/700)$$

where $f$ is frequency in Hz and $m$ is mel frequency.

Temporal Pooling: Aggregates variable-length sequences:

  • Mean Pooling: $\mathbf{v} = \frac{1}{T}\sum_{t=1}^{T} \mathbf{h}_t$
  • Max Pooling: $\mathbf{v} = \max_{t} \mathbf{h}_t$
  • Attention Pooling: $\mathbf{v} = \sum_{t} \alpha_t \mathbf{h}_t$ where $\alpha = \text{softmax}(\mathbf{W}\mathbf{H})$

Preprocessing Steps:

  1. Resampling: Normalize to target sample rate (typically 16kHz)
  2. Normalization: Zero-mean, unit-variance: $x' = \frac{x - \mu}{\sigma}$
  3. Padding/Truncation: Fixed-length sequences for batch processing

Similarity Metrics:

  • Cosine Similarity: $\text{sim}(\mathbf{v}_1, \mathbf{v}_2) = \frac{\mathbf{v}_1 \cdot \mathbf{v}_2}{||\mathbf{v}_1|| \cdot ||\mathbf{v}_2||}$
  • Euclidean Distance: $d(\mathbf{v}_1, \mathbf{v}_2) = ||\mathbf{v}_1 - \mathbf{v}_2||_2$

Related Snippets