318 lines
13 KiB
Python
318 lines
13 KiB
Python
from fastapi import UploadFile, Form
|
|
from fastapi.responses import JSONResponse
|
|
from pipeline_setup import pipe
|
|
from utils.image_processing import encode_image_base64
|
|
from utils.video_processing import split_video_into_segments, extract_motion_key_frames, extract_audio_from_video
|
|
from utils.audio_transcription import transcribe_audio
|
|
import os
|
|
import torch
|
|
import json
|
|
import time
|
|
import asyncio
|
|
import mimetypes
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
def save_checkpoint(video_id, checkpoint_data):
|
|
checkpoint_path = f"/tmp/{video_id}_progress.json"
|
|
with open(checkpoint_path, "w") as f:
|
|
json.dump(checkpoint_data, f)
|
|
|
|
def load_checkpoint(video_id):
|
|
checkpoint_path = f"/tmp/{video_id}_progress.json"
|
|
if os.path.exists(checkpoint_path):
|
|
with open(checkpoint_path, "r") as f:
|
|
return json.load(f)
|
|
return None
|
|
|
|
# async def video_query(file: UploadFile, question: str = Form(...)):
|
|
# try:
|
|
# print("Processing video...")
|
|
|
|
# if file.content_type not in ["video/mp4", "video/avi", "video/mkv"]:
|
|
# return JSONResponse({"query": question, "error": "Unsupported video file type."})
|
|
|
|
# overall_start_time = time.time()
|
|
|
|
# video_data = await file.read()
|
|
# video_id = str(hash(video_data)) # Unique ID for checkpointing
|
|
# temp_video_path = f"/tmp/{video_id}.mp4"
|
|
# with open(temp_video_path, "wb") as temp_video_file:
|
|
# temp_video_file.write(video_data)
|
|
|
|
# video_reading_time = time.time()
|
|
# segments = split_video_into_segments(temp_video_path, segment_duration=30)
|
|
|
|
# checkpoint = load_checkpoint(video_id) or {}
|
|
# aggregated_responses = checkpoint.get("responses", [])
|
|
# segment_timings = checkpoint.get("timings", [])
|
|
# completed_segments = set(checkpoint.get("completed_segments", []))
|
|
# preprocessed_segments = set(checkpoint.get("preprocessed_segments", []))
|
|
# inference_completed_segments = set(checkpoint.get("inference_completed_segments", []))
|
|
|
|
# for i, segment_path in enumerate(segments):
|
|
# if i in completed_segments:
|
|
# print(f"Skipping already processed segment {i+1}")
|
|
# continue
|
|
|
|
# segment_start_time = time.time()
|
|
|
|
# if i not in preprocessed_segments:
|
|
# frame_start_time = time.time()
|
|
# imgs = extract_motion_key_frames(segment_path, max_frames=50, sigma_multiplier=2)
|
|
# frame_time = time.time()
|
|
|
|
# audio_start_time = time.time()
|
|
# audio_path = extract_audio_from_video(segment_path)
|
|
# transcribed_text = transcribe_audio(audio_path)
|
|
# audio_time = time.time()
|
|
|
|
# preprocessed_segments.add(i)
|
|
# save_checkpoint(video_id, {
|
|
# "responses": aggregated_responses,
|
|
# "timings": segment_timings,
|
|
# "completed_segments": list(completed_segments),
|
|
# "preprocessed_segments": list(preprocessed_segments),
|
|
# "inference_completed_segments": list(inference_completed_segments)
|
|
# })
|
|
|
|
# if i not in inference_completed_segments:
|
|
# combined_query = f"Audio Transcript: {transcribed_text}\n{question}"
|
|
# question_with_frames = "".join([f"Frame{j+1}: {{IMAGE_TOKEN}}\n" for j, _ in enumerate(imgs)])
|
|
# question_with_frames += combined_query
|
|
|
|
# content = [{"type": "text", "text": question_with_frames}] + [
|
|
# {"type": "image_url", "image_url": {"max_dynamic_patch": 1, "url": f"data:image/jpeg;base64,{encode_image_base64(img)}"}}
|
|
# for img in imgs
|
|
# ]
|
|
|
|
# inference_start_time = time.time()
|
|
# messages = [dict(role="user", content=content)]
|
|
# response = await asyncio.to_thread(pipe, messages)
|
|
# inference_time = time.time()
|
|
|
|
# aggregated_responses.append(response.text)
|
|
|
|
# inference_completed_segments.add(i)
|
|
# save_checkpoint(video_id, {
|
|
# "responses": aggregated_responses,
|
|
# "timings": segment_timings,
|
|
# "completed_segments": list(completed_segments),
|
|
# "preprocessed_segments": list(preprocessed_segments),
|
|
# "inference_completed_segments": list(inference_completed_segments)
|
|
# })
|
|
|
|
# segment_timings.append({
|
|
# "segment_index": i + 1,
|
|
# "segment_processing_time": inference_time - segment_start_time,
|
|
# "frame_extraction_time": frame_time - frame_start_time,
|
|
# "audio_extraction_time": audio_time - audio_start_time,
|
|
# "model_inference_time": inference_time - inference_start_time
|
|
# })
|
|
|
|
# completed_segments.add(i)
|
|
# save_checkpoint(video_id, {
|
|
# "responses": aggregated_responses,
|
|
# "timings": segment_timings,
|
|
# "completed_segments": list(completed_segments),
|
|
# "preprocessed_segments": list(preprocessed_segments),
|
|
# "inference_completed_segments": list(inference_completed_segments)
|
|
# })
|
|
|
|
# return JSONResponse({
|
|
# "question": question,
|
|
# "responses": aggregated_responses,
|
|
# "timings": {
|
|
# "video_reading_time": video_reading_time - overall_start_time,
|
|
# "total_segments": len(segments),
|
|
# "total_processing_time": time.time() - overall_start_time,
|
|
# "segment_details": segment_timings
|
|
# }
|
|
# })
|
|
# except Exception as e:
|
|
# return JSONResponse({"query": question, "error": str(e)})
|
|
|
|
|
|
# async def video_query(video_path: str, question: str):
|
|
# """
|
|
# API endpoint to process a video file with the user's query.
|
|
# """
|
|
# try:
|
|
# print("Processing video...")
|
|
|
|
# if not video_path or not isinstance(video_path, str):
|
|
# return {"query": question, "error": "No video file provided or invalid file input."}
|
|
|
|
# # Determine the file type using the file extension
|
|
# file_type, _ = mimetypes.guess_type(video_path)
|
|
# if file_type is None or not file_type.startswith("video/"):
|
|
# return {"query": question, "error": "Unsupported video file type."}
|
|
|
|
# # Log the video path
|
|
# print(f"Video path: {video_path}")
|
|
|
|
# # Split the video into segments
|
|
# print("Splitting video...")
|
|
# segments = split_video_into_segments(video_path, segment_duration=30)
|
|
# print(f"Video split into {len(segments)} segments.")
|
|
|
|
# aggregated_responses = []
|
|
# segment_timings = []
|
|
|
|
# for i, segment_path in enumerate(segments):
|
|
# print(f"Processing segment {i+1}/{len(segments)}: {segment_path}")
|
|
|
|
# # Extract key frames
|
|
# imgs = extract_motion_key_frames(segment_path, max_frames=50, sigma_multiplier=2)
|
|
|
|
# # Extract audio and transcribe
|
|
# audio_path = extract_audio_from_video(segment_path)
|
|
# transcribed_text = transcribe_audio(audio_path)
|
|
|
|
# # Combine transcribed text with the query
|
|
# combined_query = f"Audio Transcript: {transcribed_text}\n{question}"
|
|
|
|
# # Prepare content for the pipeline
|
|
# question_with_frames = ""
|
|
# for j, img in enumerate(imgs):
|
|
# question_with_frames += f"Frame{j+1}: {{IMAGE_TOKEN}}\n"
|
|
# question_with_frames += combined_query
|
|
|
|
# content = [{"type": "text", "text": question_with_frames}]
|
|
# for img in imgs:
|
|
# content.append({
|
|
# "type": "image_url",
|
|
# "image_url": {
|
|
# "max_dynamic_patch": 1,
|
|
# "url": f"data:image/jpeg;base64,{encode_image_base64(img)}"
|
|
# }
|
|
# })
|
|
|
|
# # Query the model
|
|
# messages = [dict(role="user", content=content)]
|
|
# response = await asyncio.to_thread(pipe, messages)
|
|
|
|
# # Aggregate response
|
|
# aggregated_responses.append(response.text)
|
|
|
|
# return {
|
|
# "question": question,
|
|
# "responses": aggregated_responses,
|
|
# }
|
|
# except Exception as e:
|
|
# return {"query": question, "error": str(e)}
|
|
|
|
|
|
# def video_query(video_path: str, question: str):
|
|
# """
|
|
# Processes a video file using the model.
|
|
# Reads the video from disk, extracts key frames, transcribes audio, and queries the model.
|
|
# """
|
|
# try:
|
|
# print("Processing video...")
|
|
|
|
# if not os.path.exists(video_path):
|
|
# return {"query": question, "error": "Video file not found."}
|
|
|
|
# # Determine the file type
|
|
# file_type, _ = mimetypes.guess_type(video_path)
|
|
# if file_type is None or not file_type.startswith("video/"):
|
|
# return {"query": question, "error": "Unsupported video file type."}
|
|
|
|
# # Split video into segments
|
|
# print("Splitting video...")
|
|
# segments = split_video_into_segments(video_path, segment_duration=30)
|
|
# print(f"Video split into {len(segments)} segments.")
|
|
|
|
# aggregated_responses = []
|
|
# segment_timings = []
|
|
|
|
# for i, segment_path in enumerate(segments):
|
|
# print(f"Processing segment {i+1}/{len(segments)}: {segment_path}")
|
|
|
|
# # Extract key frames
|
|
# imgs = extract_motion_key_frames(segment_path, max_frames=50, sigma_multiplier=2)
|
|
|
|
# # Extract audio and transcribe
|
|
# audio_path = extract_audio_from_video(segment_path)
|
|
# transcribed_text = transcribe_audio(audio_path)
|
|
|
|
# # Combine transcribed text with the query
|
|
# combined_query = f"Audio Transcript: {transcribed_text}\n{question}"
|
|
|
|
# # Prepare content for the pipeline
|
|
# question_with_frames = "".join([f"Frame{j+1}: {{IMAGE_TOKEN}}\n" for j in range(len(imgs))])
|
|
# question_with_frames += combined_query
|
|
|
|
# content = [{"type": "text", "text": question_with_frames}] + [
|
|
# {"type": "image_url", "image_url": {"max_dynamic_patch": 1, "url": f"data:image/jpeg;base64,{encode_image_base64(img)}"}}
|
|
# for img in imgs
|
|
# ]
|
|
|
|
# # Query the model
|
|
# messages = [dict(role="user", content=content)]
|
|
# response = pipe(messages)
|
|
|
|
# # Aggregate response
|
|
# aggregated_responses.append(response.text)
|
|
|
|
# return {
|
|
# "question": question,
|
|
# "responses": aggregated_responses,
|
|
# }
|
|
|
|
# except Exception as e:
|
|
# return {"query": question, "error": str(e)}
|
|
|
|
|
|
# def run_video_inference(preprocessed_data):
|
|
# """
|
|
# **Inference Step (Runs on GPU)**
|
|
# - Takes preprocessed data (key frames + transcribed audio).
|
|
# - Constructs a query for the model.
|
|
# - Runs inference on the GPU.
|
|
# - Returns the aggregated responses.
|
|
# """
|
|
# import torch
|
|
# torch.cuda.empty_cache() # Free up GPU memory before inference
|
|
|
|
# try:
|
|
# print("Starting video inference...")
|
|
|
|
# question = preprocessed_data["question"]
|
|
# segments = preprocessed_data["segments"]
|
|
# aggregated_responses = []
|
|
|
|
# for segment in segments:
|
|
# segment_index = segment["segment_index"]
|
|
# transcribed_text = segment["transcription"]
|
|
# encoded_imgs = segment["encoded_images"]
|
|
|
|
# print(f"Running inference on segment {segment_index + 1}...")
|
|
|
|
# # Prepare query content
|
|
# question_with_frames = "".join(
|
|
# [f"Frame{j+1}: {{IMAGE_TOKEN}}\n" for j in range(len(encoded_imgs))]
|
|
# )
|
|
# combined_query = f"Audio Transcript: {transcribed_text}\n{question}"
|
|
# question_with_frames += combined_query
|
|
|
|
# content = [{"type": "text", "text": question_with_frames}] + [
|
|
# {"type": "image_url", "image_url": {"max_dynamic_patch": 1, "url": f"data:image/jpeg;base64,{img}"}}
|
|
# for img in encoded_imgs
|
|
# ]
|
|
|
|
# # Query the model (GPU-heavy operation)
|
|
# messages = [dict(role="user", content=content)]
|
|
# response = pipe(messages)
|
|
|
|
# # Collect responses
|
|
# aggregated_responses.append(response.text)
|
|
|
|
# return {
|
|
# "question": question,
|
|
# "responses": aggregated_responses,
|
|
# }
|
|
|
|
# except Exception as e:
|
|
# return {"query": question, "error": str(e)}
|