fix: improve inline citation parsing bump:patch

This commit is contained in:
Tadashi 2024-11-26 20:52:03 +07:00
parent f3a2a293f2
commit f15abdbb23
No known key found for this signature in database
GPG Key ID: 399380A00CC9028D
2 changed files with 90 additions and 19 deletions

View File

@ -342,11 +342,11 @@ class AnswerWithContextPipeline(BaseComponent):
span_idx = span.get("idx", None)
if span_idx is not None:
to_highlight = f"{span_idx + 1}" + to_highlight
to_highlight = f"{span_idx}" + to_highlight
text += Render.highlight(
to_highlight,
elem_id=str(span_idx + 1) if span_idx is not None else None,
elem_id=str(span_idx) if span_idx is not None else None,
)
if idx < len(ss) - 1:
text += cur_doc.text[span["end"] : ss[idx + 1]["start"]]

View File

@ -1,6 +1,7 @@
import re
import threading
from collections import defaultdict
from dataclasses import dataclass
from typing import Generator
import numpy as np
@ -8,7 +9,6 @@ import numpy as np
from kotaemon.base import AIMessage, Document, HumanMessage, SystemMessage
from kotaemon.llms import PromptTemplate
from .citation import CiteEvidence
from .citation_qa import CITATION_TIMEOUT, MAX_IMAGES, AnswerWithContextPipeline
from .format_context import EVIDENCE_MODE_FIGURE
from .utils import find_start_end_phrase
@ -61,12 +61,27 @@ START_PHRASE: Fixed-size Chunker This is our baseline chunker
END_PHRASE: this shows good retrieval quality.
FINAL ANSWER
An alternative to semantic chunking is fixed-size chunking. This traditional method involves splitting documents into chunks of a predetermined or user-specified size, regardless of semantic content, which is computationally efficient1. However, it may result in the fragmentation of semantically related content, thereby potentially degrading retrieval performance2.
An alternative to semantic chunking is fixed-size chunking. This traditional method involves splitting documents into chunks of a predetermined or user-specified size, regardless of semantic content, which is computationally efficient1. However, it may result in the fragmentation of semantically related content, thereby potentially degrading retrieval performance12.
QUESTION: {question}\n
ANSWER:
""" # noqa
START_ANSWER = "FINAL ANSWER"
START_CITATION = "CITATION LIST"
CITATION_PATTERN = r"citation【(\d+)】"
START_ANSWER_PATTERN = "start_phrase:"
END_ANSWER_PATTERN = "end_phrase:"
@dataclass
class InlineEvidence:
"""List of evidences to support the answer."""
start_phrase: str | None = None
end_phrase: str | None = None
idx: int | None = None
class AnswerWithInlineCitation(AnswerWithContextPipeline):
"""Answer the question based on the evidence with inline citation"""
@ -85,15 +100,54 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
return prompt, evidence
def answer_to_citations(self, answer):
evidences = []
def answer_to_citations(self, answer) -> list[InlineEvidence]:
citations: list[InlineEvidence] = []
lines = answer.split("\n")
for line in lines:
for keyword in ["START_PHRASE:", "END_PHRASE:"]:
if line.startswith(keyword):
evidences.append(line[len(keyword) :].strip())
return CiteEvidence(evidences=evidences)
current_evidence = None
for line in lines:
# check citation idx using regex
match = re.match(CITATION_PATTERN, line.lower())
if match:
try:
parsed_citation_idx = int(match.group(1))
except ValueError:
parsed_citation_idx = None
# conclude the current evidence if exists
if current_evidence:
citations.append(current_evidence)
current_evidence = None
current_evidence = InlineEvidence(idx=parsed_citation_idx)
else:
for keyword in [START_ANSWER_PATTERN, END_ANSWER_PATTERN]:
if line.lower().startswith(keyword):
matched_phrase = line[len(keyword) :].strip()
if not current_evidence:
current_evidence = InlineEvidence(idx=None)
if keyword == START_ANSWER_PATTERN:
current_evidence.start_phrase = matched_phrase
else:
current_evidence.end_phrase = matched_phrase
break
if (
current_evidence
and current_evidence.end_phrase
and current_evidence.start_phrase
):
citations.append(current_evidence)
current_evidence = None
if current_evidence:
citations.append(current_evidence)
return citations
def replace_citation_with_link(self, answer: str):
# Define the regex pattern to match 【number】
@ -114,6 +168,8 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
),
)
answer = answer.replace(START_CITATION, "")
return answer
def stream( # type: ignore
@ -178,8 +234,6 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
# append main prompt
messages.append(HumanMessage(content=prompt))
START_ANSWER = "FINAL ANSWER"
start_of_answer = True
final_answer = ""
try:
@ -187,12 +241,24 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
print("Trying LLM streaming")
for out_msg in self.llm.stream(messages):
if START_ANSWER in output:
if not final_answer:
try:
left_over_answer = output.split(START_ANSWER)[1].lstrip()
except IndexError:
left_over_answer = ""
if left_over_answer:
out_msg.text = left_over_answer + out_msg.text
final_answer += (
out_msg.text.lstrip() if start_of_answer else out_msg.text
out_msg.text.lstrip() if not final_answer else out_msg.text
)
start_of_answer = False
yield Document(channel="chat", content=out_msg.text)
# check for the edge case of citation list is repeated
# with smaller LLMs
if START_CITATION in out_msg.text:
break
output += out_msg.text
logprobs += out_msg.logprobs
except NotImplementedError:
@ -235,10 +301,15 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
if not answer.metadata["citation"]:
return spans
evidences = answer.metadata["citation"].evidences
evidences = answer.metadata["citation"]
for e_id, evidence in enumerate(evidences):
start_phrase, end_phrase = evidence.start_phrase, evidence.end_phrase
evidence_idx = evidence.idx
if evidence_idx is None:
evidence_idx = e_id + 1
for start_idx in range(0, len(evidences), 2):
start_phrase, end_phrase = evidences[start_idx : start_idx + 2]
best_match = None
best_match_length = 0
best_match_doc_idx = None
@ -259,7 +330,7 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline):
{
"start": best_match[0],
"end": best_match[1],
"idx": start_idx // 2, # implicitly set from the start_idx
"idx": evidence_idx,
}
)
return spans