Tiktok-Talent-Info/test_audio.py

92 lines
3.3 KiB
Python
Raw Normal View History

2025-03-22 20:54:10 +08:00
import torch
import os
from whisper import load_model
from pydub import AudioSegment
def extract_audio_from_video(video_path: str) -> str:
audio = AudioSegment.from_file(video_path)
audio_path = "/tmp/temp_audio_test.wav"
audio.export(audio_path, format="wav")
print("video extracted!")
return audio_path
# def transcribe_audio(audio_path: str) -> str:
# print("Loading model in transcribe_audio...")
# from transformers import WhisperProcessor, WhisperForConditionalGeneration
# import torch
# # Load processor and model from transformers
# processor = WhisperProcessor.from_pretrained("openai/whisper-base")
# model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base")
# if torch.cuda.device_count() > 1:
# print(f"Using {torch.cuda.device_count()} GPUs!")
# model = torch.nn.DataParallel(model)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# print("Model loaded successfully.")
# print(audio_path)
# # Load and process the audio file
# import librosa
# audio_input, sr = librosa.load(audio_path, sr=16000)
# input_features = processor(audio_input, sampling_rate=sr, return_tensors="pt").input_features.to(device)
# # Generate transcription
# with torch.no_grad():
# if isinstance(model, torch.nn.DataParallel):
# generated_ids = model.module.generate(input_features)
# else:
# generated_ids = model.generate(input_features)
# # Decode the generated tokens to text
# transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# return transcription
def transcribe_audio(audio_path: str) -> str:
print("Loading model in transcribe_audio...")
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
from transformers import WhisperProcessor, WhisperForConditionalGeneration
import torch
# Load processor and model from transformers
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
if torch.cuda.device_count() > 1:
print(f"Using {torch.cuda.device_count()} GPUs!")
model = torch.nn.DataParallel(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Model loaded successfully.")
print(audio_path)
# Load and process the audio file
import librosa
audio_input, sr = librosa.load(audio_path, sr=16000)
input_features = processor(audio_input, sampling_rate=sr, return_tensors="pt").input_features.to(device)
# Generate transcription
with torch.no_grad():
if isinstance(model, torch.nn.DataParallel):
generated_ids = model.module.generate(input_features)
else:
generated_ids = model.generate(input_features)
# Decode the generated tokens to text
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription
if __name__ == "__main__":
extract_audio_from_video("../video/1.mp4")
audio_file = "/tmp/temp_audio_test.wav"
for i in range(3):
print(f"\nTranscription attempt {i + 1}:")
transcription = transcribe_audio(audio_file)
print("Transcription:")
print(transcription)