import os import json import time import shutil import warnings from html import escape from pathlib import Path from typing import Optional import gradio as gr from huggingface_hub import snapshot_download from PIL import Image, ImageFile from handler import EndpointHandler from translator import translate_texts # ------------------------------------------------------------------ # 安全配置 # ------------------------------------------------------------------ # 1) 限制上传文件原始体积,拦截伪装图片/图片中塞入额外数据/高熵噪声导致的超大文件 MAX_UPLOAD_BYTES = 8 * 1024 * 1024 # 8 MB # 2) 限制单边尺寸,避免异常超大分辨率 MAX_IMAGE_SIDE = 4096 # 3) 限制总像素数,防止“像素炸弹”或解码后内存占用过高 MAX_IMAGE_PIXELS = 20_000_000 # 2000 万像素 # 4) 限制解码后的估算内存占用 MAX_DECOMPRESSED_BYTES = 160 * 1024 * 1024 # 160 MB # 5) 仅允许常见安全图片格式 ALLOWED_IMAGE_FORMATS = {"PNG", "JPEG", "WEBP", "BMP", "GIF"} # Pillow 安全设置 Image.MAX_IMAGE_PIXELS = MAX_IMAGE_PIXELS ImageFile.LOAD_TRUNCATED_IMAGES = False warnings.simplefilter("error", Image.DecompressionBombWarning) class ImageValidationError(ValueError): """上传图片校验失败。""" def _format_size(num_bytes: int) -> str: if num_bytes < 1024: return f"{num_bytes} B" if num_bytes < 1024 * 1024: return f"{num_bytes / 1024:.2f} KB" return f"{num_bytes / (1024 * 1024):.2f} MB" def validate_and_open_image(image_path: str) -> Image.Image: """ 安全打开用户上传图片: - 校验原始文件体积 - 校验图片格式 - 校验宽高/总像素 - 校验解码后预估内存占用 - 拦截 Pillow 解压炸弹警告 """ if not image_path: raise ImageValidationError("未检测到上传文件。") if not os.path.isfile(image_path): raise ImageValidationError("上传文件不存在或无法访问。") file_size = os.path.getsize(image_path) if file_size <= 0: raise ImageValidationError("上传文件为空。") if file_size > MAX_UPLOAD_BYTES: raise ImageValidationError( f"图片文件过大:{_format_size(file_size)},超过限制 {_format_size(MAX_UPLOAD_BYTES)}。" ) try: with Image.open(image_path) as probe: img_format = (probe.format or "").upper() width, height = probe.size probe.verify() except Image.DecompressionBombWarning: raise ImageValidationError("图片疑似像素炸弹,已被拒绝处理。") except Exception as e: raise ImageValidationError(f"无法解析为有效图片文件:{e}") if img_format not in ALLOWED_IMAGE_FORMATS: raise ImageValidationError( f"不支持的图片格式:{img_format or '未知'}。仅允许:{', '.join(sorted(ALLOWED_IMAGE_FORMATS))}。" ) if width <= 0 or height <= 0: raise ImageValidationError("图片尺寸非法。") if width > MAX_IMAGE_SIDE or height > MAX_IMAGE_SIDE: raise ImageValidationError( f"图片尺寸过大:{width}×{height},单边不得超过 {MAX_IMAGE_SIDE} 像素。" ) total_pixels = width * height if total_pixels > MAX_IMAGE_PIXELS: raise ImageValidationError( f"图片总像素过大:{total_pixels:,},超过限制 {MAX_IMAGE_PIXELS:,}。" ) estimated_decompressed_bytes = total_pixels * 3 if estimated_decompressed_bytes > MAX_DECOMPRESSED_BYTES: raise ImageValidationError( "图片解码后的内存占用过高,已拒绝处理。" f" 预计占用约 {_format_size(estimated_decompressed_bytes)}," f"超过限制 {_format_size(MAX_DECOMPRESSED_BYTES)}。" ) try: with Image.open(image_path) as img: img.load() if img.mode != "RGB": img = img.convert("RGB") else: img = img.copy() except Image.DecompressionBombWarning: raise ImageValidationError("图片在解码阶段触发像素炸弹保护,已拒绝处理。") except Exception as e: raise ImageValidationError(f"图片加载失败:{e}") return img # ------------------------------------------------------------------ # 新版 PixAI Tagger v0.9 模型配置 # ------------------------------------------------------------------ ASSETS_REPO_ID = os.environ.get("ASSETS_REPO_ID", "pixai-labs/pixai-tagger-v0.9") ASSETS_REVISION = os.environ.get("ASSETS_REVISION") MODEL_DIR = os.environ.get("MODEL_DIR", "./assets") HF_TOKEN = ( os.environ.get("HUGGINGFACE_HUB_TOKEN") or os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") ) REQUIRED_FILES = [ "model_v0.9.pth", "tags_v0.9_13k.json", "char_ip_map.json", ] def ensure_assets(repo_id: str, revision: Optional[str], target_dir: str) -> None: """ 下载 pixai-labs/pixai-tagger-v0.9 所需资源,并复制到 handler 期望的本地目录。 如果文件已经存在,则不会重复下载。 """ target = Path(target_dir) target.mkdir(parents=True, exist_ok=True) missing = [fname for fname in REQUIRED_FILES if not (target / fname).exists()] if not missing: return snapshot_path = snapshot_download( repo_id=repo_id, revision=revision, allow_patterns=REQUIRED_FILES, token=HF_TOKEN, ) for fname in REQUIRED_FILES: src = Path(snapshot_path) / fname dst = target / fname if not src.exists(): raise FileNotFoundError( f"模型资源缺失:'{fname}' 未在 {repo_id} @ {revision or 'default'} 中找到。" ) if src.resolve() != dst.resolve(): shutil.copyfile(src, dst) # ------------------------------------------------------------------ # Tagger 类:使用新版 EndpointHandler # ------------------------------------------------------------------ class Tagger: def __init__(self): self.handler = None self.device = "unknown" self._load_model_and_labels() def _load_model_and_labels(self) -> None: try: ensure_assets(ASSETS_REPO_ID, ASSETS_REVISION, MODEL_DIR) self.handler = EndpointHandler(MODEL_DIR) self.device = getattr(self.handler, "device", "unknown") print(f"✅ PixAI Tagger v0.9 加载成功,设备:{str(self.device).upper()}") except Exception as e: print(f"❌ PixAI Tagger v0.9 加载失败: {e}") raise RuntimeError(f"模型初始化失败: {e}") from e @staticmethod def _display_tag(tag: str) -> str: return str(tag).replace("_", " ") @staticmethod def _get_score(scores: dict, tag: str) -> float: """ handler 通常以原始 tag 作为分数字典 key。 这里额外兼容空格/下划线两种写法,避免 key 不一致时取不到分数。 """ if not isinstance(scores, dict): return 0.0 candidates = [ tag, str(tag).replace("_", " "), str(tag).replace(" ", "_"), ] for key in candidates: if key in scores: try: return float(scores[key]) except Exception: return 0.0 return 0.0 def predict(self, img: Image.Image, gen_th: float = 0.30, char_th: float = 0.85): """ 返回结构保持原 app.py 的 UI 处理习惯: - general:通用/特征标签,带置信度 - characters:角色标签,带置信度 - ips:IP 标签,新模型不返回评分标签,因此原 ratings 改为 ips,且 IP 不展示伪造置信度 """ if self.handler is None: raise RuntimeError("模型未成功加载,无法进行预测。") if img is None: raise ValueError("输入图像不能为空。") params = { "general_threshold": float(gen_th), "character_threshold": float(char_th), "mode": "threshold", "topk_general": 25, "topk_character": 10, "include_scores": True, } data = { "inputs": img, "parameters": params, } started = time.time() out = self.handler(data) latency = round(time.time() - started, 4) feature_tags = out.get("feature", []) or [] character_tags = out.get("character", []) or [] ip_tags = out.get("ip", []) or [] feature_scores = out.get("feature_scores", {}) or {} character_scores = out.get("character_scores", {}) or {} general = { self._display_tag(tag): self._get_score(feature_scores, tag) for tag in feature_tags } characters = { self._display_tag(tag): self._get_score(character_scores, tag) for tag in character_tags } # IP 标签没有评分,使用 None 表示“不显示置信度” ips = { self._display_tag(tag): None for tag in ip_tags } general = dict(sorted(general.items(), key=lambda kv: kv[1], reverse=True)) characters = dict(sorted(characters.items(), key=lambda kv: kv[1], reverse=True)) res = { "general": general, "characters": characters, "ips": ips, } tag_categories_for_translation = { "general": list(general.keys()), "characters": list(characters.keys()), "ips": list(ips.keys()), } raw_meta = { "device": str(self.device), "latency_s_total": latency, "_params": out.get("_params", params), "_timings": out.get("_timings", {}), } return res, tag_categories_for_translation, raw_meta # 全局 Tagger 实例 try: tagger_instance = Tagger() except RuntimeError as e: print(f"应用启动时 Tagger 初始化失败: {e}") tagger_instance = None DEVICE_LABEL = ( f"设备:{str(tagger_instance.device).upper()}" if tagger_instance is not None else "设备:UNKNOWN" ) # ------------------------------------------------------------------ # Gradio UI # ------------------------------------------------------------------ custom_css = """ .label-container { max-height: 300px; overflow-y: auto; border: 1px solid #ddd; padding: 10px; border-radius: 5px; background-color: #f9f9f9; } .tag-item { display: flex; justify-content: space-between; align-items: center; margin: 2px 0; padding: 2px 5px; border-radius: 3px; background-color: #fff; transition: background-color 0.2s; } .tag-item:hover { background-color: #f0f0f0; } .tag-en { font-weight: bold; color: #333; cursor: pointer; } .tag-zh { color: #666; margin-left: 10px; } .tag-score { color: #999; font-size: 0.9em; white-space: nowrap; } .btn-analyze-container { margin-top: 15px; margin-bottom: 15px; } """ _js_functions = """ function copyToClipboard(text) { console.log('copyToClipboard function was called.'); console.log('Received text:', text); if (typeof text === 'undefined' || text === null) { console.warn('copyToClipboard was called with undefined or null text. Aborting this specific copy operation.'); return; } navigator.clipboard.writeText(text).then(() => { const feedback = document.createElement('div'); let displayText = String(text); displayText = displayText.substring(0, 30) + (displayText.length > 30 ? '...' : ''); feedback.textContent = '已复制: ' + displayText; feedback.style.position = 'fixed'; feedback.style.bottom = '20px'; feedback.style.left = '50%'; feedback.style.transform = 'translateX(-50%)'; feedback.style.backgroundColor = '#4CAF50'; feedback.style.color = 'white'; feedback.style.padding = '10px 20px'; feedback.style.borderRadius = '5px'; feedback.style.zIndex = '10000'; feedback.style.transition = 'opacity 0.5s ease-out'; document.body.appendChild(feedback); setTimeout(() => { feedback.style.opacity = '0'; setTimeout(() => { if (document.body.contains(feedback)) { document.body.removeChild(feedback); } }, 500); }, 1500); }).catch(err => { console.error('Failed to copy tag. Error:', err, 'Attempted to copy text:', text); const errorFeedback = document.createElement('div'); errorFeedback.textContent = '复制操作失败!'; errorFeedback.style.position = 'fixed'; errorFeedback.style.bottom = '20px'; errorFeedback.style.left = '50%'; errorFeedback.style.transform = 'translateX(-50%)'; errorFeedback.style.backgroundColor = '#D32F2F'; errorFeedback.style.color = 'white'; errorFeedback.style.padding = '10px 20px'; errorFeedback.style.borderRadius = '5px'; errorFeedback.style.zIndex = '10000'; errorFeedback.style.transition = 'opacity 0.5s ease-out'; document.body.appendChild(errorFeedback); setTimeout(() => { errorFeedback.style.opacity = '0'; setTimeout(() => { if (document.body.contains(errorFeedback)) { document.body.removeChild(errorFeedback); } }, 500); }, 2500); }); } """ with gr.Blocks(theme=gr.themes.Soft(), title="AI 图像标签分析器", css=custom_css, js=_js_functions) as demo: gr.Markdown("# 🖼️ AI 图像标签分析器") gr.Markdown( "上传图片自动识别标签,支持中英文显示和一键复制。" "[NovelAI在线绘画](https://nai.idlecloud.cc/)\n\n" f"**当前模型:pixai-labs/pixai-tagger-v0.9** | **{DEVICE_LABEL}**\n\n" "说明:新版模型不再返回评分标签,本页面已将原“评分标签”区域改为“IP 标签”。" ) state_res = gr.State({}) state_translations_dict = gr.State({}) state_tag_categories_for_translation = gr.State({}) with gr.Row(): with gr.Column(scale=1): img_in = gr.Image(type="filepath", label="上传图片", height=300) btn = gr.Button("🚀 开始分析", variant="primary", elem_classes=["btn-analyze-container"]) with gr.Accordion("⚙️ 高级设置", open=False): gen_slider = gr.Slider( 0, 1, value=0.30, step=0.01, label="通用标签阈值", info="越高 → 标签更少更准", ) char_slider = gr.Slider( 0, 1, value=0.85, step=0.01, label="角色标签阈值", info="推荐保持较高阈值", ) show_tag_scores = gr.Checkbox( True, label="在列表中显示标签置信度", info="IP 标签不返回置信度,因此不会显示分数。", ) with gr.Accordion("📊 标签汇总设置", open=True): gr.Markdown("选择要包含在下方汇总文本框中的标签类别:") with gr.Row(): sum_general = gr.Checkbox(True, label="通用标签", min_width=50) sum_char = gr.Checkbox(True, label="角色标签", min_width=50) sum_ip = gr.Checkbox(False, label="IP 标签", min_width=50) sum_sep = gr.Dropdown(["逗号", "换行", "空格"], value="逗号", label="标签之间的分隔符") sum_show_zh = gr.Checkbox(False, label="在汇总中显示中文翻译") processing_info = gr.Markdown("", visible=False) with gr.Column(scale=2): with gr.Tabs(): with gr.TabItem("🏷️ 通用标签"): out_general = gr.HTML(label="General Tags") with gr.TabItem("👤 角色标签"): gr.Markdown("

