Implement core speech-to-text pipeline
All major components: hotkey listener (rdev), audio capture (cpal), resampling (rubato), VAD (Silero ONNX), Parakeet v3 TDT transcription (ort), overlay window (winit+softbuffer), paste simulation (enigo+arboard), audio feedback (rodio), YAML config, CLI with clap, HuggingFace model download. ~2400 lines of Rust across 16 source files. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1 @@
|
||||
/target
|
||||
@@ -0,0 +1,50 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
Mouth is a single-binary, offline speech-to-text tool. Press a global hotkey, speak, and transcribed text is pasted at your cursor. Configured via YAML, no UI. Primary target is Windows; Linux/macOS supported where possible.
|
||||
|
||||
Uses Parakeet TDT 0.6B v3 (ONNX, from `istupakov/parakeet-tdt-0.6b-v3-onnx`) for transcription, Silero VAD v4 for voice activity detection.
|
||||
|
||||
## Build & Run
|
||||
|
||||
```bash
|
||||
cargo build # debug build
|
||||
cargo build --release # release build
|
||||
cargo run # run daemon (default command)
|
||||
cargo run -- config --show # show current config
|
||||
cargo run -- config # interactive config TUI
|
||||
cargo run -- config --reset # reset to defaults
|
||||
cargo run -- models # list models
|
||||
cargo run -- models --download # download configured model
|
||||
cargo run -- status # daemon status
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
Single-binary Rust application. Core pipeline: hotkey capture (rdev) → audio recording (cpal) → resampling to 16kHz (rubato) → VAD (Silero ONNX) → mel spectrogram → transcription (Parakeet v3 TDT decoder via ort) → clipboard/paste (arboard + enigo). Minimal native overlay window (winit + softbuffer).
|
||||
|
||||
**Threading model:** Main thread owns the overlay window event loop (required by winit). Background threads: hotkey listener (rdev::listen is blocking), audio recorder (cpal stream), coordinator (state machine). All communicate via `std::sync::mpsc` channels.
|
||||
|
||||
**Coordinator state machine:** Idle → Recording → Transcribing → (Pasting) → Idle. Cancel from Recording returns to Idle.
|
||||
|
||||
**Parakeet v3 inference:** Two-stage ONNX model — encoder (FastConformer) produces features, decoder+joint (TDT transducer) greedily decodes tokens with duration predictions. Audio preprocessing: pre-emphasis → STFT → 128-band log-mel → per-utterance CMVN. Vocab is SentencePiece BPE with `▁` as word boundary marker.
|
||||
|
||||
**ort crate (v2.0.0-rc.12) notes:** Session::run needs `&mut self`. Input values must be converted to `Value::into_dyn()` before passing. Use `SessionInputValue::Owned(value.into_dyn())` pattern. `try_extract_tensor` returns `(&Shape, &[T])` tuple. `from_shape_vec` needs `[usize; N]` not `Vec<usize>`.
|
||||
|
||||
Config lives at `~/.config/mouth/config.yaml` (Linux/macOS) or `%APPDATA%\mouth\config.yaml` (Windows). Models cached via HuggingFace Hub standard cache (`~/.cache/huggingface/hub/`).
|
||||
|
||||
## Cross-Compilation
|
||||
|
||||
Developing on Ubuntu 24.04, targeting Windows:
|
||||
```bash
|
||||
cargo build --target x86_64-pc-windows-gnu
|
||||
```
|
||||
|
||||
## System Dependencies (Ubuntu)
|
||||
|
||||
```bash
|
||||
sudo apt-get install libssl-dev libasound2-dev libpulse-dev libx11-dev libxcb-shape0-dev libxcb-xfixes0-dev libxkbcommon-dev libwayland-dev libgtk-3-dev libxtst-dev libxdo-dev cmake
|
||||
```
|
||||
Generated
+4950
File diff suppressed because it is too large
Load Diff
+61
@@ -0,0 +1,61 @@
|
||||
[package]
|
||||
name = "mouth"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
description = "Offline speech-to-text with global hotkey and paste"
|
||||
|
||||
[dependencies]
|
||||
# CLI
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
|
||||
# Config
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_yaml = "0.9"
|
||||
dirs = "6"
|
||||
|
||||
# Interactive config TUI
|
||||
dialoguer = "0.11"
|
||||
|
||||
# Logging
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
|
||||
# Async
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
# Global hotkey
|
||||
rdev = "0.5"
|
||||
|
||||
# Audio capture
|
||||
cpal = "0.15"
|
||||
|
||||
# Audio resampling
|
||||
rubato = "0.16"
|
||||
|
||||
# ONNX inference (Parakeet v3 + Silero VAD)
|
||||
ort = { version = "2.0.0-rc.12", features = ["download-binaries"] }
|
||||
ndarray = "0.17"
|
||||
|
||||
# Model download from HuggingFace
|
||||
hf-hub = "0.4"
|
||||
indicatif = "0.17"
|
||||
|
||||
# Clipboard
|
||||
arboard = "3"
|
||||
|
||||
# Keyboard simulation
|
||||
enigo = { version = "0.3", features = ["serde"] }
|
||||
|
||||
# Overlay window
|
||||
winit = "0.30"
|
||||
softbuffer = "0.4"
|
||||
|
||||
# Audio feedback
|
||||
rodio = "0.20"
|
||||
|
||||
# System info
|
||||
num_cpus = "1"
|
||||
|
||||
# Error handling
|
||||
anyhow = "1"
|
||||
thiserror = "2"
|
||||
@@ -0,0 +1,42 @@
|
||||
# Mouth configuration
|
||||
# Copy to ~/.config/mouth/config.yaml (Linux/macOS)
|
||||
# or %APPDATA%\mouth\config.yaml (Windows)
|
||||
|
||||
# Hotkey to activate recording
|
||||
hotkey: ctrl+space
|
||||
|
||||
# Recording mode: push_to_talk or toggle
|
||||
mode: push_to_talk
|
||||
|
||||
# Cancel hotkey (only active while recording)
|
||||
cancel_key: escape
|
||||
|
||||
# Speech-to-text model
|
||||
model: parakeet-tdt-0.6b-v3
|
||||
|
||||
# Inference accelerator: auto, cpu, cuda, directml
|
||||
accelerator: auto
|
||||
|
||||
# GPU device index (when accelerator is cuda/directml)
|
||||
gpu_device: 0
|
||||
|
||||
# How to paste text: ctrl_v, shift_insert, ctrl_shift_v, clipboard_only
|
||||
paste_method: ctrl_v
|
||||
|
||||
# Keep transcribed text on clipboard after pasting
|
||||
copy_to_clipboard: true
|
||||
|
||||
# Overlay position: top, bottom, none
|
||||
overlay_position: top
|
||||
|
||||
# Play audio feedback sounds
|
||||
audio_feedback: true
|
||||
|
||||
# Audio input device name (null = system default)
|
||||
input_device: null
|
||||
|
||||
# Voice activity detection (trim silence)
|
||||
vad_enabled: true
|
||||
|
||||
# Language hint for model
|
||||
language: en
|
||||
@@ -0,0 +1,287 @@
|
||||
# Mouth — Implementation Plan
|
||||
|
||||
## Overview
|
||||
|
||||
Mouth is a single-binary, offline speech-to-text tool for Windows (with Linux/macOS support where possible). Press a hotkey, speak, and transcribed text is pasted at your cursor. Configured entirely via YAML.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────┐ ┌───────────┐ ┌─────────────┐ ┌────────────┐
|
||||
│ Hotkey │────▶│ Recorder │────▶│ Transcriber │────▶│ Paste │
|
||||
│ Listener │ │ (cpal) │ │ (ort/ONNX) │ │ (enigo) │
|
||||
│ (rdev) │ │ │ │ │ │ │
|
||||
└─────────────┘ └───────────┘ └─────────────┘ └────────────┘
|
||||
│ │ │ │
|
||||
│ ▼ │ │
|
||||
│ ┌───────────┐ │ │
|
||||
│ │ VAD │ │ │
|
||||
│ │ (silero) │ │ │
|
||||
│ └───────────┘ │ │
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌──────────────────────────────────────────────────────────────────────┐
|
||||
│ Overlay (winit) │
|
||||
│ State: idle → recording → transcribing → done │
|
||||
└──────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Component Communication
|
||||
|
||||
All components communicate via channels (`std::sync::mpsc` or `tokio::sync`). The main thread owns the overlay window (required by most windowing systems). A coordinator task receives events from hotkey/recorder/transcriber and drives state transitions.
|
||||
|
||||
```
|
||||
HotkeyEvent(Pressed/Released) ──┐
|
||||
AudioReady(Vec<f32>) ───────────┼──▶ Coordinator ──▶ OverlayState
|
||||
TranscriptionDone(String) ──────┘ ──▶ PasteAction
|
||||
CancelRequested ────────────────┘
|
||||
```
|
||||
|
||||
## Crate Dependencies
|
||||
|
||||
| Crate | Purpose | Notes |
|
||||
|-------|---------|-------|
|
||||
| `rdev` | Global hotkey capture | Cross-platform key events, no focus required |
|
||||
| `cpal` | Audio capture | Cross-platform mic input |
|
||||
| `rubato` | Audio resampling | Resample to 16kHz for Parakeet |
|
||||
| `ort` | ONNX Runtime | Run Parakeet v3 + Silero VAD |
|
||||
| `hf-hub` | Model download | Download from HuggingFace, standard cache dir |
|
||||
| `enigo` | Keyboard simulation | Simulate Ctrl+V, Shift+Insert, etc. |
|
||||
| `arboard` | Clipboard access | Read/write clipboard, save/restore |
|
||||
| `winit` | Windowing | Minimal overlay window |
|
||||
| `softbuffer` | Pixel rendering | Draw coloured overlay (no GPU needed for overlay) |
|
||||
| `serde` + `serde_yaml` | Config | Deserialize YAML config |
|
||||
| `clap` | CLI | Subcommands: `run`, `config`, `models` |
|
||||
| `dialoguer` | Interactive TUI | `mouth config` interactive setup |
|
||||
| `rodio` | Audio playback | Blip up/down sounds |
|
||||
| `indicatif` | Progress bars | Model download progress |
|
||||
| `dirs` | Platform dirs | Config/cache paths |
|
||||
| `tracing` | Logging | Structured logging |
|
||||
|
||||
## Config File
|
||||
|
||||
Location: `~/.config/mouth/config.yaml` (Linux/macOS), `%APPDATA%\mouth\config.yaml` (Windows)
|
||||
|
||||
```yaml
|
||||
# Hotkey to activate recording
|
||||
hotkey: "ctrl+space"
|
||||
|
||||
# Recording mode: push_to_talk or toggle
|
||||
mode: push_to_talk
|
||||
|
||||
# Cancel hotkey (only active while recording)
|
||||
cancel_key: "escape"
|
||||
|
||||
# Speech-to-text model
|
||||
model: "parakeet-tdt-0.6b-v3"
|
||||
|
||||
# Inference accelerator: auto, cpu, cuda, directml
|
||||
accelerator: auto
|
||||
|
||||
# GPU device index (only used when accelerator is cuda/directml)
|
||||
gpu_device: 0
|
||||
|
||||
# How to paste text
|
||||
paste_method: ctrl_v # ctrl_v | shift_insert | ctrl_shift_v | clipboard_only
|
||||
|
||||
# Also keep transcribed text on clipboard after pasting
|
||||
copy_to_clipboard: true
|
||||
|
||||
# Overlay position on screen
|
||||
overlay_position: top # top | bottom | none
|
||||
|
||||
# Audio feedback
|
||||
audio_feedback: true
|
||||
|
||||
# Audio input device (null = system default)
|
||||
input_device: null
|
||||
|
||||
# VAD: trim silence from audio before transcription
|
||||
vad_enabled: true
|
||||
|
||||
# Language (for model hint, if supported)
|
||||
language: en
|
||||
```
|
||||
|
||||
## CLI Interface
|
||||
|
||||
```
|
||||
mouth run # Start the daemon (default if no subcommand)
|
||||
mouth config # Interactive TUI to edit config
|
||||
mouth config --show # Print current config to stdout
|
||||
mouth config --reset # Reset config to defaults
|
||||
mouth models # List available/downloaded models
|
||||
mouth models download # Download configured model (if not cached)
|
||||
mouth status # Show daemon status, loaded model, app version
|
||||
```
|
||||
|
||||
## Implementation Phases
|
||||
|
||||
### Phase 1: Project Skeleton + Config
|
||||
|
||||
- Cargo.toml with all dependencies
|
||||
- Config struct with serde, defaults, load/save
|
||||
- CLI with clap (run, config, models subcommands)
|
||||
- `mouth config` interactive TUI with dialoguer
|
||||
- Platform-aware config/cache directory resolution
|
||||
|
||||
### Phase 2: Hotkey Listener
|
||||
|
||||
- Global hotkey capture using rdev
|
||||
- Support configurable key combinations (parse from string like "ctrl+space")
|
||||
- Push-to-talk mode: record on press, stop on release
|
||||
- Toggle mode: start on first press, stop on second press
|
||||
- Cancel on Escape while recording
|
||||
- Debounce rapid key events (~30ms)
|
||||
|
||||
### Phase 3: Audio Capture + VAD
|
||||
|
||||
- Open mic input via cpal (default device or configured)
|
||||
- Convert to f32 mono
|
||||
- Resample to 16kHz via rubato
|
||||
- Buffer audio chunks during recording
|
||||
- Run Silero VAD to trim leading/trailing silence
|
||||
- Produce final `Vec<f32>` of clean speech at 16kHz
|
||||
|
||||
### Phase 4: Model Management
|
||||
|
||||
- Use hf-hub to download Parakeet v3 ONNX model from HuggingFace
|
||||
- Store in standard HF cache (`~/.cache/huggingface/hub/`)
|
||||
- Show download progress with indicatif
|
||||
- `mouth models` command to list/download models
|
||||
- Auto-download on first run if model not cached
|
||||
|
||||
### Phase 5: Transcription
|
||||
|
||||
- Load Parakeet v3 ONNX model via ort
|
||||
- Auto-detect GPU (DirectML on Windows, CUDA if available, CPU fallback)
|
||||
- Respect accelerator override from config
|
||||
- Run inference on captured audio
|
||||
- Return transcribed text string
|
||||
|
||||
### Phase 6: Overlay
|
||||
|
||||
- Create a small always-on-top window using winit
|
||||
- Render with softbuffer (simple coloured rectangle + text)
|
||||
- States and colours:
|
||||
- Recording: red pulsing indicator
|
||||
- Transcribing: amber/yellow
|
||||
- Done: brief green flash, then hide
|
||||
- Error: brief red flash with error hint
|
||||
- Window flags (Windows): `WS_EX_TOPMOST | WS_EX_TOOLWINDOW | WS_EX_NOACTIVATE`
|
||||
- Position: centered horizontally at top or bottom of current monitor
|
||||
- No focus steal, no taskbar entry
|
||||
|
||||
### Phase 7: Paste System
|
||||
|
||||
- Save current clipboard content (if preserving)
|
||||
- Set transcribed text to clipboard via arboard
|
||||
- Simulate keypress via enigo based on paste_method:
|
||||
- `ctrl_v`: Ctrl+V (Cmd+V on macOS)
|
||||
- `shift_insert`: Shift+Insert
|
||||
- `ctrl_shift_v`: Ctrl+Shift+V
|
||||
- `clipboard_only`: no keypress, just clipboard
|
||||
- Restore previous clipboard content (unless copy_to_clipboard is true)
|
||||
- Small delay between clipboard set and paste simulation (~50ms)
|
||||
|
||||
### Phase 8: Audio Feedback
|
||||
|
||||
- Bundle two short PCM blip sounds in the binary (via `include_bytes!`)
|
||||
- "Blip up" on recording start
|
||||
- "Blip down" on recording stop / transcription complete
|
||||
- Play via rodio on a separate thread (non-blocking)
|
||||
- Respect audio_feedback config flag
|
||||
|
||||
### Phase 9: Coordinator + Integration
|
||||
|
||||
- Wire all components together with channel-based message passing
|
||||
- Main thread: overlay window event loop (winit requires this)
|
||||
- Spawned threads/tasks: hotkey listener, audio recorder, transcriber
|
||||
- Coordinator receives events, drives state machine:
|
||||
```
|
||||
Idle ──[hotkey press]──▶ Recording
|
||||
Recording ──[hotkey release/press]──▶ Transcribing
|
||||
Recording ──[cancel]──▶ Idle
|
||||
Transcribing ──[result]──▶ Pasting ──▶ Idle
|
||||
Transcribing ──[error]──▶ Error ──▶ Idle
|
||||
```
|
||||
- Graceful shutdown on SIGINT / tray quit
|
||||
|
||||
### Phase 10: Daemon IPC + Status
|
||||
|
||||
- The running daemon listens on a local Unix domain socket (Linux/macOS) or named pipe (Windows) for status queries
|
||||
- Socket/pipe path: `/tmp/mouth.sock` (Linux/macOS), `\\.\pipe\mouth` (Windows)
|
||||
- `mouth status` connects and requests current state; daemon responds with JSON:
|
||||
```json
|
||||
{
|
||||
"version": "0.1.0",
|
||||
"state": "idle",
|
||||
"model": "parakeet-tdt-0.6b-v3",
|
||||
"accelerator": "directml",
|
||||
"uptime_secs": 3420
|
||||
}
|
||||
```
|
||||
- If the daemon is not running, `mouth status` reports "Mouth is not running" and exits with code 1
|
||||
- Also used internally to prevent launching a second daemon instance (lock check)
|
||||
|
||||
### Phase 11: Polish + Distribution
|
||||
|
||||
- Error handling: user-friendly messages for common failures (no mic, model not found, etc.)
|
||||
- Windows installer via `cargo-wix` or distribute as standalone .exe
|
||||
- Test on Windows 10/11 primarily
|
||||
- Test on Linux (X11 + Wayland) and macOS as secondary
|
||||
- Update CLAUDE.md with build/run/test instructions
|
||||
- Write user-facing README with setup instructions
|
||||
|
||||
## Risks & Mitigations
|
||||
|
||||
| Risk | Impact | Mitigation |
|
||||
|------|--------|------------|
|
||||
| Parakeet v3 ONNX model compatibility with `ort` | Blocks core functionality | Test early in Phase 5; Parakeet v2 as fallback |
|
||||
| `rdev` hotkey reliability on Windows | Broken UX | Test early in Phase 2; fallback to Win32 `RegisterHotKey` |
|
||||
| Overlay focus stealing | Annoying | Use proper window flags; test with various foreground apps |
|
||||
| Audio resampling quality | Poor transcription | Use rubato SincInterpolation (high quality) |
|
||||
| Binary size with bundled ONNX Runtime | Large download | ONNX Runtime is ~20-40MB; acceptable for a single-binary tool |
|
||||
| winit event loop blocking | Unresponsive | All heavy work on background threads; overlay is lightweight |
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
mouth/
|
||||
├── Cargo.toml
|
||||
├── CLAUDE.md
|
||||
├── README.md
|
||||
├── plan.md
|
||||
├── config.yaml.example
|
||||
├── resources/
|
||||
│ ├── blip_up.pcm # bundled audio feedback
|
||||
│ └── blip_down.pcm
|
||||
└── src/
|
||||
├── main.rs # CLI entry, clap setup
|
||||
├── config.rs # Config struct, YAML load/save, defaults
|
||||
├── hotkey.rs # Global hotkey listener (rdev)
|
||||
├── recorder.rs # Audio capture (cpal + rubato + VAD)
|
||||
├── vad.rs # Silero VAD wrapper
|
||||
├── transcriber.rs # ONNX inference, model loading, GPU detection
|
||||
├── model_cache.rs # HuggingFace download, cache management
|
||||
├── overlay.rs # Minimal overlay window (winit + softbuffer)
|
||||
├── paste.rs # Clipboard + paste simulation
|
||||
├── audio_feedback.rs # Blip sounds via rodio
|
||||
├── coordinator.rs # State machine, channel hub
|
||||
└── cli/
|
||||
├── mod.rs
|
||||
├── run.rs # `mouth run` handler
|
||||
├── config_cmd.rs # `mouth config` TUI
|
||||
├── models_cmd.rs # `mouth models` handler
|
||||
└── status_cmd.rs # `mouth status` handler
|
||||
```
|
||||
|
||||
## Not In Scope (v1)
|
||||
|
||||
- LLM post-processing of transcriptions
|
||||
- Transcription history / database
|
||||
- Multiple model support (v1 is Parakeet v3 only, architecture supports adding more later)
|
||||
- Auto-submit (Enter after paste)
|
||||
- Multi-language UI
|
||||
- Tray icon / system tray integration
|
||||
- Translate-to-English mode
|
||||
@@ -0,0 +1,103 @@
|
||||
use anyhow::Result;
|
||||
use rodio::{OutputStream, Sink};
|
||||
use std::io::Cursor;
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// Generate a simple sine wave blip sound.
|
||||
fn generate_blip(freq_start: f32, freq_end: f32, duration_ms: u64) -> Vec<i16> {
|
||||
let sample_rate = 44100u32;
|
||||
let num_samples = (sample_rate as u64 * duration_ms / 1000) as usize;
|
||||
let mut samples = Vec::with_capacity(num_samples);
|
||||
|
||||
for i in 0..num_samples {
|
||||
let t = i as f32 / sample_rate as f32;
|
||||
let progress = i as f32 / num_samples as f32;
|
||||
|
||||
// Linear frequency sweep
|
||||
let freq = freq_start + (freq_end - freq_start) * progress;
|
||||
|
||||
// Sine wave with envelope (fade in/out)
|
||||
let envelope = if progress < 0.1 {
|
||||
progress / 0.1
|
||||
} else if progress > 0.8 {
|
||||
(1.0 - progress) / 0.2
|
||||
} else {
|
||||
1.0
|
||||
};
|
||||
|
||||
let sample = (envelope * 0.3 * (2.0 * std::f32::consts::PI * freq * t).sin()) * i16::MAX as f32;
|
||||
samples.push(sample as i16);
|
||||
}
|
||||
|
||||
samples
|
||||
}
|
||||
|
||||
/// Encode samples as a WAV file in memory.
|
||||
fn encode_wav(samples: &[i16], sample_rate: u32) -> Vec<u8> {
|
||||
let mut buf = Vec::new();
|
||||
let data_len = (samples.len() * 2) as u32;
|
||||
let file_len = 36 + data_len;
|
||||
|
||||
// RIFF header
|
||||
buf.extend_from_slice(b"RIFF");
|
||||
buf.extend_from_slice(&file_len.to_le_bytes());
|
||||
buf.extend_from_slice(b"WAVE");
|
||||
|
||||
// fmt chunk
|
||||
buf.extend_from_slice(b"fmt ");
|
||||
buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // PCM
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // mono
|
||||
buf.extend_from_slice(&sample_rate.to_le_bytes());
|
||||
buf.extend_from_slice(&(sample_rate * 2).to_le_bytes()); // byte rate
|
||||
buf.extend_from_slice(&2u16.to_le_bytes()); // block align
|
||||
buf.extend_from_slice(&16u16.to_le_bytes()); // bits per sample
|
||||
|
||||
// data chunk
|
||||
buf.extend_from_slice(b"data");
|
||||
buf.extend_from_slice(&data_len.to_le_bytes());
|
||||
for sample in samples {
|
||||
buf.extend_from_slice(&sample.to_le_bytes());
|
||||
}
|
||||
|
||||
buf
|
||||
}
|
||||
|
||||
/// Play the "blip up" sound (recording started).
|
||||
pub fn play_blip_up() {
|
||||
thread::spawn(|| {
|
||||
if let Err(e) = play_blip_internal(800.0, 1200.0, 100) {
|
||||
warn!("Failed to play blip up: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Play the "blip down" sound (recording stopped / transcription done).
|
||||
pub fn play_blip_down() {
|
||||
thread::spawn(|| {
|
||||
if let Err(e) = play_blip_internal(1200.0, 800.0, 100) {
|
||||
warn!("Failed to play blip down: {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn play_blip_internal(freq_start: f32, freq_end: f32, duration_ms: u64) -> Result<()> {
|
||||
let samples = generate_blip(freq_start, freq_end, duration_ms);
|
||||
let wav_data = encode_wav(&samples, 44100);
|
||||
|
||||
let (_stream, stream_handle) = OutputStream::try_default()?;
|
||||
let sink = Sink::try_new(&stream_handle)?;
|
||||
|
||||
let cursor = Cursor::new(wav_data);
|
||||
let source = rodio::Decoder::new(cursor)?;
|
||||
sink.append(source);
|
||||
|
||||
debug!("Playing blip ({freq_start}Hz -> {freq_end}Hz, {duration_ms}ms)");
|
||||
sink.sleep_until_end();
|
||||
|
||||
// Keep stream alive briefly
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
use anyhow::Result;
|
||||
use dialoguer::{Input, Select};
|
||||
|
||||
use crate::config::{Accelerator, Config, OverlayPosition, PasteMethod, RecordingMode};
|
||||
|
||||
pub fn show() -> Result<()> {
|
||||
let config = Config::load()?;
|
||||
let yaml = serde_yaml::to_string(&config)?;
|
||||
println!("{yaml}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn reset() -> Result<()> {
|
||||
let config = Config::default();
|
||||
config.save()?;
|
||||
println!("Config reset to defaults at {}", Config::path()?.display());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn interactive() -> Result<()> {
|
||||
let mut config = Config::load()?;
|
||||
|
||||
config.hotkey = Input::new()
|
||||
.with_prompt("Hotkey")
|
||||
.default(config.hotkey)
|
||||
.interact_text()?;
|
||||
|
||||
let mode_idx = Select::new()
|
||||
.with_prompt("Recording mode")
|
||||
.items(&["push_to_talk", "toggle"])
|
||||
.default(match config.mode {
|
||||
RecordingMode::PushToTalk => 0,
|
||||
RecordingMode::Toggle => 1,
|
||||
})
|
||||
.interact()?;
|
||||
config.mode = match mode_idx {
|
||||
0 => RecordingMode::PushToTalk,
|
||||
_ => RecordingMode::Toggle,
|
||||
};
|
||||
|
||||
config.cancel_key = Input::new()
|
||||
.with_prompt("Cancel key")
|
||||
.default(config.cancel_key)
|
||||
.interact_text()?;
|
||||
|
||||
config.model = Input::new()
|
||||
.with_prompt("Model")
|
||||
.default(config.model)
|
||||
.interact_text()?;
|
||||
|
||||
let accel_idx = Select::new()
|
||||
.with_prompt("Accelerator")
|
||||
.items(&["auto", "cpu", "cuda", "directml"])
|
||||
.default(match config.accelerator {
|
||||
Accelerator::Auto => 0,
|
||||
Accelerator::Cpu => 1,
|
||||
Accelerator::Cuda => 2,
|
||||
Accelerator::DirectMl => 3,
|
||||
})
|
||||
.interact()?;
|
||||
config.accelerator = match accel_idx {
|
||||
0 => Accelerator::Auto,
|
||||
1 => Accelerator::Cpu,
|
||||
2 => Accelerator::Cuda,
|
||||
_ => Accelerator::DirectMl,
|
||||
};
|
||||
|
||||
config.gpu_device = Input::new()
|
||||
.with_prompt("GPU device index")
|
||||
.default(config.gpu_device)
|
||||
.interact_text()?;
|
||||
|
||||
let paste_idx = Select::new()
|
||||
.with_prompt("Paste method")
|
||||
.items(&["ctrl_v", "shift_insert", "ctrl_shift_v", "clipboard_only"])
|
||||
.default(match config.paste_method {
|
||||
PasteMethod::CtrlV => 0,
|
||||
PasteMethod::ShiftInsert => 1,
|
||||
PasteMethod::CtrlShiftV => 2,
|
||||
PasteMethod::ClipboardOnly => 3,
|
||||
})
|
||||
.interact()?;
|
||||
config.paste_method = match paste_idx {
|
||||
0 => PasteMethod::CtrlV,
|
||||
1 => PasteMethod::ShiftInsert,
|
||||
2 => PasteMethod::CtrlShiftV,
|
||||
_ => PasteMethod::ClipboardOnly,
|
||||
};
|
||||
|
||||
let overlay_idx = Select::new()
|
||||
.with_prompt("Overlay position")
|
||||
.items(&["top", "bottom", "none"])
|
||||
.default(match config.overlay_position {
|
||||
OverlayPosition::Top => 0,
|
||||
OverlayPosition::Bottom => 1,
|
||||
OverlayPosition::None => 2,
|
||||
})
|
||||
.interact()?;
|
||||
config.overlay_position = match overlay_idx {
|
||||
0 => OverlayPosition::Top,
|
||||
1 => OverlayPosition::Bottom,
|
||||
_ => OverlayPosition::None,
|
||||
};
|
||||
|
||||
let feedback_idx = Select::new()
|
||||
.with_prompt("Audio feedback")
|
||||
.items(&["yes", "no"])
|
||||
.default(if config.audio_feedback { 0 } else { 1 })
|
||||
.interact()?;
|
||||
config.audio_feedback = feedback_idx == 0;
|
||||
|
||||
let vad_idx = Select::new()
|
||||
.with_prompt("VAD (voice activity detection)")
|
||||
.items(&["enabled", "disabled"])
|
||||
.default(if config.vad_enabled { 0 } else { 1 })
|
||||
.interact()?;
|
||||
config.vad_enabled = vad_idx == 0;
|
||||
|
||||
config.language = Input::new()
|
||||
.with_prompt("Language")
|
||||
.default(config.language)
|
||||
.interact_text()?;
|
||||
|
||||
config.save()?;
|
||||
println!("\nConfig saved to {}", Config::path()?.display());
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,4 @@
|
||||
pub mod config_cmd;
|
||||
pub mod models_cmd;
|
||||
pub mod run_cmd;
|
||||
pub mod status_cmd;
|
||||
@@ -0,0 +1,30 @@
|
||||
use anyhow::Result;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::model_cache;
|
||||
|
||||
pub fn list() -> Result<()> {
|
||||
let config = Config::load()?;
|
||||
|
||||
println!("Configured model: {}", config.model);
|
||||
println!();
|
||||
println!("Available models:");
|
||||
for (name, cached) in model_cache::list_models() {
|
||||
let status = if cached { "downloaded" } else { "not downloaded" };
|
||||
let marker = if name == config.model { " (active)" } else { "" };
|
||||
println!(" {name}{marker} [{status}]");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn download() -> Result<()> {
|
||||
let config = Config::load()?;
|
||||
println!("Downloading model: {}...", config.model);
|
||||
|
||||
let paths = model_cache::ensure_model(&config.model)?;
|
||||
println!("Model ready:");
|
||||
println!(" Encoder: {}", paths.encoder.display());
|
||||
println!(" Decoder: {}", paths.decoder.display());
|
||||
println!(" Vocab: {}", paths.vocab.display());
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
use anyhow::{Context, Result};
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
use tracing::info;
|
||||
|
||||
use crate::config::{Config, OverlayPosition};
|
||||
use crate::coordinator::Coordinator;
|
||||
use crate::hotkey;
|
||||
use crate::model_cache;
|
||||
use crate::overlay;
|
||||
use crate::recorder;
|
||||
use crate::transcriber::Transcriber;
|
||||
|
||||
pub fn run() -> Result<()> {
|
||||
let config = Config::load()?;
|
||||
info!("Mouth v{} starting", env!("CARGO_PKG_VERSION"));
|
||||
info!("Mode: {:?}", config.mode);
|
||||
info!("Hotkey: {}", config.hotkey);
|
||||
info!("Model: {}", config.model);
|
||||
info!("Accelerator: {:?}", config.accelerator);
|
||||
info!("Paste method: {:?}", config.paste_method);
|
||||
|
||||
// Step 1: Ensure model is downloaded
|
||||
info!("Checking model...");
|
||||
let model_paths = model_cache::ensure_model(&config.model)
|
||||
.context("Failed to ensure model is available")?;
|
||||
|
||||
// Step 2: Load transcriber
|
||||
info!("Loading transcription engine...");
|
||||
let transcriber = Transcriber::new(&model_paths, &config.accelerator, config.gpu_device)
|
||||
.context("Failed to load transcription engine")?;
|
||||
|
||||
// Step 3: VAD (not yet bundled)
|
||||
let vad = if config.vad_enabled {
|
||||
info!("VAD enabled but Silero model not yet bundled — skipping");
|
||||
None
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Step 4: Parse hotkeys
|
||||
let hotkey_combo = hotkey::parse_hotkey(&config.hotkey)
|
||||
.with_context(|| format!("Invalid hotkey: {}", config.hotkey))?;
|
||||
let cancel_combo = hotkey::parse_hotkey(&config.cancel_key)
|
||||
.with_context(|| format!("Invalid cancel key: {}", config.cancel_key))?;
|
||||
|
||||
// Step 5: Set up channels
|
||||
let (hotkey_tx, hotkey_rx) = mpsc::channel();
|
||||
let (recorder_cmd_tx, recorder_cmd_rx) = mpsc::channel();
|
||||
let (audio_tx, audio_rx) = mpsc::channel();
|
||||
|
||||
// Step 6: Spawn background threads
|
||||
let device_name = config.input_device.clone();
|
||||
thread::Builder::new()
|
||||
.name("mouth-recorder".into())
|
||||
.spawn(move || {
|
||||
recorder::run(device_name, recorder_cmd_rx, audio_tx);
|
||||
})
|
||||
.context("Failed to spawn recorder thread")?;
|
||||
|
||||
thread::Builder::new()
|
||||
.name("mouth-hotkey".into())
|
||||
.spawn(move || {
|
||||
hotkey::listen(hotkey_combo, cancel_combo, hotkey_tx);
|
||||
})
|
||||
.context("Failed to spawn hotkey thread")?;
|
||||
|
||||
// Step 7: Start overlay + coordinator
|
||||
if config.overlay_position != OverlayPosition::None {
|
||||
let (event_loop, proxy) = overlay::create_event_loop()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create overlay event loop: {e}"))?;
|
||||
|
||||
let overlay_position = config.overlay_position.clone();
|
||||
let coord_proxy = Some(proxy);
|
||||
|
||||
// Coordinator runs on a background thread
|
||||
let coord_config = config.clone();
|
||||
thread::Builder::new()
|
||||
.name("mouth-coordinator".into())
|
||||
.spawn(move || {
|
||||
let mut coordinator = Coordinator::new(
|
||||
coord_config,
|
||||
transcriber,
|
||||
vad,
|
||||
recorder_cmd_tx,
|
||||
audio_rx,
|
||||
hotkey_rx,
|
||||
coord_proxy,
|
||||
);
|
||||
coordinator.run();
|
||||
})
|
||||
.context("Failed to spawn coordinator thread")?;
|
||||
|
||||
println!("Mouth is running. Press {} to record. Ctrl+C to quit.", config.hotkey);
|
||||
|
||||
// Overlay event loop runs on main thread (blocking)
|
||||
overlay::run_event_loop(event_loop, overlay_position)
|
||||
.map_err(|e| anyhow::anyhow!("Overlay event loop error: {e}"))?;
|
||||
} else {
|
||||
// No overlay — coordinator runs on main thread
|
||||
println!("Mouth is running. Press {} to record. Ctrl+C to quit.", config.hotkey);
|
||||
|
||||
let mut coordinator = Coordinator::new(
|
||||
config,
|
||||
transcriber,
|
||||
vad,
|
||||
recorder_cmd_tx,
|
||||
audio_rx,
|
||||
hotkey_rx,
|
||||
None,
|
||||
);
|
||||
coordinator.run();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
use anyhow::Result;
|
||||
|
||||
pub fn status() -> Result<()> {
|
||||
let version = env!("CARGO_PKG_VERSION");
|
||||
|
||||
// TODO: Phase 10 — connect to daemon IPC socket/pipe and query status
|
||||
// For now, just show version info
|
||||
println!("Mouth v{version}");
|
||||
println!("Status: not yet implemented (requires daemon IPC)");
|
||||
Ok(())
|
||||
}
|
||||
+184
@@ -0,0 +1,184 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum RecordingMode {
|
||||
PushToTalk,
|
||||
Toggle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PasteMethod {
|
||||
CtrlV,
|
||||
ShiftInsert,
|
||||
CtrlShiftV,
|
||||
ClipboardOnly,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum OverlayPosition {
|
||||
Top,
|
||||
Bottom,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum Accelerator {
|
||||
Auto,
|
||||
Cpu,
|
||||
Cuda,
|
||||
DirectMl,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
/// Hotkey to activate recording
|
||||
#[serde(default = "defaults::hotkey")]
|
||||
pub hotkey: String,
|
||||
|
||||
/// Recording mode
|
||||
#[serde(default = "defaults::mode")]
|
||||
pub mode: RecordingMode,
|
||||
|
||||
/// Cancel hotkey (only active while recording)
|
||||
#[serde(default = "defaults::cancel_key")]
|
||||
pub cancel_key: String,
|
||||
|
||||
/// Speech-to-text model identifier
|
||||
#[serde(default = "defaults::model")]
|
||||
pub model: String,
|
||||
|
||||
/// Inference accelerator
|
||||
#[serde(default = "defaults::accelerator")]
|
||||
pub accelerator: Accelerator,
|
||||
|
||||
/// GPU device index (when accelerator is cuda/directml)
|
||||
#[serde(default)]
|
||||
pub gpu_device: u32,
|
||||
|
||||
/// How to paste transcribed text
|
||||
#[serde(default = "defaults::paste_method")]
|
||||
pub paste_method: PasteMethod,
|
||||
|
||||
/// Keep transcribed text on clipboard after pasting
|
||||
#[serde(default = "defaults::yes")]
|
||||
pub copy_to_clipboard: bool,
|
||||
|
||||
/// Overlay position on screen
|
||||
#[serde(default = "defaults::overlay_position")]
|
||||
pub overlay_position: OverlayPosition,
|
||||
|
||||
/// Play audio feedback sounds
|
||||
#[serde(default = "defaults::yes")]
|
||||
pub audio_feedback: bool,
|
||||
|
||||
/// Audio input device name (null = system default)
|
||||
#[serde(default)]
|
||||
pub input_device: Option<String>,
|
||||
|
||||
/// Enable VAD to trim silence
|
||||
#[serde(default = "defaults::yes")]
|
||||
pub vad_enabled: bool,
|
||||
|
||||
/// Language hint for model
|
||||
#[serde(default = "defaults::language")]
|
||||
pub language: String,
|
||||
}
|
||||
|
||||
mod defaults {
|
||||
use super::*;
|
||||
|
||||
pub fn hotkey() -> String {
|
||||
"ctrl+space".into()
|
||||
}
|
||||
pub fn mode() -> RecordingMode {
|
||||
RecordingMode::PushToTalk
|
||||
}
|
||||
pub fn cancel_key() -> String {
|
||||
"escape".into()
|
||||
}
|
||||
pub fn model() -> String {
|
||||
"parakeet-tdt-0.6b-v3".into()
|
||||
}
|
||||
pub fn accelerator() -> Accelerator {
|
||||
Accelerator::Auto
|
||||
}
|
||||
pub fn paste_method() -> PasteMethod {
|
||||
PasteMethod::CtrlV
|
||||
}
|
||||
pub fn overlay_position() -> OverlayPosition {
|
||||
OverlayPosition::Top
|
||||
}
|
||||
pub fn yes() -> bool {
|
||||
true
|
||||
}
|
||||
pub fn language() -> String {
|
||||
"en".into()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
hotkey: defaults::hotkey(),
|
||||
mode: defaults::mode(),
|
||||
cancel_key: defaults::cancel_key(),
|
||||
model: defaults::model(),
|
||||
accelerator: defaults::accelerator(),
|
||||
gpu_device: 0,
|
||||
paste_method: defaults::paste_method(),
|
||||
copy_to_clipboard: true,
|
||||
overlay_position: defaults::overlay_position(),
|
||||
audio_feedback: true,
|
||||
input_device: None,
|
||||
vad_enabled: true,
|
||||
language: defaults::language(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Returns the platform-appropriate config directory.
|
||||
pub fn dir() -> Result<PathBuf> {
|
||||
let dir = dirs::config_dir()
|
||||
.context("Could not determine config directory")?
|
||||
.join("mouth");
|
||||
Ok(dir)
|
||||
}
|
||||
|
||||
/// Returns the path to the config file.
|
||||
pub fn path() -> Result<PathBuf> {
|
||||
Ok(Self::dir()?.join("config.yaml"))
|
||||
}
|
||||
|
||||
/// Load config from disk, falling back to defaults if file doesn't exist.
|
||||
pub fn load() -> Result<Self> {
|
||||
let path = Self::path()?;
|
||||
if !path.exists() {
|
||||
return Ok(Self::default());
|
||||
}
|
||||
let contents = std::fs::read_to_string(&path)
|
||||
.with_context(|| format!("Failed to read config from {}", path.display()))?;
|
||||
let config: Config = serde_yaml::from_str(&contents)
|
||||
.with_context(|| format!("Failed to parse config from {}", path.display()))?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Save config to disk, creating the directory if needed.
|
||||
pub fn save(&self) -> Result<()> {
|
||||
let path = Self::path()?;
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent)
|
||||
.with_context(|| format!("Failed to create config directory {}", parent.display()))?;
|
||||
}
|
||||
let yaml = serde_yaml::to_string(self).context("Failed to serialize config")?;
|
||||
std::fs::write(&path, yaml)
|
||||
.with_context(|| format!("Failed to write config to {}", path.display()))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,255 @@
|
||||
use std::sync::mpsc;
|
||||
use std::thread;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use winit::event_loop::EventLoopProxy;
|
||||
|
||||
use crate::audio_feedback;
|
||||
use crate::config::{Config, RecordingMode};
|
||||
use crate::hotkey::HotkeyEvent;
|
||||
use crate::overlay::{OverlayEvent, OverlayState};
|
||||
use crate::paste;
|
||||
use crate::recorder::{AudioData, RecorderCommand};
|
||||
use crate::transcriber::Transcriber;
|
||||
use crate::vad::Vad;
|
||||
|
||||
/// The application state machine.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
enum State {
|
||||
Idle,
|
||||
Recording,
|
||||
Transcribing,
|
||||
}
|
||||
|
||||
/// Central coordinator that wires all components together.
|
||||
pub struct Coordinator {
|
||||
config: Config,
|
||||
state: State,
|
||||
transcriber: Transcriber,
|
||||
vad: Option<Vad>,
|
||||
recorder_tx: mpsc::Sender<RecorderCommand>,
|
||||
audio_rx: mpsc::Receiver<AudioData>,
|
||||
hotkey_rx: mpsc::Receiver<HotkeyEvent>,
|
||||
overlay_proxy: Option<EventLoopProxy<OverlayEvent>>,
|
||||
}
|
||||
|
||||
impl Coordinator {
|
||||
pub fn new(
|
||||
config: Config,
|
||||
transcriber: Transcriber,
|
||||
vad: Option<Vad>,
|
||||
recorder_tx: mpsc::Sender<RecorderCommand>,
|
||||
audio_rx: mpsc::Receiver<AudioData>,
|
||||
hotkey_rx: mpsc::Receiver<HotkeyEvent>,
|
||||
overlay_proxy: Option<EventLoopProxy<OverlayEvent>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
state: State::Idle,
|
||||
transcriber,
|
||||
vad,
|
||||
recorder_tx,
|
||||
audio_rx,
|
||||
hotkey_rx,
|
||||
overlay_proxy,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the coordinator loop. This blocks until shutdown.
|
||||
pub fn run(&mut self) {
|
||||
info!("Coordinator started");
|
||||
|
||||
loop {
|
||||
// Wait for hotkey events
|
||||
let event = match self.hotkey_rx.recv() {
|
||||
Ok(e) => e,
|
||||
Err(_) => {
|
||||
info!("Hotkey channel closed, shutting down");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
self.handle_event(event);
|
||||
}
|
||||
|
||||
self.shutdown();
|
||||
}
|
||||
|
||||
fn handle_event(&mut self, event: HotkeyEvent) {
|
||||
debug!("Event: {:?}, State: {:?}", event, self.state);
|
||||
|
||||
match (self.state, event) {
|
||||
// Start recording
|
||||
(State::Idle, HotkeyEvent::Pressed) => {
|
||||
self.start_recording();
|
||||
}
|
||||
|
||||
// Push-to-talk: stop on release
|
||||
(State::Recording, HotkeyEvent::Released) => {
|
||||
if self.config.mode == RecordingMode::PushToTalk {
|
||||
self.stop_recording();
|
||||
}
|
||||
}
|
||||
|
||||
// Toggle mode: stop on second press
|
||||
(State::Recording, HotkeyEvent::Pressed) => {
|
||||
if self.config.mode == RecordingMode::Toggle {
|
||||
self.stop_recording();
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel recording
|
||||
(State::Recording, HotkeyEvent::Cancel) => {
|
||||
self.cancel_recording();
|
||||
}
|
||||
|
||||
// Ignore other combinations
|
||||
_ => {
|
||||
debug!("Ignoring event {:?} in state {:?}", event, self.state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn start_recording(&mut self) {
|
||||
info!("Recording started");
|
||||
self.state = State::Recording;
|
||||
self.set_overlay(OverlayState::Recording);
|
||||
|
||||
if self.config.audio_feedback {
|
||||
audio_feedback::play_blip_up();
|
||||
}
|
||||
|
||||
if self.recorder_tx.send(RecorderCommand::Start).is_err() {
|
||||
error!("Failed to send start command to recorder");
|
||||
self.state = State::Idle;
|
||||
self.set_overlay(OverlayState::Hidden);
|
||||
}
|
||||
}
|
||||
|
||||
fn stop_recording(&mut self) {
|
||||
info!("Recording stopped, starting transcription");
|
||||
self.state = State::Transcribing;
|
||||
self.set_overlay(OverlayState::Transcribing);
|
||||
|
||||
if self.config.audio_feedback {
|
||||
audio_feedback::play_blip_down();
|
||||
}
|
||||
|
||||
if self.recorder_tx.send(RecorderCommand::Stop).is_err() {
|
||||
error!("Failed to send stop command to recorder");
|
||||
self.state = State::Idle;
|
||||
self.set_overlay(OverlayState::Hidden);
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for audio data
|
||||
match self.audio_rx.recv() {
|
||||
Ok(audio) => {
|
||||
self.process_audio(audio);
|
||||
}
|
||||
Err(_) => {
|
||||
error!("Failed to receive audio data");
|
||||
self.state = State::Idle;
|
||||
self.set_overlay(OverlayState::Error);
|
||||
self.delayed_hide_overlay();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cancel_recording(&mut self) {
|
||||
info!("Recording cancelled");
|
||||
self.state = State::Idle;
|
||||
|
||||
if self.recorder_tx.send(RecorderCommand::Stop).is_err() {
|
||||
warn!("Failed to send stop command to recorder");
|
||||
}
|
||||
|
||||
// Drain any pending audio
|
||||
while self.audio_rx.try_recv().is_ok() {}
|
||||
|
||||
self.set_overlay(OverlayState::Hidden);
|
||||
}
|
||||
|
||||
fn process_audio(&mut self, audio: AudioData) {
|
||||
let samples = if self.config.vad_enabled {
|
||||
if let Some(vad) = &mut self.vad {
|
||||
match vad.filter_speech(&audio.samples) {
|
||||
Ok(filtered) => {
|
||||
if filtered.is_empty() {
|
||||
info!("No speech detected by VAD");
|
||||
self.state = State::Idle;
|
||||
self.set_overlay(OverlayState::Hidden);
|
||||
return;
|
||||
}
|
||||
filtered
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("VAD failed, using raw audio: {e}");
|
||||
audio.samples
|
||||
}
|
||||
}
|
||||
} else {
|
||||
audio.samples
|
||||
}
|
||||
} else {
|
||||
audio.samples
|
||||
};
|
||||
|
||||
// Transcribe
|
||||
match self.transcriber.transcribe(&samples) {
|
||||
Ok(text) => {
|
||||
if text.is_empty() {
|
||||
info!("Empty transcription");
|
||||
self.state = State::Idle;
|
||||
self.set_overlay(OverlayState::Hidden);
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Transcribed: \"{text}\"");
|
||||
self.set_overlay(OverlayState::Done);
|
||||
|
||||
// Paste the text
|
||||
if let Err(e) = paste::paste_text(
|
||||
&text,
|
||||
&self.config.paste_method,
|
||||
self.config.copy_to_clipboard,
|
||||
) {
|
||||
error!("Failed to paste text: {e}");
|
||||
self.set_overlay(OverlayState::Error);
|
||||
}
|
||||
|
||||
self.delayed_hide_overlay();
|
||||
self.state = State::Idle;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Transcription failed: {e}");
|
||||
self.state = State::Idle;
|
||||
self.set_overlay(OverlayState::Error);
|
||||
self.delayed_hide_overlay();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn set_overlay(&self, state: OverlayState) {
|
||||
if let Some(proxy) = &self.overlay_proxy {
|
||||
let _ = proxy.send_event(OverlayEvent::SetState(state));
|
||||
}
|
||||
}
|
||||
|
||||
fn delayed_hide_overlay(&self) {
|
||||
if let Some(proxy) = &self.overlay_proxy {
|
||||
let proxy = proxy.clone();
|
||||
thread::spawn(move || {
|
||||
thread::sleep(std::time::Duration::from_millis(500));
|
||||
let _ = proxy.send_event(OverlayEvent::SetState(OverlayState::Hidden));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn shutdown(&self) {
|
||||
info!("Coordinator shutting down");
|
||||
let _ = self.recorder_tx.send(RecorderCommand::Shutdown);
|
||||
if let Some(proxy) = &self.overlay_proxy {
|
||||
let _ = proxy.send_event(OverlayEvent::Shutdown);
|
||||
}
|
||||
}
|
||||
}
|
||||
+240
@@ -0,0 +1,240 @@
|
||||
use anyhow::{bail, Result};
|
||||
use rdev::{self, Event, EventType, Key};
|
||||
use std::sync::mpsc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
/// Events sent from the hotkey listener to the coordinator.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum HotkeyEvent {
|
||||
/// Hotkey was pressed (start recording or toggle)
|
||||
Pressed,
|
||||
/// Hotkey was released (stop recording in push-to-talk mode)
|
||||
Released,
|
||||
/// Cancel key was pressed
|
||||
Cancel,
|
||||
}
|
||||
|
||||
/// Parsed hotkey combination.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct HotkeyCombination {
|
||||
pub modifiers: Vec<Modifier>,
|
||||
pub key: Key,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Modifier {
|
||||
Ctrl,
|
||||
Alt,
|
||||
Shift,
|
||||
Meta,
|
||||
}
|
||||
|
||||
/// Tracks which modifier keys are currently held down.
|
||||
#[derive(Debug, Default)]
|
||||
struct ModifierState {
|
||||
ctrl: bool,
|
||||
alt: bool,
|
||||
shift: bool,
|
||||
meta: bool,
|
||||
}
|
||||
|
||||
impl ModifierState {
|
||||
fn update(&mut self, key: &Key, pressed: bool) {
|
||||
match key {
|
||||
Key::ControlLeft | Key::ControlRight => self.ctrl = pressed,
|
||||
Key::Alt | Key::AltGr => self.alt = pressed,
|
||||
Key::ShiftLeft | Key::ShiftRight => self.shift = pressed,
|
||||
Key::MetaLeft | Key::MetaRight => self.meta = pressed,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_held(&self, modifier: &Modifier) -> bool {
|
||||
match modifier {
|
||||
Modifier::Ctrl => self.ctrl,
|
||||
Modifier::Alt => self.alt,
|
||||
Modifier::Shift => self.shift,
|
||||
Modifier::Meta => self.meta,
|
||||
}
|
||||
}
|
||||
|
||||
fn all_held(&self, modifiers: &[Modifier]) -> bool {
|
||||
modifiers.iter().all(|m| self.is_held(m))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a hotkey string like "ctrl+space" into a HotkeyCombination.
|
||||
pub fn parse_hotkey(s: &str) -> Result<HotkeyCombination> {
|
||||
let lowered = s.to_lowercase();
|
||||
let parts: Vec<&str> = lowered.split('+').map(|p| p.trim()).collect();
|
||||
if parts.is_empty() {
|
||||
bail!("Empty hotkey string");
|
||||
}
|
||||
|
||||
let mut modifiers = Vec::new();
|
||||
let mut key = None;
|
||||
|
||||
for part in &parts {
|
||||
match *part {
|
||||
"ctrl" | "control" => modifiers.push(Modifier::Ctrl),
|
||||
"alt" => modifiers.push(Modifier::Alt),
|
||||
"shift" => modifiers.push(Modifier::Shift),
|
||||
"meta" | "super" | "win" | "cmd" => modifiers.push(Modifier::Meta),
|
||||
_ => {
|
||||
if key.is_some() {
|
||||
bail!("Multiple non-modifier keys in hotkey: {s}");
|
||||
}
|
||||
key = Some(parse_key(part)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let key = key.unwrap_or_else(|| {
|
||||
// If no non-modifier key, treat last modifier as the key
|
||||
// This shouldn't happen with valid config but handle gracefully
|
||||
Key::Unknown(0)
|
||||
});
|
||||
|
||||
Ok(HotkeyCombination { modifiers, key })
|
||||
}
|
||||
|
||||
fn parse_key(s: &str) -> Result<Key> {
|
||||
let key = match s {
|
||||
"space" => Key::Space,
|
||||
"enter" | "return" => Key::Return,
|
||||
"escape" | "esc" => Key::Escape,
|
||||
"tab" => Key::Tab,
|
||||
"backspace" => Key::Backspace,
|
||||
"delete" | "del" => Key::Delete,
|
||||
"insert" => Key::Insert,
|
||||
"home" => Key::Home,
|
||||
"end" => Key::End,
|
||||
"pageup" => Key::PageUp,
|
||||
"pagedown" => Key::PageDown,
|
||||
"up" => Key::UpArrow,
|
||||
"down" => Key::DownArrow,
|
||||
"left" => Key::LeftArrow,
|
||||
"right" => Key::RightArrow,
|
||||
"f1" => Key::F1,
|
||||
"f2" => Key::F2,
|
||||
"f3" => Key::F3,
|
||||
"f4" => Key::F4,
|
||||
"f5" => Key::F5,
|
||||
"f6" => Key::F6,
|
||||
"f7" => Key::F7,
|
||||
"f8" => Key::F8,
|
||||
"f9" => Key::F9,
|
||||
"f10" => Key::F10,
|
||||
"f11" => Key::F11,
|
||||
"f12" => Key::F12,
|
||||
"a" => Key::KeyA,
|
||||
"b" => Key::KeyB,
|
||||
"c" => Key::KeyC,
|
||||
"d" => Key::KeyD,
|
||||
"e" => Key::KeyE,
|
||||
"f" => Key::KeyF,
|
||||
"g" => Key::KeyG,
|
||||
"h" => Key::KeyH,
|
||||
"i" => Key::KeyI,
|
||||
"j" => Key::KeyJ,
|
||||
"k" => Key::KeyK,
|
||||
"l" => Key::KeyL,
|
||||
"m" => Key::KeyM,
|
||||
"n" => Key::KeyN,
|
||||
"o" => Key::KeyO,
|
||||
"p" => Key::KeyP,
|
||||
"q" => Key::KeyQ,
|
||||
"r" => Key::KeyR,
|
||||
"s" => Key::KeyS,
|
||||
"t" => Key::KeyT,
|
||||
"u" => Key::KeyU,
|
||||
"v" => Key::KeyV,
|
||||
"w" => Key::KeyW,
|
||||
"x" => Key::KeyX,
|
||||
"y" => Key::KeyY,
|
||||
"z" => Key::KeyZ,
|
||||
"0" => Key::Num0,
|
||||
"1" => Key::Num1,
|
||||
"2" => Key::Num2,
|
||||
"3" => Key::Num3,
|
||||
"4" => Key::Num4,
|
||||
"5" => Key::Num5,
|
||||
"6" => Key::Num6,
|
||||
"7" => Key::Num7,
|
||||
"8" => Key::Num8,
|
||||
"9" => Key::Num9,
|
||||
_ => bail!("Unknown key: {s}"),
|
||||
};
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
/// Start the global hotkey listener on the current thread (blocking).
|
||||
/// Sends HotkeyEvents to the provided channel.
|
||||
pub fn listen(
|
||||
hotkey: HotkeyCombination,
|
||||
cancel_key: HotkeyCombination,
|
||||
tx: mpsc::Sender<HotkeyEvent>,
|
||||
) {
|
||||
let debounce_duration = Duration::from_millis(30);
|
||||
let mut last_event_time = Instant::now() - debounce_duration;
|
||||
let mut modifier_state = ModifierState::default();
|
||||
let mut hotkey_held = false;
|
||||
|
||||
info!("Hotkey listener started");
|
||||
debug!("Hotkey: {:?}", hotkey);
|
||||
debug!("Cancel: {:?}", cancel_key);
|
||||
|
||||
let callback = move |event: Event| {
|
||||
let now = Instant::now();
|
||||
match event.event_type {
|
||||
EventType::KeyPress(key) => {
|
||||
modifier_state.update(&key, true);
|
||||
|
||||
// Check cancel key
|
||||
if key == cancel_key.key && modifier_state.all_held(&cancel_key.modifiers) {
|
||||
if now.duration_since(last_event_time) >= debounce_duration {
|
||||
last_event_time = now;
|
||||
debug!("Cancel key pressed");
|
||||
if tx.send(HotkeyEvent::Cancel).is_err() {
|
||||
error!("Failed to send cancel event");
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Check hotkey
|
||||
if key == hotkey.key && modifier_state.all_held(&hotkey.modifiers) {
|
||||
if now.duration_since(last_event_time) >= debounce_duration && !hotkey_held {
|
||||
last_event_time = now;
|
||||
hotkey_held = true;
|
||||
debug!("Hotkey pressed");
|
||||
if tx.send(HotkeyEvent::Pressed).is_err() {
|
||||
error!("Failed to send pressed event");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
EventType::KeyRelease(key) => {
|
||||
modifier_state.update(&key, false);
|
||||
|
||||
// Check hotkey release (for push-to-talk)
|
||||
if key == hotkey.key && hotkey_held {
|
||||
if now.duration_since(last_event_time) >= debounce_duration {
|
||||
last_event_time = now;
|
||||
hotkey_held = false;
|
||||
debug!("Hotkey released");
|
||||
if tx.send(HotkeyEvent::Released).is_err() {
|
||||
error!("Failed to send released event");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = rdev::listen(callback) {
|
||||
error!("Hotkey listener error: {:?}", e);
|
||||
}
|
||||
}
|
||||
+82
@@ -0,0 +1,82 @@
|
||||
mod audio_feedback;
|
||||
mod cli;
|
||||
mod config;
|
||||
mod coordinator;
|
||||
mod hotkey;
|
||||
mod model_cache;
|
||||
mod overlay;
|
||||
mod paste;
|
||||
mod recorder;
|
||||
mod transcriber;
|
||||
mod vad;
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "mouth", version, about = "Offline speech-to-text with global hotkey and paste")]
|
||||
struct Cli {
|
||||
#[command(subcommand)]
|
||||
command: Option<Commands>,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
/// Start the mouth daemon
|
||||
Run,
|
||||
|
||||
/// View or edit configuration
|
||||
Config {
|
||||
/// Print current config to stdout
|
||||
#[arg(long)]
|
||||
show: bool,
|
||||
|
||||
/// Reset config to defaults
|
||||
#[arg(long)]
|
||||
reset: bool,
|
||||
},
|
||||
|
||||
/// Manage speech-to-text models
|
||||
Models {
|
||||
/// Download the configured model
|
||||
#[arg(long)]
|
||||
download: bool,
|
||||
},
|
||||
|
||||
/// Show daemon status, loaded model, and version
|
||||
Status,
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
match cli.command {
|
||||
None | Some(Commands::Run) => cli::run_cmd::run(),
|
||||
|
||||
Some(Commands::Config { show, reset }) => {
|
||||
if show {
|
||||
cli::config_cmd::show()
|
||||
} else if reset {
|
||||
cli::config_cmd::reset()
|
||||
} else {
|
||||
cli::config_cmd::interactive()
|
||||
}
|
||||
}
|
||||
|
||||
Some(Commands::Models { download }) => {
|
||||
if download {
|
||||
cli::models_cmd::download()
|
||||
} else {
|
||||
cli::models_cmd::list()
|
||||
}
|
||||
}
|
||||
|
||||
Some(Commands::Status) => cli::status_cmd::status(),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
use anyhow::{Context, Result};
|
||||
use hf_hub::api::sync::Api;
|
||||
use std::path::PathBuf;
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// Known model definitions.
|
||||
pub struct ModelInfo {
|
||||
pub repo_id: &'static str,
|
||||
pub encoder_file: &'static str,
|
||||
pub encoder_data_file: Option<&'static str>,
|
||||
pub decoder_file: &'static str,
|
||||
pub vocab_file: &'static str,
|
||||
#[allow(dead_code)]
|
||||
pub preprocessor_file: Option<&'static str>,
|
||||
}
|
||||
|
||||
/// Get model info for a model name.
|
||||
pub fn get_model_info(model_name: &str) -> Result<ModelInfo> {
|
||||
match model_name {
|
||||
"parakeet-tdt-0.6b-v3" => Ok(ModelInfo {
|
||||
repo_id: "istupakov/parakeet-tdt-0.6b-v3-onnx",
|
||||
encoder_file: "encoder-model.onnx",
|
||||
encoder_data_file: Some("encoder-model.onnx.data"),
|
||||
decoder_file: "decoder_joint-model.onnx",
|
||||
vocab_file: "vocab.txt",
|
||||
preprocessor_file: None,
|
||||
}),
|
||||
_ => anyhow::bail!("Unknown model: {model_name}. Supported models: parakeet-tdt-0.6b-v3"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Paths to downloaded model files.
|
||||
pub struct ModelPaths {
|
||||
pub encoder: PathBuf,
|
||||
pub decoder: PathBuf,
|
||||
pub vocab: PathBuf,
|
||||
}
|
||||
|
||||
/// Ensure model files are downloaded and return their paths.
|
||||
/// Uses the standard HuggingFace cache directory.
|
||||
pub fn ensure_model(model_name: &str) -> Result<ModelPaths> {
|
||||
let model_info = get_model_info(model_name)?;
|
||||
let api = Api::new().context("Failed to create HuggingFace Hub API")?;
|
||||
let repo = api.model(model_info.repo_id.to_string());
|
||||
|
||||
info!("Ensuring model files for '{model_name}' from {}", model_info.repo_id);
|
||||
|
||||
// Download encoder
|
||||
info!("Checking encoder: {}", model_info.encoder_file);
|
||||
let encoder = repo
|
||||
.get(model_info.encoder_file)
|
||||
.with_context(|| format!("Failed to download {}", model_info.encoder_file))?;
|
||||
debug!("Encoder: {}", encoder.display());
|
||||
|
||||
// Download encoder data file if present
|
||||
if let Some(data_file) = model_info.encoder_data_file {
|
||||
info!("Checking encoder data: {data_file}");
|
||||
let data_path = repo
|
||||
.get(data_file)
|
||||
.with_context(|| format!("Failed to download {data_file}"))?;
|
||||
debug!("Encoder data: {}", data_path.display());
|
||||
}
|
||||
|
||||
// Download decoder
|
||||
info!("Checking decoder: {}", model_info.decoder_file);
|
||||
let decoder = repo
|
||||
.get(model_info.decoder_file)
|
||||
.with_context(|| format!("Failed to download {}", model_info.decoder_file))?;
|
||||
debug!("Decoder: {}", decoder.display());
|
||||
|
||||
// Download vocab
|
||||
info!("Checking vocab: {}", model_info.vocab_file);
|
||||
let vocab = repo
|
||||
.get(model_info.vocab_file)
|
||||
.with_context(|| format!("Failed to download {}", model_info.vocab_file))?;
|
||||
debug!("Vocab: {}", vocab.display());
|
||||
|
||||
Ok(ModelPaths {
|
||||
encoder,
|
||||
decoder,
|
||||
vocab,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if model files are already cached.
|
||||
pub fn is_model_cached(model_name: &str) -> bool {
|
||||
ensure_model(model_name).is_ok()
|
||||
}
|
||||
|
||||
/// List available models with their download status.
|
||||
pub fn list_models() -> Vec<(&'static str, bool)> {
|
||||
let models = ["parakeet-tdt-0.6b-v3"];
|
||||
models
|
||||
.iter()
|
||||
.map(|name| (*name, is_model_cached(name)))
|
||||
.collect()
|
||||
}
|
||||
+201
@@ -0,0 +1,201 @@
|
||||
use std::num::NonZeroU32;
|
||||
use tracing::{debug, error, info, warn};
|
||||
use winit::application::ApplicationHandler;
|
||||
use winit::dpi::{LogicalSize, PhysicalPosition};
|
||||
use winit::event::WindowEvent;
|
||||
use winit::event_loop::{ActiveEventLoop, EventLoop, EventLoopProxy};
|
||||
use winit::window::{Window, WindowAttributes, WindowId, WindowLevel};
|
||||
|
||||
use crate::config::OverlayPosition;
|
||||
|
||||
const OVERLAY_WIDTH: u32 = 200;
|
||||
const OVERLAY_HEIGHT: u32 = 36;
|
||||
|
||||
/// State of the overlay display.
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum OverlayState {
|
||||
Hidden,
|
||||
Recording,
|
||||
Transcribing,
|
||||
Done,
|
||||
Error,
|
||||
}
|
||||
|
||||
/// Events sent to the overlay from the coordinator.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OverlayEvent {
|
||||
SetState(OverlayState),
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
/// The overlay application handler for winit.
|
||||
struct OverlayApp {
|
||||
window: Option<std::rc::Rc<Window>>,
|
||||
surface: Option<softbuffer::Surface<std::rc::Rc<Window>, std::rc::Rc<Window>>>,
|
||||
state: OverlayState,
|
||||
position: OverlayPosition,
|
||||
}
|
||||
|
||||
impl OverlayApp {
|
||||
fn draw(&mut self) {
|
||||
let Some(surface) = &mut self.surface else { return };
|
||||
let Some(window) = &self.window else { return };
|
||||
|
||||
let size = window.inner_size();
|
||||
if size.width == 0 || size.height == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(w) = NonZeroU32::try_from(size.width) else { return };
|
||||
let Ok(h) = NonZeroU32::try_from(size.height) else { return };
|
||||
|
||||
if surface.resize(w, h).is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
let Ok(mut buffer) = surface.buffer_mut() else { return };
|
||||
|
||||
let color = match self.state {
|
||||
OverlayState::Hidden => 0x00000000,
|
||||
OverlayState::Recording => 0xFFDD3333, // Red
|
||||
OverlayState::Transcribing => 0xFFDDAA33, // Amber
|
||||
OverlayState::Done => 0xFF33AA33, // Green
|
||||
OverlayState::Error => 0xFFDD3333, // Red
|
||||
};
|
||||
|
||||
let width = size.width as usize;
|
||||
let height = size.height as usize;
|
||||
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let radius = 8;
|
||||
let in_corner = (x < radius || x >= width - radius)
|
||||
&& (y < radius || y >= height - radius);
|
||||
|
||||
let pixel = if in_corner {
|
||||
let cx = if x < radius { radius } else { width - radius - 1 };
|
||||
let cy = if y < radius { radius } else { height - radius - 1 };
|
||||
let dx = x as i32 - cx as i32;
|
||||
let dy = y as i32 - cy as i32;
|
||||
if dx * dx + dy * dy <= (radius * radius) as i32 {
|
||||
color
|
||||
} else {
|
||||
0x00000000
|
||||
}
|
||||
} else {
|
||||
color
|
||||
};
|
||||
|
||||
buffer[y * width + x] = pixel;
|
||||
}
|
||||
}
|
||||
|
||||
let _ = buffer.present();
|
||||
}
|
||||
|
||||
fn update_visibility(&self) {
|
||||
if let Some(window) = &self.window {
|
||||
let visible = self.state != OverlayState::Hidden;
|
||||
window.set_visible(visible);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ApplicationHandler<OverlayEvent> for OverlayApp {
|
||||
fn resumed(&mut self, event_loop: &ActiveEventLoop) {
|
||||
if self.window.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let attrs = WindowAttributes::default()
|
||||
.with_title("Mouth")
|
||||
.with_inner_size(LogicalSize::new(OVERLAY_WIDTH, OVERLAY_HEIGHT))
|
||||
.with_resizable(false)
|
||||
.with_decorations(false)
|
||||
.with_transparent(true)
|
||||
.with_window_level(WindowLevel::AlwaysOnTop)
|
||||
.with_visible(false);
|
||||
|
||||
match event_loop.create_window(attrs) {
|
||||
Ok(window) => {
|
||||
let window = std::rc::Rc::new(window);
|
||||
|
||||
// Position at top center of primary monitor
|
||||
if let Some(monitor) = window.current_monitor() {
|
||||
let screen_size = monitor.size();
|
||||
let pos = match self.position {
|
||||
OverlayPosition::Top => PhysicalPosition::new(
|
||||
(screen_size.width - OVERLAY_WIDTH) / 2,
|
||||
10,
|
||||
),
|
||||
OverlayPosition::Bottom => PhysicalPosition::new(
|
||||
(screen_size.width - OVERLAY_WIDTH) / 2,
|
||||
screen_size.height - OVERLAY_HEIGHT - 50,
|
||||
),
|
||||
OverlayPosition::None => PhysicalPosition::new(0, 0),
|
||||
};
|
||||
window.set_outer_position(pos);
|
||||
}
|
||||
|
||||
let context = softbuffer::Context::new(window.clone()).ok();
|
||||
let surface = context.and_then(|ctx| {
|
||||
softbuffer::Surface::new(&ctx, window.clone()).ok()
|
||||
});
|
||||
|
||||
if surface.is_none() {
|
||||
warn!("Could not create softbuffer surface — overlay rendering disabled");
|
||||
}
|
||||
|
||||
self.surface = surface;
|
||||
self.window = Some(window);
|
||||
info!("Overlay window created");
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to create overlay window: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn user_event(&mut self, event_loop: &ActiveEventLoop, event: OverlayEvent) {
|
||||
match event {
|
||||
OverlayEvent::SetState(state) => {
|
||||
debug!("Overlay state: {:?} -> {:?}", self.state, state);
|
||||
self.state = state;
|
||||
self.update_visibility();
|
||||
self.draw();
|
||||
}
|
||||
OverlayEvent::Shutdown => {
|
||||
info!("Overlay shutting down");
|
||||
event_loop.exit();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn window_event(&mut self, _event_loop: &ActiveEventLoop, _id: WindowId, event: WindowEvent) {
|
||||
if let WindowEvent::RedrawRequested = event {
|
||||
self.draw();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an event loop and return the proxy for sending events.
|
||||
pub fn create_event_loop() -> Result<(EventLoop<OverlayEvent>, EventLoopProxy<OverlayEvent>), winit::error::EventLoopError> {
|
||||
let event_loop: EventLoop<OverlayEvent> = EventLoop::with_user_event().build()?;
|
||||
let proxy = event_loop.create_proxy();
|
||||
Ok((event_loop, proxy))
|
||||
}
|
||||
|
||||
/// Run the event loop with the given position config.
|
||||
pub fn run_event_loop(
|
||||
event_loop: EventLoop<OverlayEvent>,
|
||||
position: OverlayPosition,
|
||||
) -> Result<(), winit::error::EventLoopError> {
|
||||
let mut app = OverlayApp {
|
||||
window: None,
|
||||
surface: None,
|
||||
state: OverlayState::Hidden,
|
||||
position,
|
||||
};
|
||||
|
||||
event_loop.run_app(&mut app)
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
use anyhow::{Context, Result};
|
||||
use arboard::Clipboard;
|
||||
use enigo::{Direction, Enigo, Key, Keyboard, Settings};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::config::PasteMethod;
|
||||
|
||||
/// Paste text using the configured method.
|
||||
pub fn paste_text(text: &str, method: &PasteMethod, keep_on_clipboard: bool) -> Result<()> {
|
||||
let mut clipboard = Clipboard::new().context("Failed to open clipboard")?;
|
||||
|
||||
// Save current clipboard content if we need to restore it
|
||||
let previous = if !keep_on_clipboard {
|
||||
clipboard.get_text().ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Set text to clipboard
|
||||
clipboard
|
||||
.set_text(text.to_string())
|
||||
.context("Failed to set clipboard text")?;
|
||||
debug!("Text set to clipboard ({} chars)", text.len());
|
||||
|
||||
if *method == PasteMethod::ClipboardOnly {
|
||||
info!("Text copied to clipboard (clipboard_only mode)");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Small delay to ensure clipboard is ready
|
||||
thread::sleep(Duration::from_millis(50));
|
||||
|
||||
// Simulate paste keystroke
|
||||
let mut enigo = Enigo::new(&Settings::default()).context("Failed to create enigo instance")?;
|
||||
|
||||
match method {
|
||||
PasteMethod::CtrlV => {
|
||||
debug!("Pasting via Ctrl+V");
|
||||
enigo.key(Key::Control, Direction::Press)?;
|
||||
enigo.key(Key::Unicode('v'), Direction::Click)?;
|
||||
enigo.key(Key::Control, Direction::Release)?;
|
||||
}
|
||||
PasteMethod::ShiftInsert => {
|
||||
debug!("Pasting via Shift+Insert");
|
||||
enigo.key(Key::Shift, Direction::Press)?;
|
||||
enigo.key(Key::Other(0x2D), Direction::Click)?; // VK_INSERT on Windows
|
||||
enigo.key(Key::Shift, Direction::Release)?;
|
||||
}
|
||||
PasteMethod::CtrlShiftV => {
|
||||
debug!("Pasting via Ctrl+Shift+V");
|
||||
enigo.key(Key::Control, Direction::Press)?;
|
||||
enigo.key(Key::Shift, Direction::Press)?;
|
||||
enigo.key(Key::Unicode('v'), Direction::Click)?;
|
||||
enigo.key(Key::Shift, Direction::Release)?;
|
||||
enigo.key(Key::Control, Direction::Release)?;
|
||||
}
|
||||
PasteMethod::ClipboardOnly => unreachable!(),
|
||||
}
|
||||
|
||||
// Restore previous clipboard content if needed
|
||||
if let Some(prev) = previous {
|
||||
thread::sleep(Duration::from_millis(100));
|
||||
let _ = clipboard.set_text(prev);
|
||||
debug!("Previous clipboard content restored");
|
||||
}
|
||||
|
||||
info!("Text pasted ({} chars)", text.len());
|
||||
Ok(())
|
||||
}
|
||||
+265
@@ -0,0 +1,265 @@
|
||||
use anyhow::{Context, Result};
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use cpal::{Device, SampleFormat, SampleRate, StreamConfig};
|
||||
use rubato::{FftFixedIn, Resampler};
|
||||
use std::sync::mpsc;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
const TARGET_SAMPLE_RATE: u32 = 16000;
|
||||
|
||||
/// Commands sent to the recorder.
|
||||
#[derive(Debug)]
|
||||
pub enum RecorderCommand {
|
||||
Start,
|
||||
Stop,
|
||||
Shutdown,
|
||||
}
|
||||
|
||||
/// Audio data produced by the recorder.
|
||||
#[derive(Debug)]
|
||||
pub struct AudioData {
|
||||
/// Mono f32 samples at 16kHz
|
||||
pub samples: Vec<f32>,
|
||||
#[allow(dead_code)]
|
||||
pub sample_rate: u32,
|
||||
}
|
||||
|
||||
/// Find the audio input device by name, or use the default.
|
||||
fn get_input_device(name: Option<&str>) -> Result<Device> {
|
||||
let host = cpal::default_host();
|
||||
|
||||
if let Some(name) = name {
|
||||
let devices = host.input_devices().context("Failed to enumerate input devices")?;
|
||||
for device in devices {
|
||||
if let Ok(dev_name) = device.name() {
|
||||
if dev_name.to_lowercase().contains(&name.to_lowercase()) {
|
||||
info!("Using audio input device: {dev_name}");
|
||||
return Ok(device);
|
||||
}
|
||||
}
|
||||
}
|
||||
warn!("Audio device '{name}' not found, falling back to default");
|
||||
}
|
||||
|
||||
host.default_input_device().context("No default input device available")
|
||||
}
|
||||
|
||||
/// Convert interleaved multi-channel samples to mono f32.
|
||||
fn to_mono_f32(data: &[f32], channels: u16) -> Vec<f32> {
|
||||
if channels == 1 {
|
||||
return data.to_vec();
|
||||
}
|
||||
data.chunks(channels as usize)
|
||||
.map(|frame| frame.iter().sum::<f32>() / channels as f32)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Resample audio from source rate to target rate.
|
||||
fn resample(samples: &[f32], from_rate: u32, to_rate: u32) -> Result<Vec<f32>> {
|
||||
if from_rate == to_rate {
|
||||
return Ok(samples.to_vec());
|
||||
}
|
||||
|
||||
let chunk_size = 1024;
|
||||
let mut resampler = FftFixedIn::<f32>::new(
|
||||
from_rate as usize,
|
||||
to_rate as usize,
|
||||
chunk_size,
|
||||
1, // sub-chunks
|
||||
1, // mono
|
||||
)?;
|
||||
|
||||
let mut output = Vec::new();
|
||||
let mut pos = 0;
|
||||
|
||||
while pos + chunk_size <= samples.len() {
|
||||
let chunk = &samples[pos..pos + chunk_size];
|
||||
let result = resampler.process(&[chunk], None)?;
|
||||
output.extend_from_slice(&result[0]);
|
||||
pos += chunk_size;
|
||||
}
|
||||
|
||||
// Handle remaining samples by padding with zeros
|
||||
if pos < samples.len() {
|
||||
let mut last_chunk = samples[pos..].to_vec();
|
||||
last_chunk.resize(chunk_size, 0.0);
|
||||
let result = resampler.process(&[&last_chunk], None)?;
|
||||
let remaining_ratio = (samples.len() - pos) as f32 / chunk_size as f32;
|
||||
let take = (result[0].len() as f32 * remaining_ratio) as usize;
|
||||
output.extend_from_slice(&result[0][..take]);
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Run the audio recorder in a loop, responding to commands.
|
||||
/// This should be called from a dedicated thread.
|
||||
pub fn run(
|
||||
device_name: Option<String>,
|
||||
cmd_rx: mpsc::Receiver<RecorderCommand>,
|
||||
audio_tx: mpsc::Sender<AudioData>,
|
||||
) {
|
||||
let device = match get_input_device(device_name.as_deref()) {
|
||||
Ok(d) => d,
|
||||
Err(e) => {
|
||||
error!("Failed to get audio input device: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let dev_name = device.name().unwrap_or_else(|_| "unknown".into());
|
||||
info!("Audio recorder using device: {dev_name}");
|
||||
|
||||
let config = match device.default_input_config() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to get default input config: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let source_sample_rate = config.sample_rate().0;
|
||||
let channels = config.channels();
|
||||
debug!("Input config: {source_sample_rate}Hz, {channels}ch, {:?}", config.sample_format());
|
||||
|
||||
loop {
|
||||
// Wait for a Start command
|
||||
match cmd_rx.recv() {
|
||||
Ok(RecorderCommand::Start) => {
|
||||
debug!("Recording started");
|
||||
}
|
||||
Ok(RecorderCommand::Shutdown) => {
|
||||
info!("Recorder shutting down");
|
||||
return;
|
||||
}
|
||||
Ok(RecorderCommand::Stop) => {
|
||||
// Ignore stop when not recording
|
||||
continue;
|
||||
}
|
||||
Err(_) => {
|
||||
info!("Recorder channel closed, shutting down");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Record until Stop or Shutdown
|
||||
let buffer: Arc<Mutex<Vec<f32>>> = Arc::new(Mutex::new(Vec::new()));
|
||||
let buffer_clone = Arc::clone(&buffer);
|
||||
let ch = channels;
|
||||
|
||||
let stream_config = StreamConfig {
|
||||
channels,
|
||||
sample_rate: SampleRate(source_sample_rate),
|
||||
buffer_size: cpal::BufferSize::Default,
|
||||
};
|
||||
|
||||
let stream = match config.sample_format() {
|
||||
SampleFormat::F32 => device.build_input_stream(
|
||||
&stream_config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let mono = to_mono_f32(data, ch);
|
||||
if let Ok(mut buf) = buffer_clone.lock() {
|
||||
buf.extend_from_slice(&mono);
|
||||
}
|
||||
},
|
||||
|err| error!("Audio stream error: {err}"),
|
||||
None,
|
||||
),
|
||||
SampleFormat::I16 => {
|
||||
let buffer_clone = Arc::clone(&buffer);
|
||||
device.build_input_stream(
|
||||
&stream_config,
|
||||
move |data: &[i16], _: &cpal::InputCallbackInfo| {
|
||||
let f32_data: Vec<f32> = data.iter().map(|&s| s as f32 / i16::MAX as f32).collect();
|
||||
let mono = to_mono_f32(&f32_data, ch);
|
||||
if let Ok(mut buf) = buffer_clone.lock() {
|
||||
buf.extend_from_slice(&mono);
|
||||
}
|
||||
},
|
||||
|err| error!("Audio stream error: {err}"),
|
||||
None,
|
||||
)
|
||||
}
|
||||
format => {
|
||||
error!("Unsupported sample format: {format:?}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let stream = match stream {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to build input stream: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = stream.play() {
|
||||
error!("Failed to start audio stream: {e}");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Wait for Stop or Shutdown
|
||||
let should_shutdown = loop {
|
||||
match cmd_rx.recv() {
|
||||
Ok(RecorderCommand::Stop) => {
|
||||
debug!("Recording stopped");
|
||||
break false;
|
||||
}
|
||||
Ok(RecorderCommand::Shutdown) => {
|
||||
info!("Recorder shutting down during recording");
|
||||
break true;
|
||||
}
|
||||
Ok(RecorderCommand::Start) => {
|
||||
// Ignore duplicate start
|
||||
continue;
|
||||
}
|
||||
Err(_) => {
|
||||
break true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Stop the stream
|
||||
drop(stream);
|
||||
|
||||
if should_shutdown {
|
||||
return;
|
||||
}
|
||||
|
||||
// Get recorded audio
|
||||
let raw_samples = {
|
||||
let buf = buffer.lock().unwrap();
|
||||
buf.clone()
|
||||
};
|
||||
|
||||
if raw_samples.is_empty() {
|
||||
warn!("No audio recorded");
|
||||
continue;
|
||||
}
|
||||
|
||||
debug!("Recorded {} samples at {}Hz", raw_samples.len(), source_sample_rate);
|
||||
|
||||
// Resample to 16kHz
|
||||
let samples = match resample(&raw_samples, source_sample_rate, TARGET_SAMPLE_RATE) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
error!("Failed to resample audio: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("Resampled to {} samples at {}Hz", samples.len(), TARGET_SAMPLE_RATE);
|
||||
|
||||
let audio = AudioData {
|
||||
samples,
|
||||
sample_rate: TARGET_SAMPLE_RATE,
|
||||
};
|
||||
|
||||
if audio_tx.send(audio).is_err() {
|
||||
error!("Failed to send audio data");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,424 @@
|
||||
use anyhow::{Context, Result};
|
||||
use ndarray::{Array2, Array3};
|
||||
use ort::session::{Session, SessionInputValue};
|
||||
use ort::value::Value;
|
||||
use std::borrow::Cow;
|
||||
use std::path::Path;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::config::Accelerator;
|
||||
use crate::model_cache::ModelPaths;
|
||||
|
||||
// Audio preprocessing constants (NeMo-style)
|
||||
const SAMPLE_RATE: usize = 16000;
|
||||
const N_FFT: usize = 512;
|
||||
const WIN_LENGTH: usize = 400;
|
||||
const HOP_LENGTH: usize = 160;
|
||||
const N_MELS: usize = 128;
|
||||
const PRE_EMPHASIS: f32 = 0.97;
|
||||
|
||||
/// The transcription engine.
|
||||
pub struct Transcriber {
|
||||
encoder: Session,
|
||||
decoder: Session,
|
||||
vocab: Vec<String>,
|
||||
blank_id: i64,
|
||||
vocab_size: usize,
|
||||
}
|
||||
|
||||
fn make_input<'a, V: Into<Value>>(name: &'static str, value: V) -> (Cow<'a, str>, SessionInputValue<'a>) {
|
||||
(Cow::Borrowed(name), SessionInputValue::Owned(value.into().into_dyn()))
|
||||
}
|
||||
|
||||
impl Transcriber {
|
||||
/// Load the transcription model from the given paths.
|
||||
pub fn new(paths: &ModelPaths, accelerator: &Accelerator, gpu_device: u32) -> Result<Self> {
|
||||
info!("Loading transcription model...");
|
||||
|
||||
let encoder = build_session(&paths.encoder, accelerator, gpu_device)
|
||||
.context("Failed to load encoder")?;
|
||||
info!("Encoder loaded");
|
||||
|
||||
let decoder = build_session(&paths.decoder, accelerator, gpu_device)
|
||||
.context("Failed to load decoder")?;
|
||||
info!("Decoder loaded");
|
||||
|
||||
let vocab = load_vocab(&paths.vocab)?;
|
||||
let vocab_size = vocab.len();
|
||||
let blank_id = (vocab_size - 1) as i64; // <blk> is the last token
|
||||
info!("Vocab loaded: {vocab_size} tokens, blank_id={blank_id}");
|
||||
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
vocab,
|
||||
blank_id,
|
||||
vocab_size,
|
||||
})
|
||||
}
|
||||
|
||||
/// Transcribe audio samples (mono f32 at 16kHz).
|
||||
pub fn transcribe(&mut self, samples: &[f32]) -> Result<String> {
|
||||
if samples.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
let duration = samples.len() as f32 / SAMPLE_RATE as f32;
|
||||
info!("Transcribing {duration:.1}s of audio...");
|
||||
|
||||
// Step 1: Compute mel spectrogram
|
||||
let features = compute_mel_spectrogram(samples);
|
||||
let num_frames = features.ncols();
|
||||
debug!("Mel spectrogram: {N_MELS}x{num_frames} frames");
|
||||
|
||||
// Step 2: Run encoder
|
||||
let (encoder_output, feat_dim, encoded_length) = self.run_encoder(&features)?;
|
||||
debug!("Encoded: feat_dim={feat_dim}, length={encoded_length}");
|
||||
|
||||
// Step 3: TDT greedy decode
|
||||
let tokens = self.tdt_greedy_decode(&encoder_output, feat_dim, encoded_length)?;
|
||||
debug!("Decoded {} tokens", tokens.len());
|
||||
|
||||
// Step 4: Convert tokens to text
|
||||
let text = self.tokens_to_text(&tokens);
|
||||
info!("Transcription: \"{text}\"");
|
||||
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
fn run_encoder(&mut self, features: &Array2<f32>) -> Result<(Vec<f32>, usize, usize)> {
|
||||
let num_frames = features.ncols();
|
||||
|
||||
// Shape: [batch=1, n_mels=128, time_frames]
|
||||
let input = features
|
||||
.clone()
|
||||
.into_shape_with_order((1, N_MELS, num_frames))?;
|
||||
|
||||
let length_data = ndarray::Array1::from_vec(vec![num_frames as i64]);
|
||||
|
||||
let input_value = Value::from_array(input)?;
|
||||
let length_value = Value::from_array(length_data)?;
|
||||
|
||||
let outputs = self.encoder.run(vec![
|
||||
make_input("audio_signal", input_value.into_dyn()),
|
||||
make_input("length", length_value.into_dyn()),
|
||||
])?;
|
||||
|
||||
let (enc_shape, enc_data) = outputs[0]
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to extract encoder output: {e}"))?;
|
||||
|
||||
let (_, len_data) = outputs[1]
|
||||
.try_extract_tensor::<i64>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to extract encoded lengths: {e}"))?;
|
||||
|
||||
let encoded_length = len_data[0] as usize;
|
||||
// Shape: [1, feat_dim, encoded_time]
|
||||
let dims: Vec<usize> = enc_shape.iter().map(|&d| d as usize).collect();
|
||||
let feat_dim = if dims.len() == 3 { dims[1] } else { dims[dims.len() - 1] };
|
||||
|
||||
debug!("Encoder output shape: {:?}", dims);
|
||||
Ok((enc_data.to_vec(), feat_dim, encoded_length))
|
||||
}
|
||||
|
||||
fn tdt_greedy_decode(&mut self, encoder_output: &[f32], feat_dim: usize, encoded_length: usize) -> Result<Vec<i64>> {
|
||||
// Determine decoder LSTM state dimensions by inspecting input metadata
|
||||
// Default fallback values
|
||||
let mut state_shape: [usize; 3] = [1, 1, 640];
|
||||
|
||||
for input in self.decoder.inputs() {
|
||||
let name = input.name();
|
||||
if name == "input_states_1" || name == "input_states_2" {
|
||||
let dtype = input.dtype();
|
||||
// Try to extract shape from the ValueType
|
||||
if let ort::value::ValueType::Tensor { shape, .. } = dtype {
|
||||
if shape.len() == 3 {
|
||||
// Use known dims (skip dynamic ones which are -1)
|
||||
for (i, &dim) in shape.iter().enumerate() {
|
||||
if dim > 0 {
|
||||
state_shape[i] = dim as usize;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
break; // both states have same shape
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Decoder state shape: {:?}", state_shape);
|
||||
|
||||
let mut state1 = Array3::<f32>::zeros(state_shape);
|
||||
let mut state2 = Array3::<f32>::zeros(state_shape);
|
||||
let mut prev_token = self.blank_id;
|
||||
let mut tokens = Vec::new();
|
||||
let mut t = 0;
|
||||
|
||||
let max_steps = encoded_length * 10;
|
||||
let mut step = 0;
|
||||
|
||||
while t < encoded_length && step < max_steps {
|
||||
step += 1;
|
||||
|
||||
// Extract current encoder frame: shape [1, feat_dim, 1]
|
||||
let mut frame_data = vec![0.0f32; feat_dim];
|
||||
for f in 0..feat_dim {
|
||||
// encoder_output layout: [1, feat_dim, time] -> flat index
|
||||
frame_data[f] = encoder_output[f * encoded_length + t];
|
||||
}
|
||||
let frame = Array3::from_shape_vec([1, feat_dim, 1], frame_data)?;
|
||||
|
||||
let targets = ndarray::Array2::from_shape_vec((1, 1), vec![prev_token])?;
|
||||
let target_length = ndarray::Array1::from_vec(vec![1i64]);
|
||||
|
||||
let outputs = self.decoder.run(vec![
|
||||
make_input("encoder_outputs", Value::from_array(frame)?.into_dyn()),
|
||||
make_input("targets", Value::from_array(targets)?.into_dyn()),
|
||||
make_input("target_length", Value::from_array(target_length)?.into_dyn()),
|
||||
make_input("input_states_1", Value::from_array(state1.clone())?.into_dyn()),
|
||||
make_input("input_states_2", Value::from_array(state2.clone())?.into_dyn()),
|
||||
])?;
|
||||
|
||||
let (_, output_data) = outputs["outputs"]
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to extract decoder output: {e}"))?;
|
||||
|
||||
// Split into token logits and duration logits
|
||||
let token_logits = &output_data[..self.vocab_size];
|
||||
let duration_logits = &output_data[self.vocab_size..];
|
||||
|
||||
let token_id = argmax(token_logits) as i64;
|
||||
let duration = if !duration_logits.is_empty() {
|
||||
argmax(duration_logits)
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
if token_id != self.blank_id {
|
||||
tokens.push(token_id);
|
||||
prev_token = token_id;
|
||||
|
||||
// Update states
|
||||
let (s1_shape, s1_data) = outputs["output_states_1"]
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to extract state 1: {e}"))?;
|
||||
let s1_dims: Vec<usize> = s1_shape.iter().map(|&d| d as usize).collect();
|
||||
if s1_dims.len() == 3 {
|
||||
state1 = Array3::from_shape_vec([s1_dims[0], s1_dims[1], s1_dims[2]], s1_data.to_vec())?;
|
||||
}
|
||||
|
||||
let (s2_shape, s2_data) = outputs["output_states_2"]
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to extract state 2: {e}"))?;
|
||||
let s2_dims: Vec<usize> = s2_shape.iter().map(|&d| d as usize).collect();
|
||||
if s2_dims.len() == 3 {
|
||||
state2 = Array3::from_shape_vec([s2_dims[0], s2_dims[1], s2_dims[2]], s2_data.to_vec())?;
|
||||
}
|
||||
}
|
||||
|
||||
if duration > 0 {
|
||||
t += duration;
|
||||
} else if token_id == self.blank_id {
|
||||
t += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tokens)
|
||||
}
|
||||
|
||||
fn tokens_to_text(&self, tokens: &[i64]) -> String {
|
||||
let mut text = String::new();
|
||||
for &token_id in tokens {
|
||||
if token_id >= 0 && (token_id as usize) < self.vocab.len() {
|
||||
let token = &self.vocab[token_id as usize];
|
||||
if token.starts_with('<') && token.ends_with('>') {
|
||||
continue;
|
||||
}
|
||||
text.push_str(token);
|
||||
}
|
||||
}
|
||||
|
||||
text = text.replace('\u{2581}', " ");
|
||||
let text = text.trim().to_string();
|
||||
|
||||
let mut result = String::with_capacity(text.len());
|
||||
let mut prev_space = false;
|
||||
for ch in text.chars() {
|
||||
if ch == ' ' {
|
||||
if !prev_space {
|
||||
result.push(ch);
|
||||
}
|
||||
prev_space = true;
|
||||
} else {
|
||||
result.push(ch);
|
||||
prev_space = false;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
fn build_session(model_path: &Path, accelerator: &Accelerator, _gpu_device: u32) -> Result<Session> {
|
||||
let mut builder = Session::builder()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create session builder: {e}"))?
|
||||
.with_intra_threads(num_cpus::get().min(4))
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set thread count: {e}"))?;
|
||||
|
||||
match accelerator {
|
||||
Accelerator::Auto | Accelerator::Cpu => {}
|
||||
Accelerator::Cuda => {
|
||||
info!("CUDA requested — falling back to CPU (CUDA EP not yet configured)");
|
||||
}
|
||||
Accelerator::DirectMl => {
|
||||
info!("DirectML requested — falling back to CPU (DirectML EP not yet configured)");
|
||||
}
|
||||
}
|
||||
|
||||
let session = builder
|
||||
.commit_from_file(model_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load ONNX model from {}: {e}", model_path.display()))?;
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
fn load_vocab(path: &Path) -> Result<Vec<String>> {
|
||||
let content = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read vocab from {}", path.display()))?;
|
||||
|
||||
let mut vocab = Vec::new();
|
||||
for line in content.lines() {
|
||||
let token = line.split_whitespace().next().unwrap_or("").to_string();
|
||||
vocab.push(token);
|
||||
}
|
||||
|
||||
Ok(vocab)
|
||||
}
|
||||
|
||||
fn argmax(slice: &[f32]) -> usize {
|
||||
slice
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|(i, _)| i)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
// --- Mel spectrogram computation ---
|
||||
|
||||
fn compute_mel_spectrogram(samples: &[f32]) -> Array2<f32> {
|
||||
let mut emphasized = Vec::with_capacity(samples.len());
|
||||
emphasized.push(samples[0]);
|
||||
for i in 1..samples.len() {
|
||||
emphasized.push(samples[i] - PRE_EMPHASIS * samples[i - 1]);
|
||||
}
|
||||
|
||||
let num_frames = (emphasized.len().saturating_sub(WIN_LENGTH)) / HOP_LENGTH + 1;
|
||||
let fft_size = N_FFT / 2 + 1;
|
||||
|
||||
let window: Vec<f32> = (0..WIN_LENGTH)
|
||||
.map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / (WIN_LENGTH - 1) as f32).cos()))
|
||||
.collect();
|
||||
|
||||
let mel_bank = create_mel_filterbank(SAMPLE_RATE as f32, N_FFT, N_MELS);
|
||||
|
||||
let mut mel_spec = Array2::<f32>::zeros((N_MELS, num_frames));
|
||||
|
||||
for frame_idx in 0..num_frames {
|
||||
let start = frame_idx * HOP_LENGTH;
|
||||
|
||||
let mut windowed = vec![0.0f32; N_FFT];
|
||||
for i in 0..WIN_LENGTH {
|
||||
if start + i < emphasized.len() {
|
||||
windowed[i] = emphasized[start + i] * window[i];
|
||||
}
|
||||
}
|
||||
|
||||
let power_spectrum = compute_power_spectrum(&windowed, fft_size);
|
||||
|
||||
for mel_idx in 0..N_MELS {
|
||||
let mut energy = 0.0f32;
|
||||
for k in 0..fft_size {
|
||||
energy += mel_bank[mel_idx][k] * power_spectrum[k];
|
||||
}
|
||||
mel_spec[[mel_idx, frame_idx]] = (energy + 2.0f32.powi(-24)).ln();
|
||||
}
|
||||
}
|
||||
|
||||
// Per-utterance CMVN
|
||||
for mel_idx in 0..N_MELS {
|
||||
let row = mel_spec.row(mel_idx);
|
||||
let mean = row.mean().unwrap_or(0.0);
|
||||
let variance = row.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / num_frames as f32;
|
||||
let std = (variance + 1e-5).sqrt();
|
||||
|
||||
for frame_idx in 0..num_frames {
|
||||
mel_spec[[mel_idx, frame_idx]] = (mel_spec[[mel_idx, frame_idx]] - mean) / std;
|
||||
}
|
||||
}
|
||||
|
||||
mel_spec
|
||||
}
|
||||
|
||||
fn compute_power_spectrum(signal: &[f32], output_size: usize) -> Vec<f32> {
|
||||
let n = signal.len();
|
||||
let mut spectrum = Vec::with_capacity(output_size);
|
||||
|
||||
for k in 0..output_size {
|
||||
let mut real = 0.0f32;
|
||||
let mut imag = 0.0f32;
|
||||
for (t, &sample) in signal.iter().enumerate() {
|
||||
let angle = -2.0 * std::f32::consts::PI * k as f32 * t as f32 / n as f32;
|
||||
real += sample * angle.cos();
|
||||
imag += sample * angle.sin();
|
||||
}
|
||||
spectrum.push(real * real + imag * imag);
|
||||
}
|
||||
|
||||
spectrum
|
||||
}
|
||||
|
||||
fn create_mel_filterbank(sample_rate: f32, n_fft: usize, n_mels: usize) -> Vec<Vec<f32>> {
|
||||
let fft_size = n_fft / 2 + 1;
|
||||
|
||||
let mel_low = hz_to_mel(0.0);
|
||||
let mel_high = hz_to_mel(sample_rate / 2.0);
|
||||
|
||||
let mel_points: Vec<f32> = (0..n_mels + 2)
|
||||
.map(|i| mel_low + (mel_high - mel_low) * i as f32 / (n_mels + 1) as f32)
|
||||
.collect();
|
||||
|
||||
let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
|
||||
let bin_points: Vec<usize> = hz_points
|
||||
.iter()
|
||||
.map(|&hz| ((n_fft as f32 + 1.0) * hz / sample_rate).floor() as usize)
|
||||
.collect();
|
||||
|
||||
let mut filterbank = vec![vec![0.0f32; fft_size]; n_mels];
|
||||
|
||||
for m in 0..n_mels {
|
||||
let left = bin_points[m];
|
||||
let center = bin_points[m + 1];
|
||||
let right = bin_points[m + 2];
|
||||
|
||||
for k in left..center {
|
||||
if center > left {
|
||||
filterbank[m][k] = (k - left) as f32 / (center - left) as f32;
|
||||
}
|
||||
}
|
||||
for k in center..right {
|
||||
if right > center {
|
||||
filterbank[m][k] = (right - k) as f32 / (right - center) as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filterbank
|
||||
}
|
||||
|
||||
fn hz_to_mel(hz: f32) -> f32 {
|
||||
2595.0 * (1.0 + hz / 700.0).log10()
|
||||
}
|
||||
|
||||
fn mel_to_hz(mel: f32) -> f32 {
|
||||
700.0 * (10.0f32.powf(mel / 2595.0) - 1.0)
|
||||
}
|
||||
+149
@@ -0,0 +1,149 @@
|
||||
use anyhow::Result;
|
||||
use ndarray::{Array1, Array2, Array3};
|
||||
use ort::session::{Session, SessionInputValue};
|
||||
use ort::value::Value;
|
||||
use std::borrow::Cow;
|
||||
use tracing::{debug, info};
|
||||
|
||||
const SILERO_SAMPLE_RATE: i64 = 16000;
|
||||
const WINDOW_SIZE: usize = 512; // 32ms at 16kHz
|
||||
const THRESHOLD: f32 = 0.5;
|
||||
const MIN_SPEECH_DURATION_MS: usize = 250;
|
||||
const MIN_SILENCE_DURATION_MS: usize = 300;
|
||||
const SPEECH_PAD_MS: usize = 100;
|
||||
|
||||
fn make_input<'a, V: Into<Value>>(name: &'static str, value: V) -> (Cow<'a, str>, SessionInputValue<'a>) {
|
||||
(Cow::Borrowed(name), SessionInputValue::Owned(value.into().into_dyn()))
|
||||
}
|
||||
|
||||
/// Voice Activity Detector using Silero VAD v4.
|
||||
pub struct Vad {
|
||||
session: Session,
|
||||
}
|
||||
|
||||
impl Vad {
|
||||
/// Create a new VAD instance from an ONNX model file.
|
||||
pub fn new(model_path: &str) -> Result<Self> {
|
||||
let session = Session::builder()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create VAD session builder: {e}"))?
|
||||
.with_intra_threads(1)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to set VAD threads: {e}"))?
|
||||
.commit_from_file(model_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to load VAD model from {model_path}: {e}"))?;
|
||||
|
||||
info!("Silero VAD loaded from {model_path}");
|
||||
Ok(Self { session })
|
||||
}
|
||||
|
||||
/// Filter audio to keep only speech segments.
|
||||
pub fn filter_speech(&mut self, samples: &[f32]) -> Result<Vec<f32>> {
|
||||
if samples.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let probabilities = self.get_speech_probabilities(samples)?;
|
||||
let segments = self.find_speech_segments(&probabilities, samples.len());
|
||||
|
||||
if segments.is_empty() {
|
||||
debug!("VAD: no speech detected");
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut result = Vec::new();
|
||||
for (start, end) in &segments {
|
||||
debug!("VAD: speech segment {start}..{end} ({:.1}s..{:.1}s)",
|
||||
*start as f32 / SILERO_SAMPLE_RATE as f32,
|
||||
*end as f32 / SILERO_SAMPLE_RATE as f32);
|
||||
result.extend_from_slice(&samples[*start..*end]);
|
||||
}
|
||||
|
||||
let original_duration = samples.len() as f32 / SILERO_SAMPLE_RATE as f32;
|
||||
let filtered_duration = result.len() as f32 / SILERO_SAMPLE_RATE as f32;
|
||||
debug!("VAD: {original_duration:.1}s -> {filtered_duration:.1}s ({:.0}% kept)",
|
||||
filtered_duration / original_duration * 100.0);
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn get_speech_probabilities(&mut self, samples: &[f32]) -> Result<Vec<f32>> {
|
||||
let mut probabilities = Vec::new();
|
||||
|
||||
let mut state = Array3::<f32>::zeros((2, 1, 128));
|
||||
let sr = Array1::from_vec(vec![SILERO_SAMPLE_RATE]);
|
||||
|
||||
for chunk in samples.chunks(WINDOW_SIZE) {
|
||||
let mut window = chunk.to_vec();
|
||||
if window.len() < WINDOW_SIZE {
|
||||
window.resize(WINDOW_SIZE, 0.0);
|
||||
}
|
||||
|
||||
let input = Array2::from_shape_vec((1, WINDOW_SIZE), window)?;
|
||||
|
||||
let outputs = self.session.run(vec![
|
||||
make_input("input", Value::from_array(input)?.into_dyn()),
|
||||
make_input("state", Value::from_array(state.clone())?.into_dyn()),
|
||||
make_input("sr", Value::from_array(sr.clone())?.into_dyn()),
|
||||
])?;
|
||||
|
||||
let (_, output_data) = outputs["output"]
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to extract VAD output: {e}"))?;
|
||||
probabilities.push(output_data[0]);
|
||||
|
||||
let (state_shape, state_data) = outputs["stateN"]
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to extract VAD state: {e}"))?;
|
||||
let dims: Vec<usize> = state_shape.iter().map(|&d| d as usize).collect();
|
||||
if dims.len() == 3 {
|
||||
state = Array3::from_shape_vec([dims[0], dims[1], dims[2]], state_data.to_vec())?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(probabilities)
|
||||
}
|
||||
|
||||
fn find_speech_segments(&self, probabilities: &[f32], total_samples: usize) -> Vec<(usize, usize)> {
|
||||
let samples_per_window = WINDOW_SIZE;
|
||||
let min_speech_windows = MIN_SPEECH_DURATION_MS * SILERO_SAMPLE_RATE as usize / 1000 / samples_per_window;
|
||||
let min_silence_windows = MIN_SILENCE_DURATION_MS * SILERO_SAMPLE_RATE as usize / 1000 / samples_per_window;
|
||||
let pad_samples = SPEECH_PAD_MS * SILERO_SAMPLE_RATE as usize / 1000;
|
||||
|
||||
let mut segments = Vec::new();
|
||||
let mut in_speech = false;
|
||||
let mut speech_start = 0;
|
||||
let mut silence_count = 0;
|
||||
let mut speech_count = 0;
|
||||
|
||||
for (i, &prob) in probabilities.iter().enumerate() {
|
||||
if prob >= THRESHOLD {
|
||||
if !in_speech {
|
||||
speech_start = i;
|
||||
speech_count = 0;
|
||||
}
|
||||
in_speech = true;
|
||||
speech_count += 1;
|
||||
silence_count = 0;
|
||||
} else if in_speech {
|
||||
silence_count += 1;
|
||||
if silence_count >= min_silence_windows {
|
||||
if speech_count >= min_speech_windows {
|
||||
let start = (speech_start * samples_per_window).saturating_sub(pad_samples);
|
||||
let end = ((i - silence_count + 1) * samples_per_window + pad_samples).min(total_samples);
|
||||
segments.push((start, end));
|
||||
}
|
||||
in_speech = false;
|
||||
silence_count = 0;
|
||||
speech_count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if in_speech && speech_count >= min_speech_windows {
|
||||
let start = (speech_start * samples_per_window).saturating_sub(pad_samples);
|
||||
let end = total_samples;
|
||||
segments.push((start, end));
|
||||
}
|
||||
|
||||
segments
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user