100 lines
4.1 KiB
Python
100 lines
4.1 KiB
Python
import asyncio
|
|
import gradio as gr
|
|
from gradio_image_prompter import ImagePrompter
|
|
from endpoints.text import text_query
|
|
from endpoints.image import image_query
|
|
from endpoints.video import video_query
|
|
|
|
import torch
|
|
print("Available GPUs:", torch.cuda.device_count())
|
|
print("Visible Devices:", [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
|
|
|
|
def setup_ui():
|
|
with gr.Blocks() as ui:
|
|
gr.Markdown(
|
|
"""
|
|
# Multimodal Query Interface
|
|
Submit text, image, or video queries and get insights powered by APIs.
|
|
"""
|
|
)
|
|
|
|
# Tabbed layout
|
|
with gr.Tabs():
|
|
# Text Query Tab
|
|
with gr.Tab("Text Query"):
|
|
gr.Markdown("### Submit a Text Query")
|
|
with gr.Row():
|
|
text_input = gr.Textbox(label="Your Question", placeholder="Type your question here...")
|
|
text_button = gr.Button("Submit")
|
|
text_output = gr.Textbox(label="Response", interactive=False)
|
|
text_button.click(
|
|
fn=lambda q: asyncio.run(text_query(q)),
|
|
inputs=[text_input],
|
|
outputs=[text_output]
|
|
)
|
|
|
|
# Image Query Tab
|
|
with gr.Tab("Image Query"):
|
|
gr.Markdown("### Submit an Image Query")
|
|
with gr.Row():
|
|
image_prompter = ImagePrompter(show_label=False)
|
|
image_question_input = gr.Textbox(label="Your Question", placeholder="Type your question here...")
|
|
image_button = gr.Button("Submit")
|
|
image_output = gr.Textbox(label="Response", interactive=False)
|
|
|
|
# async def handle_image_query(prompts, question):
|
|
# response = await image_query(prompts["image"], question)
|
|
# return response["response"] if "response" in response else response["error"]
|
|
|
|
async def handle_image_query(prompts, question):
|
|
"""
|
|
Handles the image query and ensures that inputs are valid.
|
|
"""
|
|
try:
|
|
# Validate prompts
|
|
if prompts is None or "image" not in prompts:
|
|
return "No image provided. Please upload an image."
|
|
|
|
image_data = prompts["image"]
|
|
|
|
# Check if image_data is valid
|
|
if image_data is None:
|
|
return "Invalid image input. Please upload a valid image."
|
|
|
|
# Call the `image_query` function
|
|
response = await image_query(image_data, question)
|
|
return response["response"] if "response" in response else response["error"]
|
|
except Exception as e:
|
|
return str(e)
|
|
|
|
image_button.click(
|
|
fn=handle_image_query,
|
|
inputs=[image_prompter, image_question_input],
|
|
outputs=[image_output]
|
|
)
|
|
|
|
# Video Query Tab
|
|
with gr.Tab("Video Query"):
|
|
gr.Markdown("### Submit a Video Query")
|
|
with gr.Row():
|
|
video_input = gr.Video(label="Upload Video")
|
|
video_question_input = gr.Textbox(label="Your Question", placeholder="Type your question here...")
|
|
video_button = gr.Button("Submit")
|
|
video_output = gr.Textbox(label="Response", interactive=False)
|
|
|
|
async def handle_video_query(video, question):
|
|
response = await video_query(video, question)
|
|
return response.get("responses", response.get("error", "Error processing video."))
|
|
|
|
video_button.click(
|
|
fn=handle_video_query,
|
|
inputs=[video_input, video_question_input],
|
|
outputs=[video_output]
|
|
)
|
|
|
|
return ui
|
|
|
|
if __name__ == "__main__":
|
|
ui = setup_ui()
|
|
ui.launch(server_name="0.0.0.0", server_port=8002)
|