提示:角色标签由模型推断,建议保持较高阈值。

") out_char = gr.HTML(label="Character Tags") with gr.TabItem("🌐 IP 标签"): gr.Markdown("

提示:新版模型输出 IP 标签,但不返回评分标签/评分置信度。

") out_ip = gr.HTML(label="IP Tags") gr.Markdown("### 标签汇总结果") out_summary = gr.Textbox( label="标签汇总", placeholder="分析完成后,此处将显示汇总的英文标签...", lines=5, show_copy_button=True, ) with gr.Accordion("🧾 推理元数据", open=False): out_meta = gr.JSON(label="Metadata") # ----------------- 辅助函数 ----------------- def format_tags_html(tags_dict, translations_list, category_name, show_scores=True, show_translation_in_list=True): if not tags_dict: return "

暂无标签

" html = '
' if not isinstance(translations_list, list): translations_list = [] tag_keys = list(tags_dict.keys()) for i, tag in enumerate(tag_keys): score = tags_dict[tag] safe_tag_text = escape(str(tag)) js_arg = json.dumps(str(tag), ensure_ascii=False) html += '
' tag_display_html = ( f'{safe_tag_text}' ) if show_translation_in_list and i < len(translations_list) and translations_list[i]: tag_display_html += f'({escape(str(translations_list[i]))})' html += f"
{tag_display_html}
" if show_scores and isinstance(score, (int, float)): html += f'{score:.3f}' html += "
" html += "
" return html def generate_summary_text_content( current_res, current_translations_dict, s_gen, s_char, s_ip, s_sep_type, s_show_zh, ): if not current_res: return "请先分析图像或选择要汇总的标签类别。" summary_parts = [] separators = {"逗号": ", ", "换行": "\n", "空格": " "} separator = separators.get(s_sep_type, ", ") categories_to_summarize = [] if s_gen: categories_to_summarize.append("general") if s_char: categories_to_summarize.append("characters") if s_ip: categories_to_summarize.append("ips") if not categories_to_summarize: return "请至少选择一个标签类别进行汇总。" for cat_key in categories_to_summarize: if current_res.get(cat_key): tags_to_join = [] cat_tags_en = list(current_res[cat_key].keys()) cat_translations = current_translations_dict.get(cat_key, []) for i, en_tag in enumerate(cat_tags_en): if s_show_zh and i < len(cat_translations) and cat_translations[i]: tags_to_join.append(f"{en_tag}/*{cat_translations[i]}*/") else: tags_to_join.append(en_tag) if tags_to_join: summary_parts.append(separator.join(tags_to_join)) joiner = "\n\n" if separator != "\n" and len(summary_parts) > 1 else separator if separator == "\n" else " " final_summary = joiner.join(summary_parts) return final_summary if final_summary else "选定的类别中没有找到标签。" def process_image_and_generate_outputs( image_path, g_th, c_th, s_scores, s_gen, s_char, s_ip, s_sep, s_zh_in_sum, ): if image_path is None: yield ( gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="❌ 请先上传图片。"), "", "", "", "", {}, {}, {}, {}, ) return if tagger_instance is None: yield ( gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="❌ 分析器未成功初始化,请检查控制台错误。"), "", "", "", "", {}, {}, {}, {}, ) return yield ( gr.update(interactive=False, value="🔄 处理中..."), gr.update(visible=True, value="🔄 正在校验并分析图像,请稍候..."), gr.HTML(value="

