diff --git a/libs/kotaemon/kotaemon/indices/qa/citation_qa.py b/libs/kotaemon/kotaemon/indices/qa/citation_qa.py index a727246..37d8ced 100644 --- a/libs/kotaemon/kotaemon/indices/qa/citation_qa.py +++ b/libs/kotaemon/kotaemon/indices/qa/citation_qa.py @@ -342,11 +342,11 @@ class AnswerWithContextPipeline(BaseComponent): span_idx = span.get("idx", None) if span_idx is not None: - to_highlight = f"【{span_idx + 1}】" + to_highlight + to_highlight = f"【{span_idx}】" + to_highlight text += Render.highlight( to_highlight, - elem_id=str(span_idx + 1) if span_idx is not None else None, + elem_id=str(span_idx) if span_idx is not None else None, ) if idx < len(ss) - 1: text += cur_doc.text[span["end"] : ss[idx + 1]["start"]] diff --git a/libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py b/libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py index 5b0bf99..9770b90 100644 --- a/libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py +++ b/libs/kotaemon/kotaemon/indices/qa/citation_qa_inline.py @@ -1,6 +1,7 @@ import re import threading from collections import defaultdict +from dataclasses import dataclass from typing import Generator import numpy as np @@ -8,7 +9,6 @@ import numpy as np from kotaemon.base import AIMessage, Document, HumanMessage, SystemMessage from kotaemon.llms import PromptTemplate -from .citation import CiteEvidence from .citation_qa import CITATION_TIMEOUT, MAX_IMAGES, AnswerWithContextPipeline from .format_context import EVIDENCE_MODE_FIGURE from .utils import find_start_end_phrase @@ -61,12 +61,27 @@ START_PHRASE: Fixed-size Chunker This is our baseline chunker END_PHRASE: this shows good retrieval quality. FINAL ANSWER -An alternative to semantic chunking is fixed-size chunking. This traditional method involves splitting documents into chunks of a predetermined or user-specified size, regardless of semantic content, which is computationally efficient【1】. However, it may result in the fragmentation of semantically related content, thereby potentially degrading retrieval performance【2】. +An alternative to semantic chunking is fixed-size chunking. This traditional method involves splitting documents into chunks of a predetermined or user-specified size, regardless of semantic content, which is computationally efficient【1】. However, it may result in the fragmentation of semantically related content, thereby potentially degrading retrieval performance【1】【2】. QUESTION: {question}\n ANSWER: """ # noqa +START_ANSWER = "FINAL ANSWER" +START_CITATION = "CITATION LIST" +CITATION_PATTERN = r"citation【(\d+)】" +START_ANSWER_PATTERN = "start_phrase:" +END_ANSWER_PATTERN = "end_phrase:" + + +@dataclass +class InlineEvidence: + """List of evidences to support the answer.""" + + start_phrase: str | None = None + end_phrase: str | None = None + idx: int | None = None + class AnswerWithInlineCitation(AnswerWithContextPipeline): """Answer the question based on the evidence with inline citation""" @@ -85,15 +100,54 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline): return prompt, evidence - def answer_to_citations(self, answer): - evidences = [] + def answer_to_citations(self, answer) -> list[InlineEvidence]: + citations: list[InlineEvidence] = [] lines = answer.split("\n") - for line in lines: - for keyword in ["START_PHRASE:", "END_PHRASE:"]: - if line.startswith(keyword): - evidences.append(line[len(keyword) :].strip()) - return CiteEvidence(evidences=evidences) + current_evidence = None + + for line in lines: + # check citation idx using regex + match = re.match(CITATION_PATTERN, line.lower()) + + if match: + try: + parsed_citation_idx = int(match.group(1)) + except ValueError: + parsed_citation_idx = None + + # conclude the current evidence if exists + if current_evidence: + citations.append(current_evidence) + current_evidence = None + + current_evidence = InlineEvidence(idx=parsed_citation_idx) + else: + for keyword in [START_ANSWER_PATTERN, END_ANSWER_PATTERN]: + if line.lower().startswith(keyword): + matched_phrase = line[len(keyword) :].strip() + if not current_evidence: + current_evidence = InlineEvidence(idx=None) + + if keyword == START_ANSWER_PATTERN: + current_evidence.start_phrase = matched_phrase + else: + current_evidence.end_phrase = matched_phrase + + break + + if ( + current_evidence + and current_evidence.end_phrase + and current_evidence.start_phrase + ): + citations.append(current_evidence) + current_evidence = None + + if current_evidence: + citations.append(current_evidence) + + return citations def replace_citation_with_link(self, answer: str): # Define the regex pattern to match 【number】 @@ -114,6 +168,8 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline): ), ) + answer = answer.replace(START_CITATION, "") + return answer def stream( # type: ignore @@ -178,8 +234,6 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline): # append main prompt messages.append(HumanMessage(content=prompt)) - START_ANSWER = "FINAL ANSWER" - start_of_answer = True final_answer = "" try: @@ -187,12 +241,24 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline): print("Trying LLM streaming") for out_msg in self.llm.stream(messages): if START_ANSWER in output: + if not final_answer: + try: + left_over_answer = output.split(START_ANSWER)[1].lstrip() + except IndexError: + left_over_answer = "" + if left_over_answer: + out_msg.text = left_over_answer + out_msg.text + final_answer += ( - out_msg.text.lstrip() if start_of_answer else out_msg.text + out_msg.text.lstrip() if not final_answer else out_msg.text ) - start_of_answer = False yield Document(channel="chat", content=out_msg.text) + # check for the edge case of citation list is repeated + # with smaller LLMs + if START_CITATION in out_msg.text: + break + output += out_msg.text logprobs += out_msg.logprobs except NotImplementedError: @@ -235,10 +301,15 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline): if not answer.metadata["citation"]: return spans - evidences = answer.metadata["citation"].evidences + evidences = answer.metadata["citation"] + + for e_id, evidence in enumerate(evidences): + start_phrase, end_phrase = evidence.start_phrase, evidence.end_phrase + evidence_idx = evidence.idx + + if evidence_idx is None: + evidence_idx = e_id + 1 - for start_idx in range(0, len(evidences), 2): - start_phrase, end_phrase = evidences[start_idx : start_idx + 2] best_match = None best_match_length = 0 best_match_doc_idx = None @@ -259,7 +330,7 @@ class AnswerWithInlineCitation(AnswerWithContextPipeline): { "start": best_match[0], "end": best_match[1], - "idx": start_idx // 2, # implicitly set from the start_idx + "idx": evidence_idx, } ) return spans