| |
|
|
| import gradio as gr |
| import cv2 |
| import numpy as np |
| import os |
| from datetime import datetime |
|
|
| from scenedetect import open_video, SceneManager |
| from scenedetect.detectors import ContentDetector |
| from moviepy.editor import VideoFileClip |
|
|
| import random |
| from functools import partial |
|
|
| import clip |
| import decord |
| import nncore |
| import torch |
| import torchvision.transforms.functional as F |
| from decord import VideoReader |
| from nncore.engine import load_checkpoint |
| from nncore.nn import build_model |
|
|
| import pandas as pd |
|
|
| def convert_time(seconds): |
| minutes, seconds = divmod(round(max(seconds, 0)), 60) |
| return f'{minutes:02d}:{seconds:02d}' |
|
|
| |
| TUNING_CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py' |
| TUNING_WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth' |
|
|
| |
| def init_tuning_model(config, checkpoint): |
| cfg = nncore.Config.from_file(config) |
| cfg.model.init = True |
|
|
| if checkpoint.startswith('http'): |
| checkpoint = nncore.download(checkpoint, out_dir='checkpoints') |
|
|
| model = build_model(cfg.model, dist=False).eval() |
| model = load_checkpoint(model, checkpoint, warning=False) |
| return model, cfg |
|
|
| tuning_model, tuning_cfg = init_tuning_model(TUNING_CONFIG, TUNING_WEIGHT) |
|
|
| |
| def preprocess_video(video_path, cfg): |
| decord.bridge.set_bridge('torch') |
| vr = decord.VideoReader(video_path) |
| stride = vr.get_avg_fps() / cfg.data.val.fps |
| fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()] |
| video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255 |
|
|
| size = 336 if '336px' in cfg.model.arch else 224 |
| h, w = video.size(-2), video.size(-1) |
| s = min(h, w) |
| x, y = round((h - s) / 2), round((w - s) / 2) |
| video = video[..., x:x + s, y:y + s] |
| video = F.resize(video, size=(size, size)) |
| video = F.normalize(video, (0.481, 0.459, 0.408), (0.269, 0.261, 0.276)) |
| return video.reshape(video.size(0), -1).unsqueeze(0) |
|
|
| |
| def calculate_saliency(video_path, query, model, cfg): |
| if len(query) == 0: |
| return None, None, 0 |
| |
| video = preprocess_video(video_path, cfg) |
| query = clip.tokenize(query, truncate=True) |
| |
| device = next(model.parameters()).device |
| data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps]) |
| |
| with torch.inference_mode(): |
| pred = model(data) |
| |
| hd = pred['_out']['saliency'].cpu() |
| hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).numpy() |
| time_axis = np.arange(0, len(hd) * 2, 2) |
| |
| |
| vr = decord.VideoReader(video_path) |
| duration = len(vr) / vr.get_avg_fps() |
| return hd, time_axis, duration |
|
|
| |
| def find_scenes(video_path, threshold, query): |
| |
| timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") |
| output_dir = f"output_{timestamp}" |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| |
| saliency_scores, time_points, total_duration = calculate_saliency(video_path, query, tuning_model, tuning_cfg) |
| if saliency_scores is None: |
| raise gr.Error("请输入有效的文本查询") |
|
|
| |
| new_time = np.linspace(0, total_duration, num=int(total_duration*10)) |
| interp_scores = np.interp(new_time, time_points, saliency_scores) |
|
|
| |
| filename = os.path.splitext(os.path.basename(video_path))[0] |
| video = open_video(video_path) |
| scene_manager = SceneManager() |
| scene_manager.add_detector(ContentDetector(threshold=threshold)) |
| scene_manager.detect_scenes(video, show_progress=True) |
| scene_list = scene_manager.get_scene_list() |
|
|
| if not scene_list: |
| gr.Warning("No scenes detected in this video") |
| return None, None, None, None |
|
|
| |
| processed_scenes = [] |
| for i, shot in enumerate(scene_list): |
| |
| start_sec = shot[0].get_seconds() |
| end_sec = shot[1].get_seconds() |
| |
| |
| start_idx = np.searchsorted(new_time, start_sec, side='left') |
| end_idx = np.searchsorted(new_time, end_sec, side='right') |
| valid_scores = interp_scores[start_idx:end_idx] |
| |
| |
| valid_scores = valid_scores[~np.isnan(valid_scores)] |
| scene_score = valid_scores.mean() if len(valid_scores) > 0 else 0.0 |
| |
| |
| scene_info = { |
| "start": convert_time(start_sec), |
| "end": convert_time(end_sec), |
| "score": round(float(scene_score), 3), |
| "start_sec": start_sec, |
| "end_sec": end_sec |
| } |
| processed_scenes.append(scene_info) |
|
|
| |
| processed_scenes.sort(key=lambda x: x['score'], reverse=True) |
|
|
| |
| timecodes = [{"title": filename + ".mp4", "fps": scene_list[0][0].get_framerate()}] |
| shots = [] |
| stills = [] |
| |
| for idx, scene in enumerate(processed_scenes): |
| |
| shot_name = f"shot_{idx+1}_{filename}" |
| target_name = os.path.join(output_dir, f"{shot_name}.mp4") |
| |
| |
| with VideoFileClip(video_path) as clip: |
| subclip = clip.subclip(scene['start_sec'], scene['end_sec']) |
| subclip.write_videofile(target_name, |
| codec="libx264", |
| audio_codec="aac", |
| threads=4, |
| preset="fast", |
| ffmpeg_params=["-crf", "23"]) |
|
|
| |
| vid = cv2.VideoCapture(video_path) |
| vid.set(cv2.CAP_PROP_POS_MSEC, scene['start_sec']*1000) |
| ret, frame = vid.read() |
| img_path = os.path.join(output_dir, f"{shot_name}_screenshot.png") |
| cv2.imwrite(img_path, frame) |
| vid.release() |
|
|
| |
| timecodes.append({ |
| "tc_in": scene['start'], |
| "tc_out": scene['end'], |
| "score": scene['score'], |
| "shot_name": shot_name |
| }) |
| shots.append(target_name) |
| stills.append((img_path, f'{shot_name}\nScore: {scene["score"]:.3f}')) |
|
|
| |
| plot_data = pd.DataFrame({ |
| 'x': new_time, |
| 'y': interp_scores |
| }) |
|
|
| return timecodes, shots, stills, plot_data |
|
|
| |
| with gr.Blocks() as demo: |
| with gr.Column(): |
| gr.Markdown(""" |
| # 增强版场景编辑检测 |
| 新增功能: |
| 1. 输入文本查询分析视频内容相关性 |
| 2. 显示相关性时序折线图 |
| 3. 按相关性得分排序输出片段:与查询文本较相关的片段排在 分割片段下载 和 场景缩略图 的前面若干个 |
| """) |
| |
| with gr.Row(): |
| with gr.Column(): |
| video_input = gr.Video(sources="upload", format="mp4", label="视频输入") |
| query_input = gr.Textbox(label="文本查询", placeholder="输入描述视频内容的文本(5-15单词为佳)") |
| threshold = gr.Slider(label="场景切换检测阈值", minimum=15.0, maximum=40.0, value=27.0) |
| with gr.Row(): |
| clear_button = gr.Button("清除") |
| run_button = gr.Button("开始处理", variant="primary") |
| plot_output = gr.LinePlot(x='x', y='y', x_title='时间(秒)', |
| y_title='相关性得分', label='时序相关性分析') |
| with gr.Column(): |
| json_output = gr.JSON(label="场景分析结果(按得分排序)") |
|
|
| file_output = gr.File(label="分割片段下载") |
| gallery_output = gr.Gallery(label="场景缩略图", object_fit="cover", columns=3) |
|
|
| run_button.click( |
| fn=find_scenes, |
| inputs=[video_input, threshold, query_input], |
| outputs=[json_output, file_output, gallery_output, plot_output] |
| ) |
| clear_button.click( |
| fn=lambda: [None, 27, None, None, None, None], |
| inputs=None, |
| outputs=[video_input, threshold, query_input, json_output, file_output, gallery_output] |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| ["anime_kiss.mp4", 27, "A romantic kiss scene between two characters"], |
| ["anime_tear.mp4", 30, "An anime character is crying."], |
| ["[P1]《原神》5.4版本PV:「梦间见月明」.mp4", 27, "A lady with a fan."] |
| ], |
| inputs=[video_input, threshold, query_input], |
| outputs=[json_output, file_output, gallery_output, plot_output], |
| fn=find_scenes, |
| cache_examples=False |
| ) |
|
|
| demo.queue().launch(debug=True, share=True) |