分析中...

"), gr.HTML(value="

分析中...

"), gr.HTML(value="

分析中...

"), gr.update(value="分析中,请稍候..."), {}, {}, {}, {}, ) try: img = validate_and_open_image(image_path) res, tag_categories_original_order, meta = tagger_instance.predict(img, g_th, c_th) all_tags_to_translate = [] for cat_key in ["general", "characters", "ips"]: all_tags_to_translate.extend(tag_categories_original_order.get(cat_key, [])) all_translations_flat = [] if all_tags_to_translate: try: all_translations_flat = translate_texts(all_tags_to_translate, src_lang="auto", tgt_lang="zh") except Exception as translate_error: print(f"⚠️ 标签翻译失败,将仅显示英文标签:{translate_error}") all_translations_flat = [""] * len(all_tags_to_translate) current_translations_dict = {} offset = 0 for cat_key in ["general", "characters", "ips"]: cat_original_tags = tag_categories_original_order.get(cat_key, []) num_tags_in_cat = len(cat_original_tags) if num_tags_in_cat > 0: current_translations_dict[cat_key] = all_translations_flat[offset: offset + num_tags_in_cat] offset += num_tags_in_cat else: current_translations_dict[cat_key] = [] general_html = format_tags_html( res.get("general", {}), current_translations_dict.get("general", []), "general", s_scores, True, ) char_html = format_tags_html( res.get("characters", {}), current_translations_dict.get("characters", []), "characters", s_scores, True, ) ip_html = format_tags_html( res.get("ips", {}), current_translations_dict.get("ips", []), "ips", s_scores, True, ) summary_text = generate_summary_text_content( res, current_translations_dict, s_gen, s_char, s_ip, s_sep, s_zh_in_sum, ) yield ( gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value="✅ 分析完成!"), general_html, char_html, ip_html, gr.update(value=summary_text), res, current_translations_dict, tag_categories_original_order, meta, ) except ImageValidationError as e: yield ( gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value=f"❌ 上传图片未通过安全校验:{str(e)}"), "

