Tiktok-Talent-Info/utils/video_processing.py

663 lines
26 KiB
Python
Raw Normal View History

2025-01-23 21:50:55 +08:00
import cv2
import os
import subprocess
import numpy as np
from PIL import Image
from pydub import AudioSegment
from decord import VideoReader, cpu
2025-03-31 13:21:15 +08:00
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import Pool, cpu_count
# def split_video_into_segments(video_path, segment_duration=30):
# """
# Splits a video into segments of a specified duration using FFmpeg.
# """
# output_dir = "/tmp/video_segments"
# os.makedirs(output_dir, exist_ok=True)
# # Calculate total duration of the video
# cap = cv2.VideoCapture(video_path)
# fps = int(cap.get(cv2.CAP_PROP_FPS))
# total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# total_duration = total_frames / fps
# cap.release()
# segments = []
# for start_time in range(0, int(total_duration), segment_duration):
# segment_file = os.path.join(output_dir, f"segment_{start_time}.mp4")
# command = [
# "ffmpeg", "-y",
# "-i", video_path,
# "-ss", str(start_time),
# "-t", str(segment_duration),
# "-c", "copy", segment_file
# ]
# subprocess.run(command, check=True)
# segments.append(segment_file)
# print(f"segments: \n", segments)
# return segments
# def split_video_into_segments(video_path, segment_duration=30): # slow
# """
# Splits a video into segments of a specified duration using FFmpeg's segment muxer.
# """
# output_dir = "/tmp/video_segments"
# os.makedirs(output_dir, exist_ok=True)
# segment_file_pattern = os.path.join(output_dir, "segment_%03d.mp4")
# command = [
# "ffmpeg", "-y",
# "-i", video_path,
# "-f", "segment",
# "-segment_time", str(segment_duration),
# "-c", "copy",
# "-reset_timestamps", "1",
# segment_file_pattern
# ]
# subprocess.run(command, check=True)
# segments = sorted([os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.startswith("segment_")])
# print(f"segments: \n", segments)
# return segments
2025-01-23 21:50:55 +08:00
def split_video_into_segments(video_path, segment_duration=30):
"""
Splits a video into segments of a specified duration using FFmpeg.
"""
output_dir = "/tmp/video_segments"
os.makedirs(output_dir, exist_ok=True)
# Calculate total duration of the video
cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
total_duration = total_frames / fps
cap.release()
segments = []
2025-03-31 13:21:15 +08:00
def extract_segment(start_time):
2025-01-23 21:50:55 +08:00
segment_file = os.path.join(output_dir, f"segment_{start_time}.mp4")
command = [
2025-03-22 20:54:10 +08:00
"ffmpeg", "-y",
"-i", video_path,
2025-01-23 21:50:55 +08:00
"-ss", str(start_time),
"-t", str(segment_duration),
"-c", "copy", segment_file
]
subprocess.run(command, check=True)
2025-03-31 13:21:15 +08:00
return segment_file
with ThreadPoolExecutor() as executor:
futures = [
executor.submit(extract_segment, start_time)
for start_time in range(0, int(total_duration), segment_duration)
]
for future in as_completed(futures):
segments.append(future.result())
2025-01-23 21:50:55 +08:00
2025-03-22 20:54:10 +08:00
print(f"segments: \n", segments)
2025-01-23 21:50:55 +08:00
return segments
2025-03-31 13:21:15 +08:00
def extract_audio_from_video(video_path):
2025-01-23 21:50:55 +08:00
"""
2025-03-31 13:21:15 +08:00
Extract audio from video using pydub and save as a temporary audio file.
2025-01-23 21:50:55 +08:00
"""
2025-03-31 13:21:15 +08:00
print("Audio extraction started...")
audio = AudioSegment.from_file(video_path)
print("Audio extraction completed.")
audio_path = "/tmp/temp_audio.wav"
audio.export(audio_path, format="wav")
print(f"Audio extracted and saved to: {audio_path}")
return audio_path
2025-01-23 21:50:55 +08:00
2025-03-31 13:21:15 +08:00
############################################################################################################
# optical motion, multithread, calculates motion between consecutive frames using dense optical flow (Farneback) only
# def extract_motion_key_frames(video_path, max_frames=20, sigma_multiplier=2, frame_interval=1):
# """
# Extracts key frames from a video based on motion intensity.
# """
# def calculate_motion(frame_pair):
# """
# Calculates motion between two consecutive frames using optical flow.
# """
# prev_gray, current_frame = frame_pair
# current_gray = cv2.cvtColor(current_frame, cv2.COLOR_BGR2GRAY)
# flow = cv2.calcOpticalFlowFarneback(prev_gray, current_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
# motion = np.sum(flow ** 2)
# return motion, current_gray
2025-01-23 21:50:55 +08:00
2025-03-31 13:21:15 +08:00
# # Load video frames using Decord
# video = VideoReader(video_path, ctx=cpu(0))
# frames_batch = video.get_batch(range(0, len(video), frame_interval)).asnumpy()
2025-01-23 21:50:55 +08:00
2025-03-31 13:21:15 +08:00
# # Resize frames for faster processing
# frames = [cv2.resize(frame, (frame.shape[1] // 2, frame.shape[0] // 2)) for frame in frames_batch]
# # Initialize the first frame
# prev_gray = cv2.cvtColor(frames[0], cv2.COLOR_BGR2GRAY)
# frame_pairs = [(prev_gray, frames[i]) for i in range(1, len(frames))]
# # Calculate motion statistics
# motion_values = []
# with ThreadPoolExecutor() as executor:
# motion_results = list(executor.map(calculate_motion, frame_pairs))
# motion_values = [motion for motion, _ in motion_results]
# # Calculate threshold statistically
# motion_mean = np.mean(motion_values)
# motion_std = np.std(motion_values)
# threshold = motion_mean + sigma_multiplier * motion_std
# # Extract key frames based on motion threshold
# key_frames = []
# for i, (motion, frame) in enumerate(zip(motion_values, frames[1:])):
# if motion > threshold and len(key_frames) < max_frames:
# img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# key_frames.append(img)
# return key_frames
############################################################################################################
# multithreading with bactch
# def extract_motion_key_frames(video_path, max_frames=20, sigma_multiplier=2, frame_interval=1):
# """
# Extracts key frames from a video based on motion intensity.
# Optimized for speed and efficiency.
# """
# def calculate_motion(frame_pair):
# """
# Calculates motion between two consecutive frames using optical flow.
# """
# prev_gray, current_frame = frame_pair
# current_gray = cv2.cvtColor(current_frame, cv2.COLOR_BGR2GRAY)
# flow = cv2.calcOpticalFlowFarneback(prev_gray, current_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
# motion = np.sum(flow ** 2)
# return motion, current_gray
# # Load video frames using Decord with reduced resolution
# video = VideoReader(video_path, ctx=cpu(0))
# total_frames = len(video)
# frame_indices = range(0, total_frames, frame_interval)
# # Process frames in smaller batches to reduce memory usage
# batch_size = 100
# motion_values = []
# for batch_start in range(0, len(frame_indices), batch_size):
# batch_end = min(batch_start + batch_size, len(frame_indices))
# batch_indices = frame_indices[batch_start:batch_end]
# frames_batch = video.get_batch(batch_indices).asnumpy()
# # Resize frames for faster processing
# frames = [cv2.resize(frame, (frame.shape[1] // 2, frame.shape[0] // 2)) for frame in frames_batch]
# # Initialize the first frame in the batch
# prev_gray = cv2.cvtColor(frames[0], cv2.COLOR_BGR2GRAY)
# frame_pairs = [(prev_gray, frames[i]) for i in range(1, len(frames))]
# # Calculate motion statistics for the batch
# with ThreadPoolExecutor() as executor:
# motion_results = list(executor.map(calculate_motion, frame_pairs))
# batch_motion_values = [motion for motion, _ in motion_results]
# motion_values.extend(batch_motion_values)
# # Update the previous frame for the next batch
# prev_gray = cv2.cvtColor(frames[-1], cv2.COLOR_BGR2GRAY)
# # Calculate threshold statistically
# motion_mean = np.mean(motion_values)
# motion_std = np.std(motion_values)
# threshold = motion_mean + sigma_multiplier * motion_std
# # Extract key frames based on motion threshold
# key_frames = []
# for i, (motion, frame) in enumerate(zip(motion_values, frames[1:])):
# if motion > threshold and len(key_frames) < max_frames:
# img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
# key_frames.append(img)
# return key_frames
############################################################################################################
# multiprocessing
# def calculate_motion(frame_pair):
# """
# Calculates motion between two consecutive frames using optical flow.
# """
# prev_gray, current_gray = frame_pair
# flow = cv2.calcOpticalFlowFarneback(prev_gray, current_gray, None, 0.5, 3, 15, 3, 5, 1.2, 0)
# motion = np.sum(flow ** 2)
# return motion
# def extract_motion_key_frames(video_path, max_frames=20, sigma_multiplier=2, frame_interval=1):
# """
# Extracts key frames from a video based on motion intensity.
# Optimized for speed and efficiency.
# """
# # Load video frames using Decord with reduced resolution
# video = VideoReader(video_path, ctx=cpu(0))
# total_frames = len(video)
# frame_indices = range(0, total_frames, frame_interval)
# # Read all frames and resize them for faster processing
# frames = video.get_batch(frame_indices).asnumpy()
# frames = [cv2.resize(frame, (frame.shape[1] // 2, frame.shape[0] // 2)) for frame in frames]
# # Convert all frames to grayscale in one go
# grayscale_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in frames]
# # Calculate motion between consecutive frames using multiprocessing
# frame_pairs = list(zip(grayscale_frames[:-1], grayscale_frames[1:]))
# with Pool(cpu_count()) as pool:
# motion_values = pool.map(calculate_motion, frame_pairs)
# # Calculate threshold statistically
# motion_mean = np.mean(motion_values)
# motion_std = np.std(motion_values)
# threshold = motion_mean + sigma_multiplier * motion_std
# # Extract key frames based on motion threshold
# key_frames = []
# for i, motion in enumerate(motion_values):
# if motion > threshold and len(key_frames) < max_frames:
# img = Image.fromarray(cv2.cvtColor(frames[i + 1], cv2.COLOR_BGR2RGB))
# key_frames.append(img)
# return key_frames
############################################################################################################
# faster optical flow, more aggressive downscaling and frame skipping, looking for motion peaks, uses both dense optical flow and includes additional peak detection logic
# def calculate_motion(frames):
# """
# Calculate motion metrics using frame differencing and sparse optical flow
# Returns a list of motion intensity values
# """
# if len(frames) < 2:
# return []
# # Convert all frames to grayscale at once
# gray_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in frames]
# # Parameters for ShiTomasi corner detection and optical flow
# feature_params = dict(maxCorners=100, qualityLevel=0.3, minDistance=7, blockSize=7)
# lk_params = dict(winSize=(15,15), maxLevel=2,
# criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03))
# motion_metrics = []
# prev_frame = gray_frames[0]
# prev_pts = cv2.goodFeaturesToTrack(prev_frame, mask=None, **feature_params)
# for i in range(1, len(gray_frames)):
# curr_frame = gray_frames[i]
# # Calculate dense optical flow (Farneback)
# flow = cv2.calcOpticalFlowFarneback(prev_frame, curr_frame, None, 0.5, 3, 15, 3, 5, 1.2, 0)
# magnitude = np.sqrt(flow[...,0]**2 + flow[...,1]**2)
# motion_metrics.append(np.mean(magnitude))
# prev_frame = curr_frame
# return motion_metrics
# def extract_motion_key_frames(video_path, max_frames=20, sigma_multiplier=2, frame_interval=5):
# # Load video with reduced resolution
# video = VideoReader(video_path, ctx=cpu(0))
# total_frames = len(video)
# frame_indices = range(0, total_frames, frame_interval)
# # Read and resize all frames at once
# frames = video.get_batch(frame_indices).asnumpy()
# frames = np.array([cv2.resize(frame, (frame.shape[1]//4, frame.shape[0]//4)) for frame in frames])
# # Calculate motion metrics
# motion_values = calculate_motion(frames)
# if not motion_values:
# return []
# # Adaptive thresholding
# mean_motion = np.mean(motion_values)
# std_motion = np.std(motion_values)
# threshold = mean_motion + sigma_multiplier * std_motion
# # Find peaks in motion values
# key_frame_indices = []
# for i in range(1, len(motion_values)-1):
# if motion_values[i] > threshold and \
# motion_values[i] > motion_values[i-1] and \
# motion_values[i] > motion_values[i+1]:
# key_frame_indices.append(i+1) # +1 because motion is between frames
# # Select top frames by motion intensity
# if len(key_frame_indices) > max_frames:
# sorted_indices = sorted(key_frame_indices, key=lambda x: motion_values[x-1], reverse=True)
# key_frame_indices = sorted_indices[:max_frames]
# key_frame_indices.sort()
# # Convert to PIL Images
# key_frames = [Image.fromarray(cv2.cvtColor(frames[i], cv2.COLOR_BGR2RGB))
# for i in key_frame_indices]
# return key_frames
############################################################################################################
# RAFT Optical Flow
# import torch
# import torchvision.models.optical_flow as of
# from torch.nn.parallel import DataParallel
# def pad_to_multiple_of_8(frame):
# """
# Pads the frame dimensions to the nearest multiple of 8.
# """
# h, w, _ = frame.shape
# pad_h = (8 - h % 8) % 8
# pad_w = (8 - w % 8) % 8
# return cv2.copyMakeBorder(frame, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=[0, 0, 0])
# def compute_raft_flow_batch(frame_batch, raft_model):
# """
# Computes optical flow for a batch of frames using the RAFT model.
# """
# # Pad frames to make dimensions divisible by 8
# frame_batch = [pad_to_multiple_of_8(frame) for frame in frame_batch]
# # Convert frames to tensors and normalize
# frame_tensors = torch.stack([torch.from_numpy(frame).permute(2, 0, 1).float().cuda() / 255.0 for frame in frame_batch])
# # Compute optical flow for the batch
# with torch.no_grad():
# flows = raft_model(frame_tensors[:-1], frame_tensors[1:])
# # Calculate motion magnitude for each flow
# motions = [np.sum(flow.cpu().numpy() ** 2) for flow in flows]
# return motions
# def extract_motion_key_frames(video_path, max_frames=20, sigma_multiplier=2, frame_interval=1, batch_size=128):
# """
# Extracts key frames from a video based on motion intensity using RAFT for optical flow.
# Utilizes multiple GPUs and processes frames in batches.
# """
# # Load RAFT model and wrap it with DataParallel for multi-GPU support
# print("Loading RAFT model...")
# raft_model = of.raft_large(pretrained=True).cuda()
# if torch.cuda.device_count() > 1:
# print(f"Using {torch.cuda.device_count()} GPUs!")
# raft_model = DataParallel(raft_model)
# # Load video frames using Decord with reduced resolution
# video = VideoReader(video_path, ctx=cpu(0))
# total_frames = len(video)
# frame_indices = range(0, total_frames, frame_interval)
# # Read all frames and resize them for faster processing
# frames = video.get_batch(frame_indices).asnumpy()
# frames = [cv2.resize(frame, (frame.shape[1] // 2, frame.shape[0] // 2)) for frame in frames]
# # Calculate motion between consecutive frames using RAFT in batches
# motion_values = []
# print(f"The total number of frames: {len(frames)}")
# for batch_start in range(1, len(frames), batch_size):
# batch_end = min(batch_start + batch_size, len(frames))
# batch_frames = frames[batch_start - 1:batch_end]
# batch_motions = compute_raft_flow_batch(batch_frames, raft_model)
# motion_values.extend(batch_motions)
# # Calculate threshold statistically
# motion_mean = np.mean(motion_values)
# motion_std = np.std(motion_values)
# threshold = motion_mean + sigma_multiplier * motion_std
# # Extract key frames based on motion threshold
# key_frames = []
# for i, motion in enumerate(motion_values):
# if motion > threshold and len(key_frames) < max_frames:
# img = Image.fromarray(cv2.cvtColor(frames[i + 1], cv2.COLOR_BGR2RGB))
# key_frames.append(img)
# return key_frames
############################################################################################################
# Histogram Difference
# def calculate_histogram_difference(frame_pair):
# """
# Calculates the difference between two consecutive frames using color histograms.
# """
# frame1, frame2 = frame_pair
# # Calculate histograms for each frame
# hist1 = cv2.calcHist([frame1], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
# hist2 = cv2.calcHist([frame2], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256])
# # Normalize histograms
# cv2.normalize(hist1, hist1)
# cv2.normalize(hist2, hist2)
# # Calculate histogram difference using Chi-Squared distan
# difference = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CHISQR)
# return difference
# def extract_motion_key_frames(video_path, max_frames=20, sigma_multiplier=2, frame_interval=1):
# """
# Extracts key frames from a video based on histogram differences.
# Optimized for speed and efficiency.
# """
# # Load video frames using Decord with reduced resolution
# video = VideoReader(video_path, ctx=cpu(0))
# total_frames = len(video)
# frame_indices = range(0, total_frames, frame_interval)
# # Read all frames and resize them for faster processing
# frames = video.get_batch(frame_indices).asnumpy()
# frames = [cv2.resize(frame, (frame.shape[1] // 2, frame.shape[0] // 2)) for frame in frames]
# # Calculate histogram differences between consecutive frames using multiprocessing
# frame_pairs = list(zip(frames[:-1], frames[1:]))
# with Pool(cpu_count()) as pool:
# histogram_differences = pool.map(calculate_histogram_difference, frame_pairs)
# # Calculate threshold statistically
# diff_mean = np.mean(histogram_differences)
# diff_std = np.std(histogram_differences)
# threshold = diff_mean + sigma_multiplier * diff_std
# # Extract key frames based on histogram difference threshold
# key_frames = []
# for i, difference in enumerate(histogram_differences):
# if difference > threshold and len(key_frames) < max_frames:
# img = Image.fromarray(cv2.cvtColor(frames[i + 1], cv2.COLOR_BGR2RGB))
# key_frames.append(img)
# return key_frames
############################################################################################################
# faster histogram
# def calculate_histogram_difference(frame1, frame2):
# """
# Calculates the difference between two consecutive frames using grayscale histograms.
# """
# # Convert frames to grayscale
# gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)
# gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)
# # Calculate histograms with fewer bins (e.g., 16 bins)
# hist1 = cv2.calcHist([gray1], [0], None, [16], [0, 256])
# hist2 = cv2.calcHist([gray2], [0], None, [16], [0, 256])
# # Normalize histograms
# cv2.normalize(hist1, hist1)
# cv2.normalize(hist2, hist2)
# # Calculate histogram difference using Chi-Squared distance
# difference = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CHISQR)
# return difference
# def extract_motion_key_frames(video_path, max_frames=20, sigma_multiplier=2, frame_interval=10):
# """
# Extracts key frames from a video based on histogram differences.
# Optimized for speed by reducing histogram complexity and skipping frames.
# """
# # Load video frames using Decord with reduced resolution
# video = VideoReader(video_path, ctx=cpu(0))
# total_frames = len(video)
# frame_indices = range(0, total_frames, frame_interval)
# # Read all frames and resize them for faster processing
# frames = video.get_batch(frame_indices).asnumpy()
# frames = [cv2.resize(frame, (frame.shape[1] // 2, frame.shape[0] // 2)) for frame in frames]
# # Calculate histogram differences between consecutive frames
# histogram_differences = []
# for i in range(1, len(frames)):
# difference = calculate_histogram_difference(frames[i - 1], frames[i])
# histogram_differences.append(difference)
# # Calculate threshold statistically
# diff_mean = np.mean(histogram_differences)
# diff_std = np.std(histogram_differences)
# threshold = diff_mean + sigma_multiplier * diff_std
# # Extract key frames based on histogram difference threshold
# key_frames = []
# for i, difference in enumerate(histogram_differences):
# if difference > threshold and len(key_frames) < max_frames:
# img = Image.fromarray(cv2.cvtColor(frames[i + 1], cv2.COLOR_BGR2RGB))
# key_frames.append(img)
# return key_frames
############################################################################################################
# faster histogram with batch
def calculate_histogram_difference_batch(frame_batch):
"""
Calculates histogram differences for a batch of frames.
"""
# Convert frames to grayscale
gray_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in frame_batch]
# Calculate histograms for all frames in the batch
histograms = [cv2.calcHist([gray], [0], None, [16], [0, 256]) for gray in gray_frames]
for hist in histograms:
cv2.normalize(hist, hist)
# Calculate histogram differences between consecutive frames
differences = []
for i in range(1, len(histograms)):
difference = cv2.compareHist(histograms[i - 1], histograms[i], cv2.HISTCMP_CHISQR)
differences.append(difference)
return differences
def extract_motion_key_frames(video_path, max_frames=20, sigma_multiplier=2, frame_interval=10, batch_size=16):
"""
Extracts key frames from a video based on histogram differences.
Uses batch processing for faster computation.
"""
# Load video frames using Decord with reduced resolution
video = VideoReader(video_path, ctx=cpu(0))
total_frames = len(video)
print(f"All total frames: {total_frames}")
frame_indices = range(0, total_frames, frame_interval)
# Read all frames and resize them for faster processing
frames = video.get_batch(frame_indices).asnumpy()
frames = [cv2.resize(frame, (frame.shape[1] // 2, frame.shape[0] // 2)) for frame in frames]
# Process frames in batches
histogram_differences = []
print(f"The total number of frames: {len(frames)}")
for batch_start in range(0, len(frames), batch_size):
batch_end = min(batch_start + batch_size + 1, len(frames)) # +1 to include the next frame for difference
batch_frames = frames[batch_start:batch_end]
batch_differences = calculate_histogram_difference_batch(batch_frames)
histogram_differences.extend(batch_differences)
2025-01-23 21:50:55 +08:00
# Calculate threshold statistically
2025-03-31 13:21:15 +08:00
diff_mean = np.mean(histogram_differences)
diff_std = np.std(histogram_differences)
threshold = diff_mean + sigma_multiplier * diff_std
2025-01-23 21:50:55 +08:00
2025-03-31 13:21:15 +08:00
# Extract key frames based on histogram difference threshold
2025-01-23 21:50:55 +08:00
key_frames = []
2025-03-31 13:21:15 +08:00
for i, difference in enumerate(histogram_differences):
if difference > threshold and len(key_frames) < max_frames:
img = Image.fromarray(cv2.cvtColor(frames[i + 1], cv2.COLOR_BGR2RGB))
2025-01-23 21:50:55 +08:00
key_frames.append(img)
return key_frames
2025-03-31 13:21:15 +08:00
############################################################################################################
# faster faster histogram
# def calculate_frame_difference(frame1, frame2):
# """
# Ultra-fast frame difference calculation using downscaled grayscale and absolute pixel differences.
# """
# # Convert to grayscale and downscale further
# gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)
# gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)
# # Downscale to tiny images (e.g., 16x16) for fast comparison
# tiny1 = cv2.resize(gray1, (16, 16))
# tiny2 = cv2.resize(gray2, (16, 16))
# # Calculate normalized absolute difference
# diff = cv2.absdiff(tiny1, tiny2)
# return np.mean(diff) / 255.0 # Normalize to [0,1]
# def save_key_frames(key_frames, output_dir="key_frames", prefix="frame"):
# """
# Saves key frames to disk as JPEG images.
# """
# if not os.path.exists(output_dir):
# os.makedirs(output_dir)
# saved_paths = []
# for i, frame in enumerate(key_frames):
# frame_path = os.path.join(output_dir, f"{prefix}_{i:04d}.jpg")
# frame.save(frame_path, quality=85) # Good quality with reasonable compression
# saved_paths.append(frame_path)
# return saved_paths
# def extract_motion_key_frames(video_path, max_frames=20, sigma_multiplier=2, frame_interval=15):
# # Load video with decord (faster than OpenCV)
# video = VideoReader(video_path, ctx=cpu(0))
# total_frames = len(video)
# # Pre-calculate frame indices to process
# frame_indices = range(0, total_frames, frame_interval)
# frames = video.get_batch(frame_indices).asnumpy()
# # Downscale all frames upfront (much faster than per-frame)
# frames = [cv2.resize(frame, (frame.shape[1]//4, frame.shape[0]//4)) for frame in frames]
# # Calculate differences (vectorized approach)
# differences = []
# prev_frame = frames[0]
# for frame in frames[1:]:
# diff = calculate_frame_difference(prev_frame, frame)
# differences.append(diff)
# prev_frame = frame
# # Adaptive thresholding
# diff_mean = np.mean(differences)
# diff_std = np.std(differences)
# threshold = diff_mean + sigma_multiplier * diff_std
# # Extract key frames
# key_frames = []
# for i, diff in enumerate(differences):
# if diff > threshold and len(key_frames) < max_frames:
# img = Image.fromarray(cv2.cvtColor(frames[i+1], cv2.COLOR_BGR2RGB))
# key_frames.append(img)
# saved_paths = save_key_frames(key_frames, '../video')
# return key_frames