Implement core speech-to-text pipeline

All major components: hotkey listener (rdev), audio capture (cpal),
resampling (rubato), VAD (Silero ONNX), Parakeet v3 TDT transcription
(ort), overlay window (winit+softbuffer), paste simulation (enigo+arboard),
audio feedback (rodio), YAML config, CLI with clap, HuggingFace model
download. ~2400 lines of Rust across 16 source files.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-10 16:47:46 +01:00
parent 6b737f92fe
commit 9b0bf7d9e3
22 changed files with 7750 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
/target
+50
View File
@@ -0,0 +1,50 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
Mouth is a single-binary, offline speech-to-text tool. Press a global hotkey, speak, and transcribed text is pasted at your cursor. Configured via YAML, no UI. Primary target is Windows; Linux/macOS supported where possible.
Uses Parakeet TDT 0.6B v3 (ONNX, from `istupakov/parakeet-tdt-0.6b-v3-onnx`) for transcription, Silero VAD v4 for voice activity detection.
## Build & Run
```bash
cargo build # debug build
cargo build --release # release build
cargo run # run daemon (default command)
cargo run -- config --show # show current config
cargo run -- config # interactive config TUI
cargo run -- config --reset # reset to defaults
cargo run -- models # list models
cargo run -- models --download # download configured model
cargo run -- status # daemon status
```
## Architecture
Single-binary Rust application. Core pipeline: hotkey capture (rdev) → audio recording (cpal) → resampling to 16kHz (rubato) → VAD (Silero ONNX) → mel spectrogram → transcription (Parakeet v3 TDT decoder via ort) → clipboard/paste (arboard + enigo). Minimal native overlay window (winit + softbuffer).
**Threading model:** Main thread owns the overlay window event loop (required by winit). Background threads: hotkey listener (rdev::listen is blocking), audio recorder (cpal stream), coordinator (state machine). All communicate via `std::sync::mpsc` channels.
**Coordinator state machine:** Idle → Recording → Transcribing → (Pasting) → Idle. Cancel from Recording returns to Idle.
**Parakeet v3 inference:** Two-stage ONNX model — encoder (FastConformer) produces features, decoder+joint (TDT transducer) greedily decodes tokens with duration predictions. Audio preprocessing: pre-emphasis → STFT → 128-band log-mel → per-utterance CMVN. Vocab is SentencePiece BPE with `▁` as word boundary marker.
**ort crate (v2.0.0-rc.12) notes:** Session::run needs `&mut self`. Input values must be converted to `Value::into_dyn()` before passing. Use `SessionInputValue::Owned(value.into_dyn())` pattern. `try_extract_tensor` returns `(&Shape, &[T])` tuple. `from_shape_vec` needs `[usize; N]` not `Vec<usize>`.
Config lives at `~/.config/mouth/config.yaml` (Linux/macOS) or `%APPDATA%\mouth\config.yaml` (Windows). Models cached via HuggingFace Hub standard cache (`~/.cache/huggingface/hub/`).
## Cross-Compilation
Developing on Ubuntu 24.04, targeting Windows:
```bash
cargo build --target x86_64-pc-windows-gnu
```
## System Dependencies (Ubuntu)
```bash
sudo apt-get install libssl-dev libasound2-dev libpulse-dev libx11-dev libxcb-shape0-dev libxcb-xfixes0-dev libxkbcommon-dev libwayland-dev libgtk-3-dev libxtst-dev libxdo-dev cmake
```
Generated
+4950
View File
File diff suppressed because it is too large Load Diff
+61
View File
@@ -0,0 +1,61 @@
[package]
name = "mouth"
version = "0.1.0"
edition = "2024"
description = "Offline speech-to-text with global hotkey and paste"
[dependencies]
# CLI
clap = { version = "4", features = ["derive"] }
# Config
serde = { version = "1", features = ["derive"] }
serde_yaml = "0.9"
dirs = "6"
# Interactive config TUI
dialoguer = "0.11"
# Logging
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
# Async
tokio = { version = "1", features = ["full"] }
# Global hotkey
rdev = "0.5"
# Audio capture
cpal = "0.15"
# Audio resampling
rubato = "0.16"
# ONNX inference (Parakeet v3 + Silero VAD)
ort = { version = "2.0.0-rc.12", features = ["download-binaries"] }
ndarray = "0.17"
# Model download from HuggingFace
hf-hub = "0.4"
indicatif = "0.17"
# Clipboard
arboard = "3"
# Keyboard simulation
enigo = { version = "0.3", features = ["serde"] }
# Overlay window
winit = "0.30"
softbuffer = "0.4"
# Audio feedback
rodio = "0.20"
# System info
num_cpus = "1"
# Error handling
anyhow = "1"
thiserror = "2"
+42
View File
@@ -0,0 +1,42 @@
# Mouth configuration
# Copy to ~/.config/mouth/config.yaml (Linux/macOS)
# or %APPDATA%\mouth\config.yaml (Windows)
# Hotkey to activate recording
hotkey: ctrl+space
# Recording mode: push_to_talk or toggle
mode: push_to_talk
# Cancel hotkey (only active while recording)
cancel_key: escape
# Speech-to-text model
model: parakeet-tdt-0.6b-v3
# Inference accelerator: auto, cpu, cuda, directml
accelerator: auto
# GPU device index (when accelerator is cuda/directml)
gpu_device: 0
# How to paste text: ctrl_v, shift_insert, ctrl_shift_v, clipboard_only
paste_method: ctrl_v
# Keep transcribed text on clipboard after pasting
copy_to_clipboard: true
# Overlay position: top, bottom, none
overlay_position: top
# Play audio feedback sounds
audio_feedback: true
# Audio input device name (null = system default)
input_device: null
# Voice activity detection (trim silence)
vad_enabled: true
# Language hint for model
language: en
+287
View File
@@ -0,0 +1,287 @@
# Mouth — Implementation Plan
## Overview
Mouth is a single-binary, offline speech-to-text tool for Windows (with Linux/macOS support where possible). Press a hotkey, speak, and transcribed text is pasted at your cursor. Configured entirely via YAML.
## Architecture
```
┌─────────────┐ ┌───────────┐ ┌─────────────┐ ┌────────────┐
│ Hotkey │────▶│ Recorder │────▶│ Transcriber │────▶│ Paste │
│ Listener │ │ (cpal) │ │ (ort/ONNX) │ │ (enigo) │
│ (rdev) │ │ │ │ │ │ │
└─────────────┘ └───────────┘ └─────────────┘ └────────────┘
│ │ │ │
│ ▼ │ │
│ ┌───────────┐ │ │
│ │ VAD │ │ │
│ │ (silero) │ │ │
│ └───────────┘ │ │
│ │ │
▼ ▼ ▼
┌──────────────────────────────────────────────────────────────────────┐
│ Overlay (winit) │
│ State: idle → recording → transcribing → done │
└──────────────────────────────────────────────────────────────────────┘
```
### Component Communication
All components communicate via channels (`std::sync::mpsc` or `tokio::sync`). The main thread owns the overlay window (required by most windowing systems). A coordinator task receives events from hotkey/recorder/transcriber and drives state transitions.
```
HotkeyEvent(Pressed/Released) ──┐
AudioReady(Vec<f32>) ───────────┼──▶ Coordinator ──▶ OverlayState
TranscriptionDone(String) ──────┘ ──▶ PasteAction
CancelRequested ────────────────┘
```
## Crate Dependencies
| Crate | Purpose | Notes |
|-------|---------|-------|
| `rdev` | Global hotkey capture | Cross-platform key events, no focus required |
| `cpal` | Audio capture | Cross-platform mic input |
| `rubato` | Audio resampling | Resample to 16kHz for Parakeet |
| `ort` | ONNX Runtime | Run Parakeet v3 + Silero VAD |
| `hf-hub` | Model download | Download from HuggingFace, standard cache dir |
| `enigo` | Keyboard simulation | Simulate Ctrl+V, Shift+Insert, etc. |
| `arboard` | Clipboard access | Read/write clipboard, save/restore |
| `winit` | Windowing | Minimal overlay window |
| `softbuffer` | Pixel rendering | Draw coloured overlay (no GPU needed for overlay) |
| `serde` + `serde_yaml` | Config | Deserialize YAML config |
| `clap` | CLI | Subcommands: `run`, `config`, `models` |
| `dialoguer` | Interactive TUI | `mouth config` interactive setup |
| `rodio` | Audio playback | Blip up/down sounds |
| `indicatif` | Progress bars | Model download progress |
| `dirs` | Platform dirs | Config/cache paths |
| `tracing` | Logging | Structured logging |
## Config File
Location: `~/.config/mouth/config.yaml` (Linux/macOS), `%APPDATA%\mouth\config.yaml` (Windows)
```yaml
# Hotkey to activate recording
hotkey: "ctrl+space"
# Recording mode: push_to_talk or toggle
mode: push_to_talk
# Cancel hotkey (only active while recording)
cancel_key: "escape"
# Speech-to-text model
model: "parakeet-tdt-0.6b-v3"
# Inference accelerator: auto, cpu, cuda, directml
accelerator: auto
# GPU device index (only used when accelerator is cuda/directml)
gpu_device: 0
# How to paste text
paste_method: ctrl_v # ctrl_v | shift_insert | ctrl_shift_v | clipboard_only
# Also keep transcribed text on clipboard after pasting
copy_to_clipboard: true
# Overlay position on screen
overlay_position: top # top | bottom | none
# Audio feedback
audio_feedback: true
# Audio input device (null = system default)
input_device: null
# VAD: trim silence from audio before transcription
vad_enabled: true
# Language (for model hint, if supported)
language: en
```
## CLI Interface
```
mouth run # Start the daemon (default if no subcommand)
mouth config # Interactive TUI to edit config
mouth config --show # Print current config to stdout
mouth config --reset # Reset config to defaults
mouth models # List available/downloaded models
mouth models download # Download configured model (if not cached)
mouth status # Show daemon status, loaded model, app version
```
## Implementation Phases
### Phase 1: Project Skeleton + Config
- Cargo.toml with all dependencies
- Config struct with serde, defaults, load/save
- CLI with clap (run, config, models subcommands)
- `mouth config` interactive TUI with dialoguer
- Platform-aware config/cache directory resolution
### Phase 2: Hotkey Listener
- Global hotkey capture using rdev
- Support configurable key combinations (parse from string like "ctrl+space")
- Push-to-talk mode: record on press, stop on release
- Toggle mode: start on first press, stop on second press
- Cancel on Escape while recording
- Debounce rapid key events (~30ms)
### Phase 3: Audio Capture + VAD
- Open mic input via cpal (default device or configured)
- Convert to f32 mono
- Resample to 16kHz via rubato
- Buffer audio chunks during recording
- Run Silero VAD to trim leading/trailing silence
- Produce final `Vec<f32>` of clean speech at 16kHz
### Phase 4: Model Management
- Use hf-hub to download Parakeet v3 ONNX model from HuggingFace
- Store in standard HF cache (`~/.cache/huggingface/hub/`)
- Show download progress with indicatif
- `mouth models` command to list/download models
- Auto-download on first run if model not cached
### Phase 5: Transcription
- Load Parakeet v3 ONNX model via ort
- Auto-detect GPU (DirectML on Windows, CUDA if available, CPU fallback)
- Respect accelerator override from config
- Run inference on captured audio
- Return transcribed text string
### Phase 6: Overlay
- Create a small always-on-top window using winit
- Render with softbuffer (simple coloured rectangle + text)
- States and colours:
- Recording: red pulsing indicator
- Transcribing: amber/yellow
- Done: brief green flash, then hide
- Error: brief red flash with error hint
- Window flags (Windows): `WS_EX_TOPMOST | WS_EX_TOOLWINDOW | WS_EX_NOACTIVATE`
- Position: centered horizontally at top or bottom of current monitor
- No focus steal, no taskbar entry
### Phase 7: Paste System
- Save current clipboard content (if preserving)
- Set transcribed text to clipboard via arboard
- Simulate keypress via enigo based on paste_method:
- `ctrl_v`: Ctrl+V (Cmd+V on macOS)
- `shift_insert`: Shift+Insert
- `ctrl_shift_v`: Ctrl+Shift+V
- `clipboard_only`: no keypress, just clipboard
- Restore previous clipboard content (unless copy_to_clipboard is true)
- Small delay between clipboard set and paste simulation (~50ms)
### Phase 8: Audio Feedback
- Bundle two short PCM blip sounds in the binary (via `include_bytes!`)
- "Blip up" on recording start
- "Blip down" on recording stop / transcription complete
- Play via rodio on a separate thread (non-blocking)
- Respect audio_feedback config flag
### Phase 9: Coordinator + Integration
- Wire all components together with channel-based message passing
- Main thread: overlay window event loop (winit requires this)
- Spawned threads/tasks: hotkey listener, audio recorder, transcriber
- Coordinator receives events, drives state machine:
```
Idle ──[hotkey press]──▶ Recording
Recording ──[hotkey release/press]──▶ Transcribing
Recording ──[cancel]──▶ Idle
Transcribing ──[result]──▶ Pasting ──▶ Idle
Transcribing ──[error]──▶ Error ──▶ Idle
```
- Graceful shutdown on SIGINT / tray quit
### Phase 10: Daemon IPC + Status
- The running daemon listens on a local Unix domain socket (Linux/macOS) or named pipe (Windows) for status queries
- Socket/pipe path: `/tmp/mouth.sock` (Linux/macOS), `\\.\pipe\mouth` (Windows)
- `mouth status` connects and requests current state; daemon responds with JSON:
```json
{
"version": "0.1.0",
"state": "idle",
"model": "parakeet-tdt-0.6b-v3",
"accelerator": "directml",
"uptime_secs": 3420
}
```
- If the daemon is not running, `mouth status` reports "Mouth is not running" and exits with code 1
- Also used internally to prevent launching a second daemon instance (lock check)
### Phase 11: Polish + Distribution
- Error handling: user-friendly messages for common failures (no mic, model not found, etc.)
- Windows installer via `cargo-wix` or distribute as standalone .exe
- Test on Windows 10/11 primarily
- Test on Linux (X11 + Wayland) and macOS as secondary
- Update CLAUDE.md with build/run/test instructions
- Write user-facing README with setup instructions
## Risks & Mitigations
| Risk | Impact | Mitigation |
|------|--------|------------|
| Parakeet v3 ONNX model compatibility with `ort` | Blocks core functionality | Test early in Phase 5; Parakeet v2 as fallback |
| `rdev` hotkey reliability on Windows | Broken UX | Test early in Phase 2; fallback to Win32 `RegisterHotKey` |
| Overlay focus stealing | Annoying | Use proper window flags; test with various foreground apps |
| Audio resampling quality | Poor transcription | Use rubato SincInterpolation (high quality) |
| Binary size with bundled ONNX Runtime | Large download | ONNX Runtime is ~20-40MB; acceptable for a single-binary tool |
| winit event loop blocking | Unresponsive | All heavy work on background threads; overlay is lightweight |
## File Structure
```
mouth/
├── Cargo.toml
├── CLAUDE.md
├── README.md
├── plan.md
├── config.yaml.example
├── resources/
│ ├── blip_up.pcm # bundled audio feedback
│ └── blip_down.pcm
└── src/
├── main.rs # CLI entry, clap setup
├── config.rs # Config struct, YAML load/save, defaults
├── hotkey.rs # Global hotkey listener (rdev)
├── recorder.rs # Audio capture (cpal + rubato + VAD)
├── vad.rs # Silero VAD wrapper
├── transcriber.rs # ONNX inference, model loading, GPU detection
├── model_cache.rs # HuggingFace download, cache management
├── overlay.rs # Minimal overlay window (winit + softbuffer)
├── paste.rs # Clipboard + paste simulation
├── audio_feedback.rs # Blip sounds via rodio
├── coordinator.rs # State machine, channel hub
└── cli/
├── mod.rs
├── run.rs # `mouth run` handler
├── config_cmd.rs # `mouth config` TUI
├── models_cmd.rs # `mouth models` handler
└── status_cmd.rs # `mouth status` handler
```
## Not In Scope (v1)
- LLM post-processing of transcriptions
- Transcription history / database
- Multiple model support (v1 is Parakeet v3 only, architecture supports adding more later)
- Auto-submit (Enter after paste)
- Multi-language UI
- Tray icon / system tray integration
- Translate-to-English mode
+103
View File
@@ -0,0 +1,103 @@
use anyhow::Result;
use rodio::{OutputStream, Sink};
use std::io::Cursor;
use std::thread;
use std::time::Duration;
use tracing::{debug, warn};
/// Generate a simple sine wave blip sound.
fn generate_blip(freq_start: f32, freq_end: f32, duration_ms: u64) -> Vec<i16> {
let sample_rate = 44100u32;
let num_samples = (sample_rate as u64 * duration_ms / 1000) as usize;
let mut samples = Vec::with_capacity(num_samples);
for i in 0..num_samples {
let t = i as f32 / sample_rate as f32;
let progress = i as f32 / num_samples as f32;
// Linear frequency sweep
let freq = freq_start + (freq_end - freq_start) * progress;
// Sine wave with envelope (fade in/out)
let envelope = if progress < 0.1 {
progress / 0.1
} else if progress > 0.8 {
(1.0 - progress) / 0.2
} else {
1.0
};
let sample = (envelope * 0.3 * (2.0 * std::f32::consts::PI * freq * t).sin()) * i16::MAX as f32;
samples.push(sample as i16);
}
samples
}
/// Encode samples as a WAV file in memory.
fn encode_wav(samples: &[i16], sample_rate: u32) -> Vec<u8> {
let mut buf = Vec::new();
let data_len = (samples.len() * 2) as u32;
let file_len = 36 + data_len;
// RIFF header
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&file_len.to_le_bytes());
buf.extend_from_slice(b"WAVE");
// fmt chunk
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
buf.extend_from_slice(&1u16.to_le_bytes()); // PCM
buf.extend_from_slice(&1u16.to_le_bytes()); // mono
buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&(sample_rate * 2).to_le_bytes()); // byte rate
buf.extend_from_slice(&2u16.to_le_bytes()); // block align
buf.extend_from_slice(&16u16.to_le_bytes()); // bits per sample
// data chunk
buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_len.to_le_bytes());
for sample in samples {
buf.extend_from_slice(&sample.to_le_bytes());
}
buf
}
/// Play the "blip up" sound (recording started).
pub fn play_blip_up() {
thread::spawn(|| {
if let Err(e) = play_blip_internal(800.0, 1200.0, 100) {
warn!("Failed to play blip up: {e}");
}
});
}
/// Play the "blip down" sound (recording stopped / transcription done).
pub fn play_blip_down() {
thread::spawn(|| {
if let Err(e) = play_blip_internal(1200.0, 800.0, 100) {
warn!("Failed to play blip down: {e}");
}
});
}
fn play_blip_internal(freq_start: f32, freq_end: f32, duration_ms: u64) -> Result<()> {
let samples = generate_blip(freq_start, freq_end, duration_ms);
let wav_data = encode_wav(&samples, 44100);
let (_stream, stream_handle) = OutputStream::try_default()?;
let sink = Sink::try_new(&stream_handle)?;
let cursor = Cursor::new(wav_data);
let source = rodio::Decoder::new(cursor)?;
sink.append(source);
debug!("Playing blip ({freq_start}Hz -> {freq_end}Hz, {duration_ms}ms)");
sink.sleep_until_end();
// Keep stream alive briefly
thread::sleep(Duration::from_millis(50));
Ok(())
}
+127
View File
@@ -0,0 +1,127 @@
use anyhow::Result;
use dialoguer::{Input, Select};
use crate::config::{Accelerator, Config, OverlayPosition, PasteMethod, RecordingMode};
pub fn show() -> Result<()> {
let config = Config::load()?;
let yaml = serde_yaml::to_string(&config)?;
println!("{yaml}");
Ok(())
}
pub fn reset() -> Result<()> {
let config = Config::default();
config.save()?;
println!("Config reset to defaults at {}", Config::path()?.display());
Ok(())
}
pub fn interactive() -> Result<()> {
let mut config = Config::load()?;
config.hotkey = Input::new()
.with_prompt("Hotkey")
.default(config.hotkey)
.interact_text()?;
let mode_idx = Select::new()
.with_prompt("Recording mode")
.items(&["push_to_talk", "toggle"])
.default(match config.mode {
RecordingMode::PushToTalk => 0,
RecordingMode::Toggle => 1,
})
.interact()?;
config.mode = match mode_idx {
0 => RecordingMode::PushToTalk,
_ => RecordingMode::Toggle,
};
config.cancel_key = Input::new()
.with_prompt("Cancel key")
.default(config.cancel_key)
.interact_text()?;
config.model = Input::new()
.with_prompt("Model")
.default(config.model)
.interact_text()?;
let accel_idx = Select::new()
.with_prompt("Accelerator")
.items(&["auto", "cpu", "cuda", "directml"])
.default(match config.accelerator {
Accelerator::Auto => 0,
Accelerator::Cpu => 1,
Accelerator::Cuda => 2,
Accelerator::DirectMl => 3,
})
.interact()?;
config.accelerator = match accel_idx {
0 => Accelerator::Auto,
1 => Accelerator::Cpu,
2 => Accelerator::Cuda,
_ => Accelerator::DirectMl,
};
config.gpu_device = Input::new()
.with_prompt("GPU device index")
.default(config.gpu_device)
.interact_text()?;
let paste_idx = Select::new()
.with_prompt("Paste method")
.items(&["ctrl_v", "shift_insert", "ctrl_shift_v", "clipboard_only"])
.default(match config.paste_method {
PasteMethod::CtrlV => 0,
PasteMethod::ShiftInsert => 1,
PasteMethod::CtrlShiftV => 2,
PasteMethod::ClipboardOnly => 3,
})
.interact()?;
config.paste_method = match paste_idx {
0 => PasteMethod::CtrlV,
1 => PasteMethod::ShiftInsert,
2 => PasteMethod::CtrlShiftV,
_ => PasteMethod::ClipboardOnly,
};
let overlay_idx = Select::new()
.with_prompt("Overlay position")
.items(&["top", "bottom", "none"])
.default(match config.overlay_position {
OverlayPosition::Top => 0,
OverlayPosition::Bottom => 1,
OverlayPosition::None => 2,
})
.interact()?;
config.overlay_position = match overlay_idx {
0 => OverlayPosition::Top,
1 => OverlayPosition::Bottom,
_ => OverlayPosition::None,
};
let feedback_idx = Select::new()
.with_prompt("Audio feedback")
.items(&["yes", "no"])
.default(if config.audio_feedback { 0 } else { 1 })
.interact()?;
config.audio_feedback = feedback_idx == 0;
let vad_idx = Select::new()
.with_prompt("VAD (voice activity detection)")
.items(&["enabled", "disabled"])
.default(if config.vad_enabled { 0 } else { 1 })
.interact()?;
config.vad_enabled = vad_idx == 0;
config.language = Input::new()
.with_prompt("Language")
.default(config.language)
.interact_text()?;
config.save()?;
println!("\nConfig saved to {}", Config::path()?.display());
Ok(())
}
+4
View File
@@ -0,0 +1,4 @@
pub mod config_cmd;
pub mod models_cmd;
pub mod run_cmd;
pub mod status_cmd;
+30
View File
@@ -0,0 +1,30 @@
use anyhow::Result;
use crate::config::Config;
use crate::model_cache;
pub fn list() -> Result<()> {
let config = Config::load()?;
println!("Configured model: {}", config.model);
println!();
println!("Available models:");
for (name, cached) in model_cache::list_models() {
let status = if cached { "downloaded" } else { "not downloaded" };
let marker = if name == config.model { " (active)" } else { "" };
println!(" {name}{marker} [{status}]");
}
Ok(())
}
pub fn download() -> Result<()> {
let config = Config::load()?;
println!("Downloading model: {}...", config.model);
let paths = model_cache::ensure_model(&config.model)?;
println!("Model ready:");
println!(" Encoder: {}", paths.encoder.display());
println!(" Decoder: {}", paths.decoder.display());
println!(" Vocab: {}", paths.vocab.display());
Ok(())
}
+116
View File
@@ -0,0 +1,116 @@
use anyhow::{Context, Result};
use std::sync::mpsc;
use std::thread;
use tracing::info;
use crate::config::{Config, OverlayPosition};
use crate::coordinator::Coordinator;
use crate::hotkey;
use crate::model_cache;
use crate::overlay;
use crate::recorder;
use crate::transcriber::Transcriber;
pub fn run() -> Result<()> {
let config = Config::load()?;
info!("Mouth v{} starting", env!("CARGO_PKG_VERSION"));
info!("Mode: {:?}", config.mode);
info!("Hotkey: {}", config.hotkey);
info!("Model: {}", config.model);
info!("Accelerator: {:?}", config.accelerator);
info!("Paste method: {:?}", config.paste_method);
// Step 1: Ensure model is downloaded
info!("Checking model...");
let model_paths = model_cache::ensure_model(&config.model)
.context("Failed to ensure model is available")?;
// Step 2: Load transcriber
info!("Loading transcription engine...");
let transcriber = Transcriber::new(&model_paths, &config.accelerator, config.gpu_device)
.context("Failed to load transcription engine")?;
// Step 3: VAD (not yet bundled)
let vad = if config.vad_enabled {
info!("VAD enabled but Silero model not yet bundled — skipping");
None
} else {
None
};
// Step 4: Parse hotkeys
let hotkey_combo = hotkey::parse_hotkey(&config.hotkey)
.with_context(|| format!("Invalid hotkey: {}", config.hotkey))?;
let cancel_combo = hotkey::parse_hotkey(&config.cancel_key)
.with_context(|| format!("Invalid cancel key: {}", config.cancel_key))?;
// Step 5: Set up channels
let (hotkey_tx, hotkey_rx) = mpsc::channel();
let (recorder_cmd_tx, recorder_cmd_rx) = mpsc::channel();
let (audio_tx, audio_rx) = mpsc::channel();
// Step 6: Spawn background threads
let device_name = config.input_device.clone();
thread::Builder::new()
.name("mouth-recorder".into())
.spawn(move || {
recorder::run(device_name, recorder_cmd_rx, audio_tx);
})
.context("Failed to spawn recorder thread")?;
thread::Builder::new()
.name("mouth-hotkey".into())
.spawn(move || {
hotkey::listen(hotkey_combo, cancel_combo, hotkey_tx);
})
.context("Failed to spawn hotkey thread")?;
// Step 7: Start overlay + coordinator
if config.overlay_position != OverlayPosition::None {
let (event_loop, proxy) = overlay::create_event_loop()
.map_err(|e| anyhow::anyhow!("Failed to create overlay event loop: {e}"))?;
let overlay_position = config.overlay_position.clone();
let coord_proxy = Some(proxy);
// Coordinator runs on a background thread
let coord_config = config.clone();
thread::Builder::new()
.name("mouth-coordinator".into())
.spawn(move || {
let mut coordinator = Coordinator::new(
coord_config,
transcriber,
vad,
recorder_cmd_tx,
audio_rx,
hotkey_rx,
coord_proxy,
);
coordinator.run();
})
.context("Failed to spawn coordinator thread")?;
println!("Mouth is running. Press {} to record. Ctrl+C to quit.", config.hotkey);
// Overlay event loop runs on main thread (blocking)
overlay::run_event_loop(event_loop, overlay_position)
.map_err(|e| anyhow::anyhow!("Overlay event loop error: {e}"))?;
} else {
// No overlay — coordinator runs on main thread
println!("Mouth is running. Press {} to record. Ctrl+C to quit.", config.hotkey);
let mut coordinator = Coordinator::new(
config,
transcriber,
vad,
recorder_cmd_tx,
audio_rx,
hotkey_rx,
None,
);
coordinator.run();
}
Ok(())
}
+11
View File
@@ -0,0 +1,11 @@
use anyhow::Result;
pub fn status() -> Result<()> {
let version = env!("CARGO_PKG_VERSION");
// TODO: Phase 10 — connect to daemon IPC socket/pipe and query status
// For now, just show version info
println!("Mouth v{version}");
println!("Status: not yet implemented (requires daemon IPC)");
Ok(())
}
+184
View File
@@ -0,0 +1,184 @@
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum RecordingMode {
PushToTalk,
Toggle,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum PasteMethod {
CtrlV,
ShiftInsert,
CtrlShiftV,
ClipboardOnly,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum OverlayPosition {
Top,
Bottom,
None,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum Accelerator {
Auto,
Cpu,
Cuda,
DirectMl,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
/// Hotkey to activate recording
#[serde(default = "defaults::hotkey")]
pub hotkey: String,
/// Recording mode
#[serde(default = "defaults::mode")]
pub mode: RecordingMode,
/// Cancel hotkey (only active while recording)
#[serde(default = "defaults::cancel_key")]
pub cancel_key: String,
/// Speech-to-text model identifier
#[serde(default = "defaults::model")]
pub model: String,
/// Inference accelerator
#[serde(default = "defaults::accelerator")]
pub accelerator: Accelerator,
/// GPU device index (when accelerator is cuda/directml)
#[serde(default)]
pub gpu_device: u32,
/// How to paste transcribed text
#[serde(default = "defaults::paste_method")]
pub paste_method: PasteMethod,
/// Keep transcribed text on clipboard after pasting
#[serde(default = "defaults::yes")]
pub copy_to_clipboard: bool,
/// Overlay position on screen
#[serde(default = "defaults::overlay_position")]
pub overlay_position: OverlayPosition,
/// Play audio feedback sounds
#[serde(default = "defaults::yes")]
pub audio_feedback: bool,
/// Audio input device name (null = system default)
#[serde(default)]
pub input_device: Option<String>,
/// Enable VAD to trim silence
#[serde(default = "defaults::yes")]
pub vad_enabled: bool,
/// Language hint for model
#[serde(default = "defaults::language")]
pub language: String,
}
mod defaults {
use super::*;
pub fn hotkey() -> String {
"ctrl+space".into()
}
pub fn mode() -> RecordingMode {
RecordingMode::PushToTalk
}
pub fn cancel_key() -> String {
"escape".into()
}
pub fn model() -> String {
"parakeet-tdt-0.6b-v3".into()
}
pub fn accelerator() -> Accelerator {
Accelerator::Auto
}
pub fn paste_method() -> PasteMethod {
PasteMethod::CtrlV
}
pub fn overlay_position() -> OverlayPosition {
OverlayPosition::Top
}
pub fn yes() -> bool {
true
}
pub fn language() -> String {
"en".into()
}
}
impl Default for Config {
fn default() -> Self {
Self {
hotkey: defaults::hotkey(),
mode: defaults::mode(),
cancel_key: defaults::cancel_key(),
model: defaults::model(),
accelerator: defaults::accelerator(),
gpu_device: 0,
paste_method: defaults::paste_method(),
copy_to_clipboard: true,
overlay_position: defaults::overlay_position(),
audio_feedback: true,
input_device: None,
vad_enabled: true,
language: defaults::language(),
}
}
}
impl Config {
/// Returns the platform-appropriate config directory.
pub fn dir() -> Result<PathBuf> {
let dir = dirs::config_dir()
.context("Could not determine config directory")?
.join("mouth");
Ok(dir)
}
/// Returns the path to the config file.
pub fn path() -> Result<PathBuf> {
Ok(Self::dir()?.join("config.yaml"))
}
/// Load config from disk, falling back to defaults if file doesn't exist.
pub fn load() -> Result<Self> {
let path = Self::path()?;
if !path.exists() {
return Ok(Self::default());
}
let contents = std::fs::read_to_string(&path)
.with_context(|| format!("Failed to read config from {}", path.display()))?;
let config: Config = serde_yaml::from_str(&contents)
.with_context(|| format!("Failed to parse config from {}", path.display()))?;
Ok(config)
}
/// Save config to disk, creating the directory if needed.
pub fn save(&self) -> Result<()> {
let path = Self::path()?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("Failed to create config directory {}", parent.display()))?;
}
let yaml = serde_yaml::to_string(self).context("Failed to serialize config")?;
std::fs::write(&path, yaml)
.with_context(|| format!("Failed to write config to {}", path.display()))?;
Ok(())
}
}
+255
View File
@@ -0,0 +1,255 @@
use std::sync::mpsc;
use std::thread;
use tracing::{debug, error, info, warn};
use winit::event_loop::EventLoopProxy;
use crate::audio_feedback;
use crate::config::{Config, RecordingMode};
use crate::hotkey::HotkeyEvent;
use crate::overlay::{OverlayEvent, OverlayState};
use crate::paste;
use crate::recorder::{AudioData, RecorderCommand};
use crate::transcriber::Transcriber;
use crate::vad::Vad;
/// The application state machine.
#[derive(Debug, Clone, Copy, PartialEq)]
enum State {
Idle,
Recording,
Transcribing,
}
/// Central coordinator that wires all components together.
pub struct Coordinator {
config: Config,
state: State,
transcriber: Transcriber,
vad: Option<Vad>,
recorder_tx: mpsc::Sender<RecorderCommand>,
audio_rx: mpsc::Receiver<AudioData>,
hotkey_rx: mpsc::Receiver<HotkeyEvent>,
overlay_proxy: Option<EventLoopProxy<OverlayEvent>>,
}
impl Coordinator {
pub fn new(
config: Config,
transcriber: Transcriber,
vad: Option<Vad>,
recorder_tx: mpsc::Sender<RecorderCommand>,
audio_rx: mpsc::Receiver<AudioData>,
hotkey_rx: mpsc::Receiver<HotkeyEvent>,
overlay_proxy: Option<EventLoopProxy<OverlayEvent>>,
) -> Self {
Self {
config,
state: State::Idle,
transcriber,
vad,
recorder_tx,
audio_rx,
hotkey_rx,
overlay_proxy,
}
}
/// Run the coordinator loop. This blocks until shutdown.
pub fn run(&mut self) {
info!("Coordinator started");
loop {
// Wait for hotkey events
let event = match self.hotkey_rx.recv() {
Ok(e) => e,
Err(_) => {
info!("Hotkey channel closed, shutting down");
break;
}
};
self.handle_event(event);
}
self.shutdown();
}
fn handle_event(&mut self, event: HotkeyEvent) {
debug!("Event: {:?}, State: {:?}", event, self.state);
match (self.state, event) {
// Start recording
(State::Idle, HotkeyEvent::Pressed) => {
self.start_recording();
}
// Push-to-talk: stop on release
(State::Recording, HotkeyEvent::Released) => {
if self.config.mode == RecordingMode::PushToTalk {
self.stop_recording();
}
}
// Toggle mode: stop on second press
(State::Recording, HotkeyEvent::Pressed) => {
if self.config.mode == RecordingMode::Toggle {
self.stop_recording();
}
}
// Cancel recording
(State::Recording, HotkeyEvent::Cancel) => {
self.cancel_recording();
}
// Ignore other combinations
_ => {
debug!("Ignoring event {:?} in state {:?}", event, self.state);
}
}
}
fn start_recording(&mut self) {
info!("Recording started");
self.state = State::Recording;
self.set_overlay(OverlayState::Recording);
if self.config.audio_feedback {
audio_feedback::play_blip_up();
}
if self.recorder_tx.send(RecorderCommand::Start).is_err() {
error!("Failed to send start command to recorder");
self.state = State::Idle;
self.set_overlay(OverlayState::Hidden);
}
}
fn stop_recording(&mut self) {
info!("Recording stopped, starting transcription");
self.state = State::Transcribing;
self.set_overlay(OverlayState::Transcribing);
if self.config.audio_feedback {
audio_feedback::play_blip_down();
}
if self.recorder_tx.send(RecorderCommand::Stop).is_err() {
error!("Failed to send stop command to recorder");
self.state = State::Idle;
self.set_overlay(OverlayState::Hidden);
return;
}
// Wait for audio data
match self.audio_rx.recv() {
Ok(audio) => {
self.process_audio(audio);
}
Err(_) => {
error!("Failed to receive audio data");
self.state = State::Idle;
self.set_overlay(OverlayState::Error);
self.delayed_hide_overlay();
}
}
}
fn cancel_recording(&mut self) {
info!("Recording cancelled");
self.state = State::Idle;
if self.recorder_tx.send(RecorderCommand::Stop).is_err() {
warn!("Failed to send stop command to recorder");
}
// Drain any pending audio
while self.audio_rx.try_recv().is_ok() {}
self.set_overlay(OverlayState::Hidden);
}
fn process_audio(&mut self, audio: AudioData) {
let samples = if self.config.vad_enabled {
if let Some(vad) = &mut self.vad {
match vad.filter_speech(&audio.samples) {
Ok(filtered) => {
if filtered.is_empty() {
info!("No speech detected by VAD");
self.state = State::Idle;
self.set_overlay(OverlayState::Hidden);
return;
}
filtered
}
Err(e) => {
warn!("VAD failed, using raw audio: {e}");
audio.samples
}
}
} else {
audio.samples
}
} else {
audio.samples
};
// Transcribe
match self.transcriber.transcribe(&samples) {
Ok(text) => {
if text.is_empty() {
info!("Empty transcription");
self.state = State::Idle;
self.set_overlay(OverlayState::Hidden);
return;
}
info!("Transcribed: \"{text}\"");
self.set_overlay(OverlayState::Done);
// Paste the text
if let Err(e) = paste::paste_text(
&text,
&self.config.paste_method,
self.config.copy_to_clipboard,
) {
error!("Failed to paste text: {e}");
self.set_overlay(OverlayState::Error);
}
self.delayed_hide_overlay();
self.state = State::Idle;
}
Err(e) => {
error!("Transcription failed: {e}");
self.state = State::Idle;
self.set_overlay(OverlayState::Error);
self.delayed_hide_overlay();
}
}
}
fn set_overlay(&self, state: OverlayState) {
if let Some(proxy) = &self.overlay_proxy {
let _ = proxy.send_event(OverlayEvent::SetState(state));
}
}
fn delayed_hide_overlay(&self) {
if let Some(proxy) = &self.overlay_proxy {
let proxy = proxy.clone();
thread::spawn(move || {
thread::sleep(std::time::Duration::from_millis(500));
let _ = proxy.send_event(OverlayEvent::SetState(OverlayState::Hidden));
});
}
}
fn shutdown(&self) {
info!("Coordinator shutting down");
let _ = self.recorder_tx.send(RecorderCommand::Shutdown);
if let Some(proxy) = &self.overlay_proxy {
let _ = proxy.send_event(OverlayEvent::Shutdown);
}
}
}
+240
View File
@@ -0,0 +1,240 @@
use anyhow::{bail, Result};
use rdev::{self, Event, EventType, Key};
use std::sync::mpsc;
use std::time::{Duration, Instant};
use tracing::{debug, error, info};
/// Events sent from the hotkey listener to the coordinator.
#[derive(Debug, Clone, Copy)]
pub enum HotkeyEvent {
/// Hotkey was pressed (start recording or toggle)
Pressed,
/// Hotkey was released (stop recording in push-to-talk mode)
Released,
/// Cancel key was pressed
Cancel,
}
/// Parsed hotkey combination.
#[derive(Debug, Clone)]
pub struct HotkeyCombination {
pub modifiers: Vec<Modifier>,
pub key: Key,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Modifier {
Ctrl,
Alt,
Shift,
Meta,
}
/// Tracks which modifier keys are currently held down.
#[derive(Debug, Default)]
struct ModifierState {
ctrl: bool,
alt: bool,
shift: bool,
meta: bool,
}
impl ModifierState {
fn update(&mut self, key: &Key, pressed: bool) {
match key {
Key::ControlLeft | Key::ControlRight => self.ctrl = pressed,
Key::Alt | Key::AltGr => self.alt = pressed,
Key::ShiftLeft | Key::ShiftRight => self.shift = pressed,
Key::MetaLeft | Key::MetaRight => self.meta = pressed,
_ => {}
}
}
fn is_held(&self, modifier: &Modifier) -> bool {
match modifier {
Modifier::Ctrl => self.ctrl,
Modifier::Alt => self.alt,
Modifier::Shift => self.shift,
Modifier::Meta => self.meta,
}
}
fn all_held(&self, modifiers: &[Modifier]) -> bool {
modifiers.iter().all(|m| self.is_held(m))
}
}
/// Parse a hotkey string like "ctrl+space" into a HotkeyCombination.
pub fn parse_hotkey(s: &str) -> Result<HotkeyCombination> {
let lowered = s.to_lowercase();
let parts: Vec<&str> = lowered.split('+').map(|p| p.trim()).collect();
if parts.is_empty() {
bail!("Empty hotkey string");
}
let mut modifiers = Vec::new();
let mut key = None;
for part in &parts {
match *part {
"ctrl" | "control" => modifiers.push(Modifier::Ctrl),
"alt" => modifiers.push(Modifier::Alt),
"shift" => modifiers.push(Modifier::Shift),
"meta" | "super" | "win" | "cmd" => modifiers.push(Modifier::Meta),
_ => {
if key.is_some() {
bail!("Multiple non-modifier keys in hotkey: {s}");
}
key = Some(parse_key(part)?);
}
}
}
let key = key.unwrap_or_else(|| {
// If no non-modifier key, treat last modifier as the key
// This shouldn't happen with valid config but handle gracefully
Key::Unknown(0)
});
Ok(HotkeyCombination { modifiers, key })
}
fn parse_key(s: &str) -> Result<Key> {
let key = match s {
"space" => Key::Space,
"enter" | "return" => Key::Return,
"escape" | "esc" => Key::Escape,
"tab" => Key::Tab,
"backspace" => Key::Backspace,
"delete" | "del" => Key::Delete,
"insert" => Key::Insert,
"home" => Key::Home,
"end" => Key::End,
"pageup" => Key::PageUp,
"pagedown" => Key::PageDown,
"up" => Key::UpArrow,
"down" => Key::DownArrow,
"left" => Key::LeftArrow,
"right" => Key::RightArrow,
"f1" => Key::F1,
"f2" => Key::F2,
"f3" => Key::F3,
"f4" => Key::F4,
"f5" => Key::F5,
"f6" => Key::F6,
"f7" => Key::F7,
"f8" => Key::F8,
"f9" => Key::F9,
"f10" => Key::F10,
"f11" => Key::F11,
"f12" => Key::F12,
"a" => Key::KeyA,
"b" => Key::KeyB,
"c" => Key::KeyC,
"d" => Key::KeyD,
"e" => Key::KeyE,
"f" => Key::KeyF,
"g" => Key::KeyG,
"h" => Key::KeyH,
"i" => Key::KeyI,
"j" => Key::KeyJ,
"k" => Key::KeyK,
"l" => Key::KeyL,
"m" => Key::KeyM,
"n" => Key::KeyN,
"o" => Key::KeyO,
"p" => Key::KeyP,
"q" => Key::KeyQ,
"r" => Key::KeyR,
"s" => Key::KeyS,
"t" => Key::KeyT,
"u" => Key::KeyU,
"v" => Key::KeyV,
"w" => Key::KeyW,
"x" => Key::KeyX,
"y" => Key::KeyY,
"z" => Key::KeyZ,
"0" => Key::Num0,
"1" => Key::Num1,
"2" => Key::Num2,
"3" => Key::Num3,
"4" => Key::Num4,
"5" => Key::Num5,
"6" => Key::Num6,
"7" => Key::Num7,
"8" => Key::Num8,
"9" => Key::Num9,
_ => bail!("Unknown key: {s}"),
};
Ok(key)
}
/// Start the global hotkey listener on the current thread (blocking).
/// Sends HotkeyEvents to the provided channel.
pub fn listen(
hotkey: HotkeyCombination,
cancel_key: HotkeyCombination,
tx: mpsc::Sender<HotkeyEvent>,
) {
let debounce_duration = Duration::from_millis(30);
let mut last_event_time = Instant::now() - debounce_duration;
let mut modifier_state = ModifierState::default();
let mut hotkey_held = false;
info!("Hotkey listener started");
debug!("Hotkey: {:?}", hotkey);
debug!("Cancel: {:?}", cancel_key);
let callback = move |event: Event| {
let now = Instant::now();
match event.event_type {
EventType::KeyPress(key) => {
modifier_state.update(&key, true);
// Check cancel key
if key == cancel_key.key && modifier_state.all_held(&cancel_key.modifiers) {
if now.duration_since(last_event_time) >= debounce_duration {
last_event_time = now;
debug!("Cancel key pressed");
if tx.send(HotkeyEvent::Cancel).is_err() {
error!("Failed to send cancel event");
}
}
return;
}
// Check hotkey
if key == hotkey.key && modifier_state.all_held(&hotkey.modifiers) {
if now.duration_since(last_event_time) >= debounce_duration && !hotkey_held {
last_event_time = now;
hotkey_held = true;
debug!("Hotkey pressed");
if tx.send(HotkeyEvent::Pressed).is_err() {
error!("Failed to send pressed event");
}
}
}
}
EventType::KeyRelease(key) => {
modifier_state.update(&key, false);
// Check hotkey release (for push-to-talk)
if key == hotkey.key && hotkey_held {
if now.duration_since(last_event_time) >= debounce_duration {
last_event_time = now;
hotkey_held = false;
debug!("Hotkey released");
if tx.send(HotkeyEvent::Released).is_err() {
error!("Failed to send released event");
}
}
}
}
_ => {}
}
};
if let Err(e) = rdev::listen(callback) {
error!("Hotkey listener error: {:?}", e);
}
}
+82
View File
@@ -0,0 +1,82 @@
mod audio_feedback;
mod cli;
mod config;
mod coordinator;
mod hotkey;
mod model_cache;
mod overlay;
mod paste;
mod recorder;
mod transcriber;
mod vad;
use clap::{Parser, Subcommand};
#[derive(Parser)]
#[command(name = "mouth", version, about = "Offline speech-to-text with global hotkey and paste")]
struct Cli {
#[command(subcommand)]
command: Option<Commands>,
}
#[derive(Subcommand)]
enum Commands {
/// Start the mouth daemon
Run,
/// View or edit configuration
Config {
/// Print current config to stdout
#[arg(long)]
show: bool,
/// Reset config to defaults
#[arg(long)]
reset: bool,
},
/// Manage speech-to-text models
Models {
/// Download the configured model
#[arg(long)]
download: bool,
},
/// Show daemon status, loaded model, and version
Status,
}
fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();
let cli = Cli::parse();
match cli.command {
None | Some(Commands::Run) => cli::run_cmd::run(),
Some(Commands::Config { show, reset }) => {
if show {
cli::config_cmd::show()
} else if reset {
cli::config_cmd::reset()
} else {
cli::config_cmd::interactive()
}
}
Some(Commands::Models { download }) => {
if download {
cli::models_cmd::download()
} else {
cli::models_cmd::list()
}
}
Some(Commands::Status) => cli::status_cmd::status(),
}
}
+97
View File
@@ -0,0 +1,97 @@
use anyhow::{Context, Result};
use hf_hub::api::sync::Api;
use std::path::PathBuf;
use tracing::{debug, info};
/// Known model definitions.
pub struct ModelInfo {
pub repo_id: &'static str,
pub encoder_file: &'static str,
pub encoder_data_file: Option<&'static str>,
pub decoder_file: &'static str,
pub vocab_file: &'static str,
#[allow(dead_code)]
pub preprocessor_file: Option<&'static str>,
}
/// Get model info for a model name.
pub fn get_model_info(model_name: &str) -> Result<ModelInfo> {
match model_name {
"parakeet-tdt-0.6b-v3" => Ok(ModelInfo {
repo_id: "istupakov/parakeet-tdt-0.6b-v3-onnx",
encoder_file: "encoder-model.onnx",
encoder_data_file: Some("encoder-model.onnx.data"),
decoder_file: "decoder_joint-model.onnx",
vocab_file: "vocab.txt",
preprocessor_file: None,
}),
_ => anyhow::bail!("Unknown model: {model_name}. Supported models: parakeet-tdt-0.6b-v3"),
}
}
/// Paths to downloaded model files.
pub struct ModelPaths {
pub encoder: PathBuf,
pub decoder: PathBuf,
pub vocab: PathBuf,
}
/// Ensure model files are downloaded and return their paths.
/// Uses the standard HuggingFace cache directory.
pub fn ensure_model(model_name: &str) -> Result<ModelPaths> {
let model_info = get_model_info(model_name)?;
let api = Api::new().context("Failed to create HuggingFace Hub API")?;
let repo = api.model(model_info.repo_id.to_string());
info!("Ensuring model files for '{model_name}' from {}", model_info.repo_id);
// Download encoder
info!("Checking encoder: {}", model_info.encoder_file);
let encoder = repo
.get(model_info.encoder_file)
.with_context(|| format!("Failed to download {}", model_info.encoder_file))?;
debug!("Encoder: {}", encoder.display());
// Download encoder data file if present
if let Some(data_file) = model_info.encoder_data_file {
info!("Checking encoder data: {data_file}");
let data_path = repo
.get(data_file)
.with_context(|| format!("Failed to download {data_file}"))?;
debug!("Encoder data: {}", data_path.display());
}
// Download decoder
info!("Checking decoder: {}", model_info.decoder_file);
let decoder = repo
.get(model_info.decoder_file)
.with_context(|| format!("Failed to download {}", model_info.decoder_file))?;
debug!("Decoder: {}", decoder.display());
// Download vocab
info!("Checking vocab: {}", model_info.vocab_file);
let vocab = repo
.get(model_info.vocab_file)
.with_context(|| format!("Failed to download {}", model_info.vocab_file))?;
debug!("Vocab: {}", vocab.display());
Ok(ModelPaths {
encoder,
decoder,
vocab,
})
}
/// Check if model files are already cached.
pub fn is_model_cached(model_name: &str) -> bool {
ensure_model(model_name).is_ok()
}
/// List available models with their download status.
pub fn list_models() -> Vec<(&'static str, bool)> {
let models = ["parakeet-tdt-0.6b-v3"];
models
.iter()
.map(|name| (*name, is_model_cached(name)))
.collect()
}
+201
View File
@@ -0,0 +1,201 @@
use std::num::NonZeroU32;
use tracing::{debug, error, info, warn};
use winit::application::ApplicationHandler;
use winit::dpi::{LogicalSize, PhysicalPosition};
use winit::event::WindowEvent;
use winit::event_loop::{ActiveEventLoop, EventLoop, EventLoopProxy};
use winit::window::{Window, WindowAttributes, WindowId, WindowLevel};
use crate::config::OverlayPosition;
const OVERLAY_WIDTH: u32 = 200;
const OVERLAY_HEIGHT: u32 = 36;
/// State of the overlay display.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum OverlayState {
Hidden,
Recording,
Transcribing,
Done,
Error,
}
/// Events sent to the overlay from the coordinator.
#[derive(Debug, Clone)]
pub enum OverlayEvent {
SetState(OverlayState),
Shutdown,
}
/// The overlay application handler for winit.
struct OverlayApp {
window: Option<std::rc::Rc<Window>>,
surface: Option<softbuffer::Surface<std::rc::Rc<Window>, std::rc::Rc<Window>>>,
state: OverlayState,
position: OverlayPosition,
}
impl OverlayApp {
fn draw(&mut self) {
let Some(surface) = &mut self.surface else { return };
let Some(window) = &self.window else { return };
let size = window.inner_size();
if size.width == 0 || size.height == 0 {
return;
}
let Ok(w) = NonZeroU32::try_from(size.width) else { return };
let Ok(h) = NonZeroU32::try_from(size.height) else { return };
if surface.resize(w, h).is_err() {
return;
}
let Ok(mut buffer) = surface.buffer_mut() else { return };
let color = match self.state {
OverlayState::Hidden => 0x00000000,
OverlayState::Recording => 0xFFDD3333, // Red
OverlayState::Transcribing => 0xFFDDAA33, // Amber
OverlayState::Done => 0xFF33AA33, // Green
OverlayState::Error => 0xFFDD3333, // Red
};
let width = size.width as usize;
let height = size.height as usize;
for y in 0..height {
for x in 0..width {
let radius = 8;
let in_corner = (x < radius || x >= width - radius)
&& (y < radius || y >= height - radius);
let pixel = if in_corner {
let cx = if x < radius { radius } else { width - radius - 1 };
let cy = if y < radius { radius } else { height - radius - 1 };
let dx = x as i32 - cx as i32;
let dy = y as i32 - cy as i32;
if dx * dx + dy * dy <= (radius * radius) as i32 {
color
} else {
0x00000000
}
} else {
color
};
buffer[y * width + x] = pixel;
}
}
let _ = buffer.present();
}
fn update_visibility(&self) {
if let Some(window) = &self.window {
let visible = self.state != OverlayState::Hidden;
window.set_visible(visible);
}
}
}
impl ApplicationHandler<OverlayEvent> for OverlayApp {
fn resumed(&mut self, event_loop: &ActiveEventLoop) {
if self.window.is_some() {
return;
}
let attrs = WindowAttributes::default()
.with_title("Mouth")
.with_inner_size(LogicalSize::new(OVERLAY_WIDTH, OVERLAY_HEIGHT))
.with_resizable(false)
.with_decorations(false)
.with_transparent(true)
.with_window_level(WindowLevel::AlwaysOnTop)
.with_visible(false);
match event_loop.create_window(attrs) {
Ok(window) => {
let window = std::rc::Rc::new(window);
// Position at top center of primary monitor
if let Some(monitor) = window.current_monitor() {
let screen_size = monitor.size();
let pos = match self.position {
OverlayPosition::Top => PhysicalPosition::new(
(screen_size.width - OVERLAY_WIDTH) / 2,
10,
),
OverlayPosition::Bottom => PhysicalPosition::new(
(screen_size.width - OVERLAY_WIDTH) / 2,
screen_size.height - OVERLAY_HEIGHT - 50,
),
OverlayPosition::None => PhysicalPosition::new(0, 0),
};
window.set_outer_position(pos);
}
let context = softbuffer::Context::new(window.clone()).ok();
let surface = context.and_then(|ctx| {
softbuffer::Surface::new(&ctx, window.clone()).ok()
});
if surface.is_none() {
warn!("Could not create softbuffer surface — overlay rendering disabled");
}
self.surface = surface;
self.window = Some(window);
info!("Overlay window created");
}
Err(e) => {
error!("Failed to create overlay window: {e}");
}
}
}
fn user_event(&mut self, event_loop: &ActiveEventLoop, event: OverlayEvent) {
match event {
OverlayEvent::SetState(state) => {
debug!("Overlay state: {:?} -> {:?}", self.state, state);
self.state = state;
self.update_visibility();
self.draw();
}
OverlayEvent::Shutdown => {
info!("Overlay shutting down");
event_loop.exit();
}
}
}
fn window_event(&mut self, _event_loop: &ActiveEventLoop, _id: WindowId, event: WindowEvent) {
if let WindowEvent::RedrawRequested = event {
self.draw();
}
}
}
/// Create an event loop and return the proxy for sending events.
pub fn create_event_loop() -> Result<(EventLoop<OverlayEvent>, EventLoopProxy<OverlayEvent>), winit::error::EventLoopError> {
let event_loop: EventLoop<OverlayEvent> = EventLoop::with_user_event().build()?;
let proxy = event_loop.create_proxy();
Ok((event_loop, proxy))
}
/// Run the event loop with the given position config.
pub fn run_event_loop(
event_loop: EventLoop<OverlayEvent>,
position: OverlayPosition,
) -> Result<(), winit::error::EventLoopError> {
let mut app = OverlayApp {
window: None,
surface: None,
state: OverlayState::Hidden,
position,
};
event_loop.run_app(&mut app)
}
+71
View File
@@ -0,0 +1,71 @@
use anyhow::{Context, Result};
use arboard::Clipboard;
use enigo::{Direction, Enigo, Key, Keyboard, Settings};
use std::thread;
use std::time::Duration;
use tracing::{debug, info};
use crate::config::PasteMethod;
/// Paste text using the configured method.
pub fn paste_text(text: &str, method: &PasteMethod, keep_on_clipboard: bool) -> Result<()> {
let mut clipboard = Clipboard::new().context("Failed to open clipboard")?;
// Save current clipboard content if we need to restore it
let previous = if !keep_on_clipboard {
clipboard.get_text().ok()
} else {
None
};
// Set text to clipboard
clipboard
.set_text(text.to_string())
.context("Failed to set clipboard text")?;
debug!("Text set to clipboard ({} chars)", text.len());
if *method == PasteMethod::ClipboardOnly {
info!("Text copied to clipboard (clipboard_only mode)");
return Ok(());
}
// Small delay to ensure clipboard is ready
thread::sleep(Duration::from_millis(50));
// Simulate paste keystroke
let mut enigo = Enigo::new(&Settings::default()).context("Failed to create enigo instance")?;
match method {
PasteMethod::CtrlV => {
debug!("Pasting via Ctrl+V");
enigo.key(Key::Control, Direction::Press)?;
enigo.key(Key::Unicode('v'), Direction::Click)?;
enigo.key(Key::Control, Direction::Release)?;
}
PasteMethod::ShiftInsert => {
debug!("Pasting via Shift+Insert");
enigo.key(Key::Shift, Direction::Press)?;
enigo.key(Key::Other(0x2D), Direction::Click)?; // VK_INSERT on Windows
enigo.key(Key::Shift, Direction::Release)?;
}
PasteMethod::CtrlShiftV => {
debug!("Pasting via Ctrl+Shift+V");
enigo.key(Key::Control, Direction::Press)?;
enigo.key(Key::Shift, Direction::Press)?;
enigo.key(Key::Unicode('v'), Direction::Click)?;
enigo.key(Key::Shift, Direction::Release)?;
enigo.key(Key::Control, Direction::Release)?;
}
PasteMethod::ClipboardOnly => unreachable!(),
}
// Restore previous clipboard content if needed
if let Some(prev) = previous {
thread::sleep(Duration::from_millis(100));
let _ = clipboard.set_text(prev);
debug!("Previous clipboard content restored");
}
info!("Text pasted ({} chars)", text.len());
Ok(())
}
+265
View File
@@ -0,0 +1,265 @@
use anyhow::{Context, Result};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use cpal::{Device, SampleFormat, SampleRate, StreamConfig};
use rubato::{FftFixedIn, Resampler};
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use tracing::{debug, error, info, warn};
const TARGET_SAMPLE_RATE: u32 = 16000;
/// Commands sent to the recorder.
#[derive(Debug)]
pub enum RecorderCommand {
Start,
Stop,
Shutdown,
}
/// Audio data produced by the recorder.
#[derive(Debug)]
pub struct AudioData {
/// Mono f32 samples at 16kHz
pub samples: Vec<f32>,
#[allow(dead_code)]
pub sample_rate: u32,
}
/// Find the audio input device by name, or use the default.
fn get_input_device(name: Option<&str>) -> Result<Device> {
let host = cpal::default_host();
if let Some(name) = name {
let devices = host.input_devices().context("Failed to enumerate input devices")?;
for device in devices {
if let Ok(dev_name) = device.name() {
if dev_name.to_lowercase().contains(&name.to_lowercase()) {
info!("Using audio input device: {dev_name}");
return Ok(device);
}
}
}
warn!("Audio device '{name}' not found, falling back to default");
}
host.default_input_device().context("No default input device available")
}
/// Convert interleaved multi-channel samples to mono f32.
fn to_mono_f32(data: &[f32], channels: u16) -> Vec<f32> {
if channels == 1 {
return data.to_vec();
}
data.chunks(channels as usize)
.map(|frame| frame.iter().sum::<f32>() / channels as f32)
.collect()
}
/// Resample audio from source rate to target rate.
fn resample(samples: &[f32], from_rate: u32, to_rate: u32) -> Result<Vec<f32>> {
if from_rate == to_rate {
return Ok(samples.to_vec());
}
let chunk_size = 1024;
let mut resampler = FftFixedIn::<f32>::new(
from_rate as usize,
to_rate as usize,
chunk_size,
1, // sub-chunks
1, // mono
)?;
let mut output = Vec::new();
let mut pos = 0;
while pos + chunk_size <= samples.len() {
let chunk = &samples[pos..pos + chunk_size];
let result = resampler.process(&[chunk], None)?;
output.extend_from_slice(&result[0]);
pos += chunk_size;
}
// Handle remaining samples by padding with zeros
if pos < samples.len() {
let mut last_chunk = samples[pos..].to_vec();
last_chunk.resize(chunk_size, 0.0);
let result = resampler.process(&[&last_chunk], None)?;
let remaining_ratio = (samples.len() - pos) as f32 / chunk_size as f32;
let take = (result[0].len() as f32 * remaining_ratio) as usize;
output.extend_from_slice(&result[0][..take]);
}
Ok(output)
}
/// Run the audio recorder in a loop, responding to commands.
/// This should be called from a dedicated thread.
pub fn run(
device_name: Option<String>,
cmd_rx: mpsc::Receiver<RecorderCommand>,
audio_tx: mpsc::Sender<AudioData>,
) {
let device = match get_input_device(device_name.as_deref()) {
Ok(d) => d,
Err(e) => {
error!("Failed to get audio input device: {e}");
return;
}
};
let dev_name = device.name().unwrap_or_else(|_| "unknown".into());
info!("Audio recorder using device: {dev_name}");
let config = match device.default_input_config() {
Ok(c) => c,
Err(e) => {
error!("Failed to get default input config: {e}");
return;
}
};
let source_sample_rate = config.sample_rate().0;
let channels = config.channels();
debug!("Input config: {source_sample_rate}Hz, {channels}ch, {:?}", config.sample_format());
loop {
// Wait for a Start command
match cmd_rx.recv() {
Ok(RecorderCommand::Start) => {
debug!("Recording started");
}
Ok(RecorderCommand::Shutdown) => {
info!("Recorder shutting down");
return;
}
Ok(RecorderCommand::Stop) => {
// Ignore stop when not recording
continue;
}
Err(_) => {
info!("Recorder channel closed, shutting down");
return;
}
}
// Record until Stop or Shutdown
let buffer: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));
let buffer_clone = Arc::clone(&buffer);
let ch = channels;
let stream_config = StreamConfig {
channels,
sample_rate: SampleRate(source_sample_rate),
buffer_size: cpal::BufferSize::Default,
};
let stream = match config.sample_format() {
SampleFormat::F32 => device.build_input_stream(
&stream_config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
let mono = to_mono_f32(data, ch);
if let Ok(mut buf) = buffer_clone.lock() {
buf.extend_from_slice(&mono);
}
},
|err| error!("Audio stream error: {err}"),
None,
),
SampleFormat::I16 => {
let buffer_clone = Arc::clone(&buffer);
device.build_input_stream(
&stream_config,
move |data: &[i16], _: &cpal::InputCallbackInfo| {
let f32_data: Vec<f32> = data.iter().map(|&s| s as f32 / i16::MAX as f32).collect();
let mono = to_mono_f32(&f32_data, ch);
if let Ok(mut buf) = buffer_clone.lock() {
buf.extend_from_slice(&mono);
}
},
|err| error!("Audio stream error: {err}"),
None,
)
}
format => {
error!("Unsupported sample format: {format:?}");
continue;
}
};
let stream = match stream {
Ok(s) => s,
Err(e) => {
error!("Failed to build input stream: {e}");
continue;
}
};
if let Err(e) = stream.play() {
error!("Failed to start audio stream: {e}");
continue;
}
// Wait for Stop or Shutdown
let should_shutdown = loop {
match cmd_rx.recv() {
Ok(RecorderCommand::Stop) => {
debug!("Recording stopped");
break false;
}
Ok(RecorderCommand::Shutdown) => {
info!("Recorder shutting down during recording");
break true;
}
Ok(RecorderCommand::Start) => {
// Ignore duplicate start
continue;
}
Err(_) => {
break true;
}
}
};
// Stop the stream
drop(stream);
if should_shutdown {
return;
}
// Get recorded audio
let raw_samples = {
let buf = buffer.lock().unwrap();
buf.clone()
};
if raw_samples.is_empty() {
warn!("No audio recorded");
continue;
}
debug!("Recorded {} samples at {}Hz", raw_samples.len(), source_sample_rate);
// Resample to 16kHz
let samples = match resample(&raw_samples, source_sample_rate, TARGET_SAMPLE_RATE) {
Ok(s) => s,
Err(e) => {
error!("Failed to resample audio: {e}");
continue;
}
};
debug!("Resampled to {} samples at {}Hz", samples.len(), TARGET_SAMPLE_RATE);
let audio = AudioData {
samples,
sample_rate: TARGET_SAMPLE_RATE,
};
if audio_tx.send(audio).is_err() {
error!("Failed to send audio data");
return;
}
}
}
+424
View File
@@ -0,0 +1,424 @@
use anyhow::{Context, Result};
use ndarray::{Array2, Array3};
use ort::session::{Session, SessionInputValue};
use ort::value::Value;
use std::borrow::Cow;
use std::path::Path;
use tracing::{debug, info};
use crate::config::Accelerator;
use crate::model_cache::ModelPaths;
// Audio preprocessing constants (NeMo-style)
const SAMPLE_RATE: usize = 16000;
const N_FFT: usize = 512;
const WIN_LENGTH: usize = 400;
const HOP_LENGTH: usize = 160;
const N_MELS: usize = 128;
const PRE_EMPHASIS: f32 = 0.97;
/// The transcription engine.
pub struct Transcriber {
encoder: Session,
decoder: Session,
vocab: Vec<String>,
blank_id: i64,
vocab_size: usize,
}
fn make_input<'a, V: Into<Value>>(name: &'static str, value: V) -> (Cow<'a, str>, SessionInputValue<'a>) {
(Cow::Borrowed(name), SessionInputValue::Owned(value.into().into_dyn()))
}
impl Transcriber {
/// Load the transcription model from the given paths.
pub fn new(paths: &ModelPaths, accelerator: &Accelerator, gpu_device: u32) -> Result<Self> {
info!("Loading transcription model...");
let encoder = build_session(&paths.encoder, accelerator, gpu_device)
.context("Failed to load encoder")?;
info!("Encoder loaded");
let decoder = build_session(&paths.decoder, accelerator, gpu_device)
.context("Failed to load decoder")?;
info!("Decoder loaded");
let vocab = load_vocab(&paths.vocab)?;
let vocab_size = vocab.len();
let blank_id = (vocab_size - 1) as i64; // <blk> is the last token
info!("Vocab loaded: {vocab_size} tokens, blank_id={blank_id}");
Ok(Self {
encoder,
decoder,
vocab,
blank_id,
vocab_size,
})
}
/// Transcribe audio samples (mono f32 at 16kHz).
pub fn transcribe(&mut self, samples: &[f32]) -> Result<String> {
if samples.is_empty() {
return Ok(String::new());
}
let duration = samples.len() as f32 / SAMPLE_RATE as f32;
info!("Transcribing {duration:.1}s of audio...");
// Step 1: Compute mel spectrogram
let features = compute_mel_spectrogram(samples);
let num_frames = features.ncols();
debug!("Mel spectrogram: {N_MELS}x{num_frames} frames");
// Step 2: Run encoder
let (encoder_output, feat_dim, encoded_length) = self.run_encoder(&features)?;
debug!("Encoded: feat_dim={feat_dim}, length={encoded_length}");
// Step 3: TDT greedy decode
let tokens = self.tdt_greedy_decode(&encoder_output, feat_dim, encoded_length)?;
debug!("Decoded {} tokens", tokens.len());
// Step 4: Convert tokens to text
let text = self.tokens_to_text(&tokens);
info!("Transcription: \"{text}\"");
Ok(text)
}
fn run_encoder(&mut self, features: &Array2<f32>) -> Result<(Vec<f32>, usize, usize)> {
let num_frames = features.ncols();
// Shape: [batch=1, n_mels=128, time_frames]
let input = features
.clone()
.into_shape_with_order((1, N_MELS, num_frames))?;
let length_data = ndarray::Array1::from_vec(vec![num_frames as i64]);
let input_value = Value::from_array(input)?;
let length_value = Value::from_array(length_data)?;
let outputs = self.encoder.run(vec![
make_input("audio_signal", input_value.into_dyn()),
make_input("length", length_value.into_dyn()),
])?;
let (enc_shape, enc_data) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("Failed to extract encoder output: {e}"))?;
let (_, len_data) = outputs[1]
.try_extract_tensor::<i64>()
.map_err(|e| anyhow::anyhow!("Failed to extract encoded lengths: {e}"))?;
let encoded_length = len_data[0] as usize;
// Shape: [1, feat_dim, encoded_time]
let dims: Vec<usize> = enc_shape.iter().map(|&d| d as usize).collect();
let feat_dim = if dims.len() == 3 { dims[1] } else { dims[dims.len() - 1] };
debug!("Encoder output shape: {:?}", dims);
Ok((enc_data.to_vec(), feat_dim, encoded_length))
}
fn tdt_greedy_decode(&mut self, encoder_output: &[f32], feat_dim: usize, encoded_length: usize) -> Result<Vec<i64>> {
// Determine decoder LSTM state dimensions by inspecting input metadata
// Default fallback values
let mut state_shape: [usize; 3] = [1, 1, 640];
for input in self.decoder.inputs() {
let name = input.name();
if name == "input_states_1" || name == "input_states_2" {
let dtype = input.dtype();
// Try to extract shape from the ValueType
if let ort::value::ValueType::Tensor { shape, .. } = dtype {
if shape.len() == 3 {
// Use known dims (skip dynamic ones which are -1)
for (i, &dim) in shape.iter().enumerate() {
if dim > 0 {
state_shape[i] = dim as usize;
}
}
}
}
break; // both states have same shape
}
}
debug!("Decoder state shape: {:?}", state_shape);
let mut state1 = Array3::<f32>::zeros(state_shape);
let mut state2 = Array3::<f32>::zeros(state_shape);
let mut prev_token = self.blank_id;
let mut tokens = Vec::new();
let mut t = 0;
let max_steps = encoded_length * 10;
let mut step = 0;
while t < encoded_length && step < max_steps {
step += 1;
// Extract current encoder frame: shape [1, feat_dim, 1]
let mut frame_data = vec![0.0f32; feat_dim];
for f in 0..feat_dim {
// encoder_output layout: [1, feat_dim, time] -> flat index
frame_data[f] = encoder_output[f * encoded_length + t];
}
let frame = Array3::from_shape_vec([1, feat_dim, 1], frame_data)?;
let targets = ndarray::Array2::from_shape_vec((1, 1), vec![prev_token])?;
let target_length = ndarray::Array1::from_vec(vec![1i64]);
let outputs = self.decoder.run(vec![
make_input("encoder_outputs", Value::from_array(frame)?.into_dyn()),
make_input("targets", Value::from_array(targets)?.into_dyn()),
make_input("target_length", Value::from_array(target_length)?.into_dyn()),
make_input("input_states_1", Value::from_array(state1.clone())?.into_dyn()),
make_input("input_states_2", Value::from_array(state2.clone())?.into_dyn()),
])?;
let (_, output_data) = outputs["outputs"]
.try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("Failed to extract decoder output: {e}"))?;
// Split into token logits and duration logits
let token_logits = &output_data[..self.vocab_size];
let duration_logits = &output_data[self.vocab_size..];
let token_id = argmax(token_logits) as i64;
let duration = if !duration_logits.is_empty() {
argmax(duration_logits)
} else {
1
};
if token_id != self.blank_id {
tokens.push(token_id);
prev_token = token_id;
// Update states
let (s1_shape, s1_data) = outputs["output_states_1"]
.try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("Failed to extract state 1: {e}"))?;
let s1_dims: Vec<usize> = s1_shape.iter().map(|&d| d as usize).collect();
if s1_dims.len() == 3 {
state1 = Array3::from_shape_vec([s1_dims[0], s1_dims[1], s1_dims[2]], s1_data.to_vec())?;
}
let (s2_shape, s2_data) = outputs["output_states_2"]
.try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("Failed to extract state 2: {e}"))?;
let s2_dims: Vec<usize> = s2_shape.iter().map(|&d| d as usize).collect();
if s2_dims.len() == 3 {
state2 = Array3::from_shape_vec([s2_dims[0], s2_dims[1], s2_dims[2]], s2_data.to_vec())?;
}
}
if duration > 0 {
t += duration;
} else if token_id == self.blank_id {
t += 1;
}
}
Ok(tokens)
}
fn tokens_to_text(&self, tokens: &[i64]) -> String {
let mut text = String::new();
for &token_id in tokens {
if token_id >= 0 && (token_id as usize) < self.vocab.len() {
let token = &self.vocab[token_id as usize];
if token.starts_with('<') && token.ends_with('>') {
continue;
}
text.push_str(token);
}
}
text = text.replace('\u{2581}', " ");
let text = text.trim().to_string();
let mut result = String::with_capacity(text.len());
let mut prev_space = false;
for ch in text.chars() {
if ch == ' ' {
if !prev_space {
result.push(ch);
}
prev_space = true;
} else {
result.push(ch);
prev_space = false;
}
}
result
}
}
fn build_session(model_path: &Path, accelerator: &Accelerator, _gpu_device: u32) -> Result<Session> {
let mut builder = Session::builder()
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {e}"))?
.with_intra_threads(num_cpus::get().min(4))
.map_err(|e| anyhow::anyhow!("Failed to set thread count: {e}"))?;
match accelerator {
Accelerator::Auto | Accelerator::Cpu => {}
Accelerator::Cuda => {
info!("CUDA requested — falling back to CPU (CUDA EP not yet configured)");
}
Accelerator::DirectMl => {
info!("DirectML requested — falling back to CPU (DirectML EP not yet configured)");
}
}
let session = builder
.commit_from_file(model_path)
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model from {}: {e}", model_path.display()))?;
Ok(session)
}
fn load_vocab(path: &Path) -> Result<Vec<String>> {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read vocab from {}", path.display()))?;
let mut vocab = Vec::new();
for line in content.lines() {
let token = line.split_whitespace().next().unwrap_or("").to_string();
vocab.push(token);
}
Ok(vocab)
}
fn argmax(slice: &[f32]) -> usize {
slice
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
// --- Mel spectrogram computation ---
fn compute_mel_spectrogram(samples: &[f32]) -> Array2<f32> {
let mut emphasized = Vec::with_capacity(samples.len());
emphasized.push(samples[0]);
for i in 1..samples.len() {
emphasized.push(samples[i] - PRE_EMPHASIS * samples[i - 1]);
}
let num_frames = (emphasized.len().saturating_sub(WIN_LENGTH)) / HOP_LENGTH + 1;
let fft_size = N_FFT / 2 + 1;
let window: Vec<f32> = (0..WIN_LENGTH)
.map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / (WIN_LENGTH - 1) as f32).cos()))
.collect();
let mel_bank = create_mel_filterbank(SAMPLE_RATE as f32, N_FFT, N_MELS);
let mut mel_spec = Array2::<f32>::zeros((N_MELS, num_frames));
for frame_idx in 0..num_frames {
let start = frame_idx * HOP_LENGTH;
let mut windowed = vec![0.0f32; N_FFT];
for i in 0..WIN_LENGTH {
if start + i < emphasized.len() {
windowed[i] = emphasized[start + i] * window[i];
}
}
let power_spectrum = compute_power_spectrum(&windowed, fft_size);
for mel_idx in 0..N_MELS {
let mut energy = 0.0f32;
for k in 0..fft_size {
energy += mel_bank[mel_idx][k] * power_spectrum[k];
}
mel_spec[[mel_idx, frame_idx]] = (energy + 2.0f32.powi(-24)).ln();
}
}
// Per-utterance CMVN
for mel_idx in 0..N_MELS {
let row = mel_spec.row(mel_idx);
let mean = row.mean().unwrap_or(0.0);
let variance = row.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / num_frames as f32;
let std = (variance + 1e-5).sqrt();
for frame_idx in 0..num_frames {
mel_spec[[mel_idx, frame_idx]] = (mel_spec[[mel_idx, frame_idx]] - mean) / std;
}
}
mel_spec
}
fn compute_power_spectrum(signal: &[f32], output_size: usize) -> Vec<f32> {
let n = signal.len();
let mut spectrum = Vec::with_capacity(output_size);
for k in 0..output_size {
let mut real = 0.0f32;
let mut imag = 0.0f32;
for (t, &sample) in signal.iter().enumerate() {
let angle = -2.0 * std::f32::consts::PI * k as f32 * t as f32 / n as f32;
real += sample * angle.cos();
imag += sample * angle.sin();
}
spectrum.push(real * real + imag * imag);
}
spectrum
}
fn create_mel_filterbank(sample_rate: f32, n_fft: usize, n_mels: usize) -> Vec<Vec<f32>> {
let fft_size = n_fft / 2 + 1;
let mel_low = hz_to_mel(0.0);
let mel_high = hz_to_mel(sample_rate / 2.0);
let mel_points: Vec<f32> = (0..n_mels + 2)
.map(|i| mel_low + (mel_high - mel_low) * i as f32 / (n_mels + 1) as f32)
.collect();
let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
let bin_points: Vec<usize> = hz_points
.iter()
.map(|&hz| ((n_fft as f32 + 1.0) * hz / sample_rate).floor() as usize)
.collect();
let mut filterbank = vec![vec![0.0f32; fft_size]; n_mels];
for m in 0..n_mels {
let left = bin_points[m];
let center = bin_points[m + 1];
let right = bin_points[m + 2];
for k in left..center {
if center > left {
filterbank[m][k] = (k - left) as f32 / (center - left) as f32;
}
}
for k in center..right {
if right > center {
filterbank[m][k] = (right - k) as f32 / (right - center) as f32;
}
}
}
filterbank
}
fn hz_to_mel(hz: f32) -> f32 {
2595.0 * (1.0 + hz / 700.0).log10()
}
fn mel_to_hz(mel: f32) -> f32 {
700.0 * (10.0f32.powf(mel / 2595.0) - 1.0)
}
+149
View File
@@ -0,0 +1,149 @@
use anyhow::Result;
use ndarray::{Array1, Array2, Array3};
use ort::session::{Session, SessionInputValue};
use ort::value::Value;
use std::borrow::Cow;
use tracing::{debug, info};
const SILERO_SAMPLE_RATE: i64 = 16000;
const WINDOW_SIZE: usize = 512; // 32ms at 16kHz
const THRESHOLD: f32 = 0.5;
const MIN_SPEECH_DURATION_MS: usize = 250;
const MIN_SILENCE_DURATION_MS: usize = 300;
const SPEECH_PAD_MS: usize = 100;
fn make_input<'a, V: Into<Value>>(name: &'static str, value: V) -> (Cow<'a, str>, SessionInputValue<'a>) {
(Cow::Borrowed(name), SessionInputValue::Owned(value.into().into_dyn()))
}
/// Voice Activity Detector using Silero VAD v4.
pub struct Vad {
session: Session,
}
impl Vad {
/// Create a new VAD instance from an ONNX model file.
pub fn new(model_path: &str) -> Result<Self> {
let session = Session::builder()
.map_err(|e| anyhow::anyhow!("Failed to create VAD session builder: {e}"))?
.with_intra_threads(1)
.map_err(|e| anyhow::anyhow!("Failed to set VAD threads: {e}"))?
.commit_from_file(model_path)
.map_err(|e| anyhow::anyhow!("Failed to load VAD model from {model_path}: {e}"))?;
info!("Silero VAD loaded from {model_path}");
Ok(Self { session })
}
/// Filter audio to keep only speech segments.
pub fn filter_speech(&mut self, samples: &[f32]) -> Result<Vec<f32>> {
if samples.is_empty() {
return Ok(Vec::new());
}
let probabilities = self.get_speech_probabilities(samples)?;
let segments = self.find_speech_segments(&probabilities, samples.len());
if segments.is_empty() {
debug!("VAD: no speech detected");
return Ok(Vec::new());
}
let mut result = Vec::new();
for (start, end) in &segments {
debug!("VAD: speech segment {start}..{end} ({:.1}s..{:.1}s)",
*start as f32 / SILERO_SAMPLE_RATE as f32,
*end as f32 / SILERO_SAMPLE_RATE as f32);
result.extend_from_slice(&samples[*start..*end]);
}
let original_duration = samples.len() as f32 / SILERO_SAMPLE_RATE as f32;
let filtered_duration = result.len() as f32 / SILERO_SAMPLE_RATE as f32;
debug!("VAD: {original_duration:.1}s -> {filtered_duration:.1}s ({:.0}% kept)",
filtered_duration / original_duration * 100.0);
Ok(result)
}
fn get_speech_probabilities(&mut self, samples: &[f32]) -> Result<Vec<f32>> {
let mut probabilities = Vec::new();
let mut state = Array3::<f32>::zeros((2, 1, 128));
let sr = Array1::from_vec(vec![SILERO_SAMPLE_RATE]);
for chunk in samples.chunks(WINDOW_SIZE) {
let mut window = chunk.to_vec();
if window.len() < WINDOW_SIZE {
window.resize(WINDOW_SIZE, 0.0);
}
let input = Array2::from_shape_vec((1, WINDOW_SIZE), window)?;
let outputs = self.session.run(vec![
make_input("input", Value::from_array(input)?.into_dyn()),
make_input("state", Value::from_array(state.clone())?.into_dyn()),
make_input("sr", Value::from_array(sr.clone())?.into_dyn()),
])?;
let (_, output_data) = outputs["output"]
.try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("Failed to extract VAD output: {e}"))?;
probabilities.push(output_data[0]);
let (state_shape, state_data) = outputs["stateN"]
.try_extract_tensor::<f32>()
.map_err(|e| anyhow::anyhow!("Failed to extract VAD state: {e}"))?;
let dims: Vec<usize> = state_shape.iter().map(|&d| d as usize).collect();
if dims.len() == 3 {
state = Array3::from_shape_vec([dims[0], dims[1], dims[2]], state_data.to_vec())?;
}
}
Ok(probabilities)
}
fn find_speech_segments(&self, probabilities: &[f32], total_samples: usize) -> Vec<(usize, usize)> {
let samples_per_window = WINDOW_SIZE;
let min_speech_windows = MIN_SPEECH_DURATION_MS * SILERO_SAMPLE_RATE as usize / 1000 / samples_per_window;
let min_silence_windows = MIN_SILENCE_DURATION_MS * SILERO_SAMPLE_RATE as usize / 1000 / samples_per_window;
let pad_samples = SPEECH_PAD_MS * SILERO_SAMPLE_RATE as usize / 1000;
let mut segments = Vec::new();
let mut in_speech = false;
let mut speech_start = 0;
let mut silence_count = 0;
let mut speech_count = 0;
for (i, &prob) in probabilities.iter().enumerate() {
if prob >= THRESHOLD {
if !in_speech {
speech_start = i;
speech_count = 0;
}
in_speech = true;
speech_count += 1;
silence_count = 0;
} else if in_speech {
silence_count += 1;
if silence_count >= min_silence_windows {
if speech_count >= min_speech_windows {
let start = (speech_start * samples_per_window).saturating_sub(pad_samples);
let end = ((i - silence_count + 1) * samples_per_window + pad_samples).min(total_samples);
segments.push((start, end));
}
in_speech = false;
silence_count = 0;
speech_count = 0;
}
}
}
if in_speech && speech_count >= min_speech_windows {
let start = (speech_start * samples_per_window).saturating_sub(pad_samples);
let end = total_samples;
segments.push((start, end));
}
segments
}
}