from fastapi import UploadFile, Form from fastapi.responses import JSONResponse from pipeline_setup import pipe from utils.image_processing import encode_image_base64 from utils.video_processing import split_video_into_segments, extract_motion_key_frames, extract_audio_from_video from utils.audio_transcription import transcribe_audio import os import torch import json import time import asyncio import mimetypes from concurrent.futures import ThreadPoolExecutor def save_checkpoint(video_id, checkpoint_data): checkpoint_path = f"/tmp/{video_id}_progress.json" with open(checkpoint_path, "w") as f: json.dump(checkpoint_data, f) def load_checkpoint(video_id): checkpoint_path = f"/tmp/{video_id}_progress.json" if os.path.exists(checkpoint_path): with open(checkpoint_path, "r") as f: return json.load(f) return None # api # async def video_query(file: UploadFile, question: str = Form(...)): # try: # print("Processing video...") # if file.content_type not in ["video/mp4", "video/avi", "video/mkv"]: # return JSONResponse({"query": question, "error": "Unsupported video file type."}) # overall_start_time = time.time() # video_data = await file.read() # video_id = str(hash(video_data)) # Unique ID for checkpointing # temp_video_path = f"/tmp/{video_id}.mp4" # with open(temp_video_path, "wb") as temp_video_file: # temp_video_file.write(video_data) # video_reading_time = time.time() # segments = split_video_into_segments(temp_video_path, segment_duration=30) # checkpoint = load_checkpoint(video_id) or {} # aggregated_responses = checkpoint.get("responses", []) # segment_timings = checkpoint.get("timings", []) # completed_segments = set(checkpoint.get("completed_segments", [])) # preprocessed_segments = set(checkpoint.get("preprocessed_segments", [])) # inference_completed_segments = set(checkpoint.get("inference_completed_segments", [])) # for i, segment_path in enumerate(segments): # if i in completed_segments: # print(f"Skipping already processed segment {i+1}") # continue # segment_start_time = time.time() # if i not in preprocessed_segments: # frame_start_time = time.time() # imgs = extract_motion_key_frames(segment_path, max_frames=50, sigma_multiplier=2) # frame_time = time.time() # audio_start_time = time.time() # audio_path = extract_audio_from_video(segment_path) # transcribed_text = transcribe_audio(audio_path) # audio_time = time.time() # preprocessed_segments.add(i) # save_checkpoint(video_id, { # "responses": aggregated_responses, # "timings": segment_timings, # "completed_segments": list(completed_segments), # "preprocessed_segments": list(preprocessed_segments), # "inference_completed_segments": list(inference_completed_segments) # }) # if i not in inference_completed_segments: # combined_query = f"Audio Transcript: {transcribed_text}\n{question}" # question_with_frames = "".join([f"Frame{j+1}: {{IMAGE_TOKEN}}\n" for j, _ in enumerate(imgs)]) # question_with_frames += combined_query # content = [{"type": "text", "text": question_with_frames}] + [ # {"type": "image_url", "image_url": {"max_dynamic_patch": 1, "url": f"data:image/jpeg;base64,{encode_image_base64(img)}"}} # for img in imgs # ] # inference_start_time = time.time() # messages = [dict(role="user", content=content)] # response = await asyncio.to_thread(pipe, messages) # inference_time = time.time() # aggregated_responses.append(response.text) # inference_completed_segments.add(i) # save_checkpoint(video_id, { # "responses": aggregated_responses, # "timings": segment_timings, # "completed_segments": list(completed_segments), # "preprocessed_segments": list(preprocessed_segments), # "inference_completed_segments": list(inference_completed_segments) # }) # segment_timings.append({ # "segment_index": i + 1, # "segment_processing_time": inference_time - segment_start_time, # "frame_extraction_time": frame_time - frame_start_time, # "audio_extraction_time": audio_time - audio_start_time, # "model_inference_time": inference_time - inference_start_time # }) # completed_segments.add(i) # save_checkpoint(video_id, { # "responses": aggregated_responses, # "timings": segment_timings, # "completed_segments": list(completed_segments), # "preprocessed_segments": list(preprocessed_segments), # "inference_completed_segments": list(inference_completed_segments) # }) # return JSONResponse({ # "question": question, # "responses": aggregated_responses, # "timings": { # "video_reading_time": video_reading_time - overall_start_time, # "total_segments": len(segments), # "total_processing_time": time.time() - overall_start_time, # "segment_details": segment_timings # } # }) # except Exception as e: # return JSONResponse({"query": question, "error": str(e)}) # gradio # async def video_query(video_path: str, question: str): # """ # API endpoint to process a video file with the user's query. # """ # try: # print("Processing video...") # if not video_path or not isinstance(video_path, str): # return {"query": question, "error": "No video file provided or invalid file input."} # # Determine the file type using the file extension # file_type, _ = mimetypes.guess_type(video_path) # if file_type is None or not file_type.startswith("video/"): # return {"query": question, "error": "Unsupported video file type."} # # Log the video path # print(f"Video path: {video_path}") # # Split the video into segments # print("Splitting video...") # segments = split_video_into_segments(video_path, segment_duration=30) # print(f"Video split into {len(segments)} segments.") # aggregated_responses = [] # segment_timings = [] # for i, segment_path in enumerate(segments): # print(f"Processing segment {i+1}/{len(segments)}: {segment_path}") # # Extract key frames # imgs = extract_motion_key_frames(segment_path, max_frames=50, sigma_multiplier=2) # # Extract audio and transcribe # audio_path = extract_audio_from_video(segment_path) # transcribed_text = transcribe_audio(audio_path) # # Combine transcribed text with the query # combined_query = f"Audio Transcript: {transcribed_text}\n{question}" # # Prepare content for the pipeline # question_with_frames = "" # for j, img in enumerate(imgs): # question_with_frames += f"Frame{j+1}: {{IMAGE_TOKEN}}\n" # question_with_frames += combined_query # content = [{"type": "text", "text": question_with_frames}] # for img in imgs: # content.append({ # "type": "image_url", # "image_url": { # "max_dynamic_patch": 1, # "url": f"data:image/jpeg;base64,{encode_image_base64(img)}" # } # }) # # Query the model # messages = [dict(role="user", content=content)] # response = await asyncio.to_thread(pipe, messages) # # Aggregate response # aggregated_responses.append(response.text) # return { # "question": question, # "responses": aggregated_responses, # } # except Exception as e: # return {"query": question, "error": str(e)} # def video_query(video_path: str, question: str): # """ # Processes a video file using the model. # Reads the video from disk, extracts key frames, transcribes audio, and queries the model. # """ # try: # print("Processing video...") # if not os.path.exists(video_path): # return {"query": question, "error": "Video file not found."} # # Determine the file type # file_type, _ = mimetypes.guess_type(video_path) # if file_type is None or not file_type.startswith("video/"): # return {"query": question, "error": "Unsupported video file type."} # # Split video into segments # print("Splitting video...") # segments = split_video_into_segments(video_path, segment_duration=30) # print(f"Video split into {len(segments)} segments.") # aggregated_responses = [] # segment_timings = [] # for i, segment_path in enumerate(segments): # print(f"Processing segment {i+1}/{len(segments)}: {segment_path}") # # Extract key frames # imgs = extract_motion_key_frames(segment_path, max_frames=50, sigma_multiplier=2) # # Extract audio and transcribe # audio_path = extract_audio_from_video(segment_path) # transcribed_text = transcribe_audio(audio_path) # # Combine transcribed text with the query # combined_query = f"Audio Transcript: {transcribed_text}\n{question}" # # Prepare content for the pipeline # question_with_frames = "".join([f"Frame{j+1}: {{IMAGE_TOKEN}}\n" for j in range(len(imgs))]) # question_with_frames += combined_query # content = [{"type": "text", "text": question_with_frames}] + [ # {"type": "image_url", "image_url": {"max_dynamic_patch": 1, "url": f"data:image/jpeg;base64,{encode_image_base64(img)}"}} # for img in imgs # ] # # Query the model # messages = [dict(role="user", content=content)] # response = pipe(messages) # # Aggregate response # aggregated_responses.append(response.text) # return { # "question": question, # "responses": aggregated_responses, # } # except Exception as e: # return {"query": question, "error": str(e)} # def run_video_inference(preprocessed_data): # """ # **Inference Step (Runs on GPU)** # - Takes preprocessed data (key frames + transcribed audio). # - Constructs a query for the model. # - Runs inference on the GPU. # - Returns the aggregated responses. # """ # import torch # torch.cuda.empty_cache() # Free up GPU memory before inference # try: # print("Starting video inference...") # question = preprocessed_data["question"] # segments = preprocessed_data["segments"] # aggregated_responses = [] # for segment in segments: # segment_index = segment["segment_index"] # transcribed_text = segment["transcription"] # encoded_imgs = segment["encoded_images"] # print(f"Running inference on segment {segment_index + 1}...") # # Prepare query content # question_with_frames = "".join( # [f"Frame{j+1}: {{IMAGE_TOKEN}}\n" for j in range(len(encoded_imgs))] # ) # combined_query = f"Audio Transcript: {transcribed_text}\n{question}" # question_with_frames += combined_query # content = [{"type": "text", "text": question_with_frames}] + [ # {"type": "image_url", "image_url": {"max_dynamic_patch": 1, "url": f"data:image/jpeg;base64,{img}"}} # for img in encoded_imgs # ] # # Query the model (GPU-heavy operation) # messages = [dict(role="user", content=content)] # response = pipe(messages) # # Collect responses # aggregated_responses.append(response.text) # return { # "question": question, # "responses": aggregated_responses, # } # except Exception as e: # return {"query": question, "error": str(e)}