图片已被安全策略拒绝

", "

图片已被安全策略拒绝

", "

图片已被安全策略拒绝

", gr.update(value=f"错误: {str(e)}", placeholder="上传图片未通过安全校验..."), {}, {}, {}, {}, ) except Exception as e: import traceback tb_str = traceback.format_exc() print(f"处理时发生错误: {e}\n{tb_str}") yield ( gr.update(interactive=True, value="🚀 开始分析"), gr.update(visible=True, value=f"❌ 处理失败: {str(e)}"), "

处理出错

", "

处理出错

", "

处理出错

", gr.update(value=f"错误: {str(e)}", placeholder="分析失败..."), {}, {}, {}, {}, ) def update_summary_display( s_gen, s_char, s_ip, s_sep, s_zh_in_sum, current_res_from_state, current_translations_from_state, ): if not current_res_from_state: return gr.update(placeholder="请先完成一次图像分析以生成汇总。", value="") new_summary_text = generate_summary_text_content( current_res_from_state, current_translations_from_state, s_gen, s_char, s_ip, s_sep, s_zh_in_sum, ) return gr.update(value=new_summary_text) btn.click( process_image_and_generate_outputs, inputs=[ img_in, gen_slider, char_slider, show_tag_scores, sum_general, sum_char, sum_ip, sum_sep, sum_show_zh, ], outputs=[ btn, processing_info, out_general, out_char, out_ip, out_summary, state_res, state_translations_dict, state_tag_categories_for_translation, out_meta, ], ) summary_controls = [sum_general, sum_char, sum_ip, sum_sep, sum_show_zh] for ctrl in summary_controls: ctrl.change( fn=update_summary_display, inputs=summary_controls + [state_res, state_translations_dict], outputs=[out_summary], ) if __name__ == "__main__": if tagger_instance is None: print("CRITICAL: Tagger 未能初始化,应用功能将受限。请检查之前的错误信息。") demo.queue(max_size=8).launch(server_name="0.0.0.0", server_port=7860)