Spaces:
Sleeping
Sleeping
| import re | |
| # ์์ฑ๋ ๋ชจ๋ ๋ด ์๋ต ๊ธฐ๋ก | |
| def generate_reply(ctx, makePipeLine, user_msg): | |
| # ์ต์ด ์๋ต | |
| response = generate_valid_response(ctx, makePipeLine, user_msg) | |
| ctx.addHistory("bot", response) | |
| # ๋ถ์์ ํ ์๋ต์ด ์ ๋๋๋ฏ๋ก ์ฌ์ฉํ์ง ์์ | |
| ''' | |
| # ์๋ต์ด ๋๊ฒผ๋ค๋ฉด ์ถ๊ฐ ์์ฑ | |
| if is_truncated_response(response): | |
| continuation = generate_valid_response(ctx, makePipeLine, response) | |
| ctx.addHistory("bot", continuation) | |
| ''' | |
| # ๋ด ์๋ต 1ํ ์์ฑ | |
| def generate_valid_response(ctx, makePipeline, user_msg) -> str: | |
| user_name = ctx.getUserName() | |
| bot_name = ctx.getBotName() | |
| while True: | |
| prompt = build_prompt(ctx.getHistory(), user_msg, user_name, bot_name) | |
| print("\n==========[DEBUG: Prompt]==========") | |
| print(prompt) | |
| print("===================================\n") | |
| full_text = makePipeline.character_chat(prompt) | |
| response = extract_response(full_text) | |
| if is_valid_response(response, user_name, bot_name): | |
| break | |
| return clean_response(response, bot_name) | |
| # ์ ๋ ฅ ํ๋กฌํํธ ์ ๋ฆฌ | |
| def build_prompt(history, user_msg, user_name, bot_name): | |
| with open("assets/prompt/init.txt", "r", encoding="utf-8") as f: | |
| system_prompt = f.read().strip() | |
| # ์ต๊ทผ ๋ํ ํ์คํ ๋ฆฌ๋ฅผ ์ผ๋ฐ ํ ์คํธ๋ก ์ฌ๊ตฌ์ฑ | |
| dialogue = "" | |
| for turn in history[-16:]: | |
| role = user_name if turn["role"] == "user" else bot_name | |
| dialogue += f"{role}: {turn['text']}\n" | |
| dialogue += f"{user_name}: {user_msg}\n" | |
| # ๋ชจ๋ธ์ ๋ง๋ ํฌ๋งท ๊ตฌ์ฑ | |
| prompt = f"""### Instruction: | |
| {system_prompt} | |
| {dialogue} | |
| ### Response: | |
| {bot_name}:""" | |
| return prompt | |
| # ์ถ๋ ฅ์์ ์๋ต ์ถ์ถ (HyperCLOVAX ํฌ๋งท์ ๋ง๊ฒ) | |
| def extract_response(full_text): | |
| # '### Response:' ์ดํ ํ ์คํธ ์ถ์ถ | |
| if "### Response:" in full_text: | |
| reply = full_text.split("### Response:")[-1].strip() | |
| else: | |
| reply = full_text.strip() | |
| return reply | |
| # ์๋ต ์ ํจ์ฑ ๊ฒ์ฌ | |
| def is_valid_response(text: str, user_name, bot_name) -> bool: | |
| if user_name + ":" in text: | |
| return False | |
| return True | |
| # ์ถ๋ ฅ ์ ์ | |
| def clean_response(text: str, bot_name): | |
| # bot_name ์ ๊ฑฐ | |
| text = re.sub(rf"{bot_name}:\s*", "", text).strip() | |
| # ๋ฏธ์์ฑ ๋ฌธ์ฅ ์ ๊ฑฐ | |
| return clean_truncated_response(text) | |
| # ๋ฏธ์์ฑ ๋ฌธ์ฅ ์ญ์ | |
| def clean_truncated_response(text: str) -> str: | |
| """ | |
| ์๋ต ํ ์คํธ๊ฐ ๋ฏธ์์ฑ๋ ๋ฌธ์ฅ์ผ๋ก ๋๋๋ฉด ๋ง์ง๋ง ๋ฌธ์ฅ์ ์ ๊ฑฐํ์ฌ ๋ฐํ, | |
| ๊ทธ๋ ์ง ์์ผ๋ฉด ์๋ฌธ ๊ทธ๋๋ก ๋ฐํ. | |
| """ | |
| # ๋ฌธ์ฅ ๋ถ๋ฆฌ ('.', '?', '!', '~' ๋ฑ ๊ธฐ์ค + ์ค๋ฐ๊ฟ ํฌํจ) | |
| sentence_end_pattern = r"(?<=[\.?!~])\s|\n" | |
| segments = re.split(sentence_end_pattern, text.strip()) | |
| if not segments: | |
| return text.strip() | |
| cleaned = [] | |
| for s in segments: | |
| s = s.strip() | |
| if not s: | |
| continue | |
| # ๋ฌธ์ฅ ๋ถํธ๋ก ๋๋๋ ๊ฒฝ์ฐ๋ง ํฌํจ | |
| if re.search(r"[.?!~โฆ\u2026\u2639\u263A\u2764\uD83D\uDE0A\uD83D\uDE22]$", s): | |
| cleaned.append(s) | |
| else: | |
| break # ๋ถ์์ ํ ๋ฌธ์ฅ์ด๋ฏ๋ก ์ดํ ๋ชจ๋ ์ ๊ฑฐ | |
| # ๋ง์ฝ ๋ชจ๋ ๋ฌธ์ฅ์ด ๋๋งบ์์ ์ ํ๋ค๋ฉด โ ์๋ฌธ ๋ฐํ | |
| result = " ".join(cleaned) | |
| return result if result != "" else text.strip() |