| |
| """BirdNET Audio Classification Script |
| |
| This script loads a WAV file and uses the BirdNET ONNX model to predict bird species. |
| The model expects audio input of shape [batch_size, 144000] (3 seconds at 48kHz). |
| |
| Created using Copilot. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import librosa |
| import onnxruntime as ort |
| import argparse |
| import os |
| from collections import defaultdict |
|
|
|
|
| def load_audio( |
| file_path: str, target_sr: int = 48000, duration: float = 3.0 |
| ) -> np.ndarray: |
| """ |
| Load and preprocess audio file for BirdNET model. |
| |
| Args: |
| file_path (str): Path to the audio file |
| target_sr (int): Target sample rate (48kHz for BirdNET) |
| duration (float): Duration in seconds (3.0 for BirdNET) |
| |
| Returns: |
| np.ndarray: Preprocessed audio array of shape [144000] |
| """ |
| try: |
| |
| audio, sr = librosa.load(file_path, sr=target_sr, duration=duration) |
|
|
| |
| target_length = int(target_sr * duration) |
|
|
| if len(audio) < target_length: |
| |
| audio = np.pad(audio, (0, target_length - len(audio))) |
| elif len(audio) > target_length: |
| |
| audio = audio[:target_length] |
|
|
| return audio.astype(np.float32) |
|
|
| except Exception as e: |
| raise RuntimeError(f"Error loading audio file {file_path}: {str(e)}") |
|
|
|
|
| def load_labels(labels_path: str) -> list[str]: |
| """ |
| Load BirdNET species labels from the labels file. |
| |
| Args: |
| labels_path (str): Path to the labels file |
| |
| Returns: |
| list[str]: List of species names |
| """ |
| try: |
| labels = [] |
| with open(labels_path, "r", encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| |
| |
| if "_" in line: |
| common_name = line.split("_", 1)[1] |
| labels.append(common_name) |
| else: |
| labels.append(line) |
| return labels |
| except Exception as e: |
| raise RuntimeError(f"Error loading labels file {labels_path}: {str(e)}") |
|
|
|
|
| def load_audio_full(file_path: str, target_sr: int = 48000) -> np.ndarray: |
| """ |
| Load full audio file for moving window analysis. |
| |
| Args: |
| file_path (str): Path to the audio file |
| target_sr (int): Target sample rate (48kHz for BirdNET) |
| |
| Returns: |
| np.ndarray: Full audio array |
| """ |
| try: |
| |
| audio, sr = librosa.load(file_path, sr=target_sr) |
| return audio.astype(np.float32) |
| except Exception as e: |
| raise RuntimeError(f"Error loading audio file {file_path}: {str(e)}") |
|
|
|
|
| def create_audio_windows( |
| audio: np.ndarray, window_size: int = 144000, overlap: float = 0.5 |
| ) -> tuple[np.ndarray, list[float]]: |
| """ |
| Create overlapping windows from audio for analysis. |
| |
| Args: |
| audio (np.ndarray): Full audio array |
| window_size (int): Size of each window (144000 for 3 seconds at 48kHz) |
| overlap (float): Overlap ratio (0.5 = 50% overlap) |
| |
| Returns: |
| tuple[np.ndarray, list[float]]: (windows array, timestamps) |
| """ |
| step_size = int(window_size * (1 - overlap)) |
| windows = [] |
| timestamps = [] |
|
|
| for start in range(0, len(audio) - window_size + 1, step_size): |
| end = start + window_size |
| window = audio[start:end] |
|
|
| |
| if len(window) == window_size: |
| windows.append(window) |
| |
| timestamps.append(start / 48000.0) |
|
|
| return np.array(windows), timestamps |
|
|
|
|
| def load_onnx_model(model_path: str) -> ort.InferenceSession: |
| """ |
| Load ONNX model for inference. |
| |
| Args: |
| model_path (str): Path to the ONNX model file |
| |
| Returns: |
| ort.InferenceSession: Loaded ONNX model session |
| """ |
| try: |
| |
| session = ort.InferenceSession(model_path) |
| return session |
|
|
| except Exception as e: |
| raise RuntimeError(f"Error loading ONNX model {model_path}: {str(e)}") |
|
|
|
|
| def predict_audio(session: ort.InferenceSession, audio_data: np.ndarray) -> np.ndarray: |
| """ |
| Run inference on audio data using the ONNX model. |
| |
| Args: |
| session (ort.InferenceSession): ONNX model session |
| audio_data (np.ndarray): Audio data of shape [144000] or [batch, 144000] |
| |
| Returns: |
| np.ndarray: Model predictions |
| """ |
| try: |
| |
| if len(audio_data.shape) == 1: |
| input_data = np.expand_dims(audio_data, axis=0) |
| else: |
| input_data = audio_data |
|
|
| |
| input_name = session.get_inputs()[0].name |
|
|
| |
| outputs = session.run(None, {input_name: input_data}) |
|
|
| return outputs[0] |
|
|
| except Exception as e: |
| raise RuntimeError(f"Error during model inference: {str(e)}") |
|
|
|
|
| def predict_audio_batch( |
| session: ort.InferenceSession, |
| windows_batch: np.ndarray, |
| batch_size: int = 128, |
| show_progress: bool = True, |
| ) -> np.ndarray: |
| """ |
| Run inference on batches of audio windows for better performance. |
| |
| Args: |
| session (ort.InferenceSession): ONNX model session |
| windows_batch (np.ndarray): Array of windows, shape [num_windows, 144000] |
| batch_size (int): Number of windows to process in each batch |
| show_progress (bool): Whether to show progress updates |
| |
| Returns: |
| np.ndarray: All predictions concatenated, shape [num_windows, num_classes] |
| """ |
| try: |
| all_predictions = [] |
| num_windows = len(windows_batch) |
|
|
| |
| input_name = session.get_inputs()[0].name |
|
|
| |
| batch_num = 0 |
| for start_idx in range(0, num_windows, batch_size): |
| end_idx = min(start_idx + batch_size, num_windows) |
| current_batch = windows_batch[start_idx:end_idx] |
| batch_num += 1 |
|
|
| if show_progress and (batch_num % 5 == 0 or batch_num == 1): |
| progress = (end_idx / num_windows) * 100 |
| print( |
| f" Batch {batch_num}: processing windows {start_idx + 1}-{end_idx} ({progress:.1f}%)" |
| ) |
|
|
| |
| outputs = session.run(None, {input_name: current_batch}) |
| batch_predictions = outputs[0] |
|
|
| all_predictions.append(batch_predictions) |
|
|
| |
| return np.concatenate(all_predictions, axis=0) |
|
|
| except Exception as e: |
| raise RuntimeError(f"Error during batch model inference: {str(e)}") |
|
|
|
|
| def analyze_detections( |
| all_predictions: np.ndarray, |
| timestamps: list[float], |
| labels: list[str], |
| confidence_threshold: float = 0.1, |
| ) -> dict[str, list[dict[str, float | int]]]: |
| """ |
| Analyze predictions across all windows and summarize detections. |
| |
| Args: |
| all_predictions (np.ndarray): Predictions from all windows, shape [num_windows, num_classes] |
| timestamps (list[float]): Timestamps for each window |
| labels (list[str]): Species labels |
| confidence_threshold (float): Minimum confidence for detection |
| |
| Returns: |
| dict[str, list[dict[str, float | int]]]: Summary of detections with timestamps |
| """ |
| detections = defaultdict(list) |
|
|
| |
| for i, (predictions, timestamp) in enumerate(zip(all_predictions, timestamps)): |
| |
| scores = predictions |
|
|
| |
| above_threshold = np.where(scores > confidence_threshold)[0] |
|
|
| for idx in above_threshold: |
| confidence = float(scores[idx]) |
| species_name = labels[idx] if idx < len(labels) else f"Class {idx}" |
|
|
| detections[species_name].append( |
| {"timestamp": timestamp, "confidence": confidence, "window": i} |
| ) |
|
|
| return dict(detections) |
|
|
|
|
| def main() -> int: |
| parser = argparse.ArgumentParser( |
| description="BirdNET Audio Classification with Moving Window" |
| ) |
| parser.add_argument("audio_file", help="Path to the WAV audio file") |
| parser.add_argument( |
| "--model", default="model.onnx", help="Path to the ONNX model file" |
| ) |
| parser.add_argument( |
| "--labels", |
| default="BirdNET_GLOBAL_6K_V2.4_Labels.txt", |
| help="Path to the labels file", |
| ) |
| parser.add_argument( |
| "--top-k", |
| type=int, |
| default=5, |
| help="Number of top predictions to show per window", |
| ) |
| parser.add_argument( |
| "--overlap", type=float, default=0.5, help="Window overlap ratio (0.0-1.0)" |
| ) |
| parser.add_argument( |
| "--confidence", |
| type=float, |
| default=0.1, |
| help="Minimum confidence threshold for detections", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=128, |
| help="Batch size for inference (default: 128)", |
| ) |
| parser.add_argument( |
| "--single-window", |
| action="store_true", |
| help="Analyze only first 3 seconds (single window)", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| if not os.path.exists(args.audio_file): |
| print(f"Error: Audio file '{args.audio_file}' not found.") |
| return 1 |
|
|
| if not os.path.exists(args.model): |
| print(f"Error: Model file '{args.model}' not found.") |
| return 1 |
|
|
| if not os.path.exists(args.labels): |
| print(f"Error: Labels file '{args.labels}' not found.") |
| return 1 |
|
|
| try: |
| |
| print(f"Loading labels from: {args.labels}") |
| labels = load_labels(args.labels) |
| print(f"Loaded {len(labels)} species labels") |
|
|
| |
| print(f"Loading ONNX model: {args.model}") |
| session = load_onnx_model(args.model) |
|
|
| |
| input_info = session.get_inputs()[0] |
| output_info = session.get_outputs()[0] |
| print(f"Model input: {input_info.name}, shape: {input_info.shape}") |
| print(f"Model output: {output_info.name}, shape: {output_info.shape}") |
|
|
| if args.single_window: |
| |
| print(f"Loading first 3 seconds of audio file: {args.audio_file}") |
| audio_data = load_audio(args.audio_file) |
| print(f"Audio loaded successfully. Shape: {audio_data.shape}") |
|
|
| print("Running inference on single window...") |
| predictions = predict_audio(session, audio_data) |
|
|
| |
| predictions = np.array(predictions) |
| if len(predictions.shape) > 1: |
| scores = predictions[0] |
| else: |
| scores = predictions |
|
|
| |
| top_indices = np.argsort(scores)[-args.top_k :][::-1] |
|
|
| print(f"\nTop {args.top_k} predictions for first 3 seconds:") |
| for i, idx in enumerate(top_indices): |
| confidence = float(scores[idx]) |
| species_name = labels[idx] if idx < len(labels) else f"Class {idx}" |
| print(f"{i + 1:2d}. {species_name}: {confidence:.6f}") |
|
|
| else: |
| |
| print(f"Loading full audio file: {args.audio_file}") |
| full_audio = load_audio_full(args.audio_file) |
| audio_duration = len(full_audio) / 48000.0 |
| print(f"Audio loaded successfully. Duration: {audio_duration:.2f} seconds") |
|
|
| |
| print(f"Creating windows with {args.overlap * 100:.0f}% overlap...") |
| windows, timestamps = create_audio_windows(full_audio, overlap=args.overlap) |
| print(f"Created {len(windows)} windows of 3 seconds each") |
|
|
| |
| print( |
| f"Running batch inference on {len(windows)} windows (batch size: {args.batch_size})..." |
| ) |
| num_batches = (len(windows) + args.batch_size - 1) // args.batch_size |
| print(f"Processing {num_batches} batches...") |
|
|
| |
| all_predictions = predict_audio_batch(session, windows, args.batch_size) |
| print(f"Completed batch inference on {len(windows)} windows") |
|
|
| |
| print( |
| f"Analyzing detections with confidence threshold {args.confidence}..." |
| ) |
| detections = analyze_detections( |
| all_predictions, timestamps, labels, args.confidence |
| ) |
|
|
| |
| sorted_species = sorted( |
| detections.items(), |
| key=lambda x: max(det["confidence"] for det in x[1]), |
| reverse=True, |
| ) |
|
|
| print("\n=== DETECTION SUMMARY ===") |
| print(f"Audio duration: {audio_duration:.2f} seconds") |
| print(f"Windows analyzed: {len(windows)}") |
| print( |
| f"Species detected (>{args.confidence:.2f} confidence): {len(sorted_species)}" |
| ) |
|
|
| if sorted_species: |
| print("\nTop detections:") |
| for species, detections_list in sorted_species[: args.top_k]: |
| max_conf = max(det["confidence"] for det in detections_list) |
| num_detections = len(detections_list) |
| first_detection = min(det["timestamp"] for det in detections_list) |
| last_detection = max(det["timestamp"] for det in detections_list) |
|
|
| print(f"\n{species}") |
| print(f" Max confidence: {max_conf:.6f}") |
| print(f" Detections: {num_detections}") |
| print( |
| f" Time range: {first_detection:.1f}s - {last_detection:.1f}s" |
| ) |
|
|
| |
| strong_detections = sorted( |
| detections_list, key=lambda x: x["confidence"], reverse=True |
| )[:3] |
| for det in strong_detections: |
| print(f" {det['timestamp']:6.1f}s: {det['confidence']:.6f}") |
| else: |
| print( |
| f"No detections found above confidence threshold {args.confidence}" |
| ) |
|
|
| return 0 |
|
|
| except Exception as e: |
| print(f"Error: {str(e)}") |
| return 1 |
|
|
|
|
| if __name__ == "__main__": |
| exit(main()) |
|
|