tenmenbot commited on
Commit
7eb64ea
·
verified ·
1 Parent(s): 5c389b4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import numpy as np
4
  import faiss
5
  from sentence_transformers import SentenceTransformer
6
- from transformers import pipeline
7
 
8
  # 記事フォルダ読み込み
9
  articles_dir = "articles"
@@ -31,8 +31,16 @@ for fname in os.listdir(articles_dir):
31
  index = faiss.IndexFlatL2(384)
32
  index.add(np.array(vectors))
33
 
34
- # 要約モデル(ken11/japanese-summary-model)
35
- summarizer = pipeline("summarization", model="ken11/japanese-summary-model")
 
 
 
 
 
 
 
 
36
 
37
  # チャットボット関数
38
  def chat(query):
@@ -42,10 +50,8 @@ def chat(query):
42
  retrieved_titles = [titles[i] for i in I[0]]
43
  retrieved_urls = [urls[i] for i in I[0]]
44
 
45
- context = "\n\n".join(retrieved_texts)[:1000] # BARTは長文に弱いので最大1000文字に制限
46
- prompt = f"{context}\n\n質問:{query}\nこの情報をもとに簡潔に回答してください。"
47
-
48
- summary = summarizer(prompt, max_length=128, min_length=30, do_sample=False)[0]["summary_text"]
49
 
50
  links = "\n".join([f"🔗 [{retrieved_titles[i]}]({retrieved_urls[i]})" for i in range(len(retrieved_titles))])
51
  return f"{summary}\n\n参考記事:\n{links}"
 
3
  import numpy as np
4
  import faiss
5
  from sentence_transformers import SentenceTransformer
6
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
7
 
8
  # 記事フォルダ読み込み
9
  articles_dir = "articles"
 
31
  index = faiss.IndexFlatL2(384)
32
  index.add(np.array(vectors))
33
 
34
+ # T5要約モデル
35
+ tokenizer = T5Tokenizer.from_pretrained("sonoisa/t5-base-japanese")
36
+ t5_model = T5ForConditionalGeneration.from_pretrained("sonoisa/t5-base-japanese")
37
+
38
+ def generate_summary(text):
39
+ input_text = "summarize: " + text.replace("\n", " ")
40
+ input_ids = tokenizer.encode(input_text, return_tensors="pt", max_length=512, truncation=True)
41
+ output_ids = t5_model.generate(input_ids, max_length=128, min_length=32, do_sample=False)
42
+ summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
43
+ return summary
44
 
45
  # チャットボット関数
46
  def chat(query):
 
50
  retrieved_titles = [titles[i] for i in I[0]]
51
  retrieved_urls = [urls[i] for i in I[0]]
52
 
53
+ context = "\n\n".join(retrieved_texts)[:1000]
54
+ summary = generate_summary(context)
 
 
55
 
56
  links = "\n".join([f"🔗 [{retrieved_titles[i]}]({retrieved_urls[i]})" for i in range(len(retrieved_titles))])
57
  return f"{summary}\n\n参考記事:\n{links}"