Tiktok-Talent-Info/utils/audio_transcription.py

92 lines
3.2 KiB
Python
Raw Permalink Normal View History

2025-03-22 20:54:10 +08:00
import os
import torch
2025-01-23 21:50:55 +08:00
from pydub import AudioSegment
def extract_audio_from_video(video_path: str) -> str:
audio = AudioSegment.from_file(video_path)
audio_path = "/tmp/temp_audio.wav"
audio.export(audio_path, format="wav")
return audio_path
2025-03-22 20:54:10 +08:00
# def transcribe_audio(audio_path: str) -> str:
# print("Loading model in transcribe_audio...")
# from whisper import load_model
# model = load_model("base", device="cpu")
# # model = load_model("base")
# print("Model loaded successfully.")
# print(f"Model is running on: {next(model.parameters()).device}")
# print("Model loaded successfully on CPU.")
# result = model.transcribe(audio_path)
# print(result)
# return result["text"]
# def transcribe_audio(audio_path: str) -> str:
# print("Loading model in transcribe_audio...")
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# from whisper import load_model
# model = load_model("base")
# if torch.cuda.device_count() > 1:
# print(f"Using {torch.cuda.device_count()} GPUs!")
# model = torch.nn.DataParallel(model)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# print("Model loaded successfully.")
# print(audio_path)
# # Access the underlying model if using DataParallel
# if isinstance(model, torch.nn.DataParallel):
# result = model.module.transcribe(audio_path)
# else:
# result = model.transcribe(audio_path)
# print(result)
# return result["text"]
2025-01-23 21:50:55 +08:00
def transcribe_audio(audio_path: str) -> str:
2025-03-22 20:54:10 +08:00
print("Loading model in transcribe_audio...")
from transformers import WhisperProcessor, WhisperForConditionalGeneration
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = torch.nn.DataParallel(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Model loaded successfully.")
print(audio_path)
# Load and process the audio file
import librosa
audio_input, sr = librosa.load(audio_path, sr=16000)
input_features = processor(audio_input, sampling_rate=sr, return_tensors="pt").input_features.to(device)
# Generate transcription
with torch.no_grad():
if isinstance(model, torch.nn.DataParallel):
generated_ids = model.module.generate(input_features)
else:
generated_ids = model.generate(input_features)
# Decode the generated tokens to text
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription
# audio_path = "/tmp/temp_audio.wav"
# num_iterations = 5
# import time
# start_time = time.time()
# for i in range(num_iterations):
# print(f"Processing iteration {i+1}...")
# transcription = transcribe_audio(audio_path)
# print(f"Transcription (iteration {i+1}): {transcription}")
# end_time = time.time()
# elapsed_time = end_time - start_time
# print(f"Time taken for iteration {i+1}: {elapsed_time:.2f} seconds\n")