Migrate the MVP into kotaemon (#108)

- Migrate the MVP into kotaemon.
- Preliminary include the pipeline within chatbot interface.
- Organize MVP as an application.

Todo:

- Add an info panel to view the planning of agents -> Fix streaming agents' output.

Resolve: #60
Resolve: #61 
Resolve: #62
This commit is contained in:
Duc Nguyen (john) 2024-01-10 15:28:09 +07:00 committed by GitHub
parent 230328c62f
commit 5a9d6f75be
31 changed files with 273 additions and 92 deletions

View File

@ -45,16 +45,12 @@ repos:
- id: prettier
types_or: [markdown, yaml]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.5.1"
rev: "v1.7.1"
hooks:
- id: mypy
additional_dependencies: [types-PyYAML==6.0.12.11, "types-requests"]
args:
[
"--check-untyped-defs",
"--ignore-missing-imports",
"--new-type-inference",
]
additional_dependencies:
[types-PyYAML==6.0.12.11, "types-requests", "sqlmodel"]
args: ["--check-untyped-defs", "--ignore-missing-imports"]
exclude: "^templates/"
- repo: https://github.com/codespell-project/codespell
rev: v2.2.4

View File

@ -269,4 +269,5 @@ class RewooAgent(BaseAgent):
total_tokens=total_token,
total_cost=total_cost,
citation=citation,
metadata={"citation": citation},
)

View File

@ -41,7 +41,7 @@ class BaseTool(BaseComponent):
args_schema = self.args_schema
if isinstance(tool_input, str):
if args_schema is not None:
key_ = next(iter(args_schema.__fields__.keys()))
key_ = next(iter(args_schema.model_fields.keys()))
args_schema.validate({key_: tool_input})
return tool_input
else:
@ -121,9 +121,11 @@ class BaseTool(BaseComponent):
class ComponentTool(BaseTool):
"""
A Tool based on another pipeline / BaseComponent to be used
as its main entry point
"""Wrapper around other BaseComponent to use it as a tool
Args:
component: BaseComponent-based component to wrap
postprocessor: Optional postprocessor for the component output
"""
component: BaseComponent

View File

@ -1,13 +1,11 @@
from typing import AnyStr, Optional, Type, Union
from typing import AnyStr, Optional, Type
from pydantic import BaseModel, Field
from kotaemon.llms import LLM, AzureChatOpenAI, ChatLLM
from kotaemon.llms import BaseLLM
from .base import BaseTool, ToolException
BaseLLM = Union[ChatLLM, LLM]
class LLMArgs(BaseModel):
query: str = Field(..., description="a search question or prompt")
@ -21,7 +19,7 @@ class LLMTool(BaseTool):
"are confident in solving the problem "
"yourself. Input can be any instruction."
)
llm: BaseLLM = AzureChatOpenAI.withx()
llm: BaseLLM
args_schema: Optional[Type[BaseModel]] = LLMArgs
def _run_tool(self, query: AnyStr) -> str:

View File

@ -1,4 +1,5 @@
from abc import abstractmethod
from typing import Iterator
from theflow import Function, Node, Param, lazy
@ -32,7 +33,9 @@ class BaseComponent(Function):
return self.__call__(self.inflow.flow())
@abstractmethod
def run(self, *args, **kwargs) -> Document | list[Document] | None:
def run(
self, *args, **kwargs
) -> Document | list[Document] | Iterator[Document] | None:
"""Run the component."""
...

View File

@ -23,11 +23,13 @@ class Document(BaseDocument):
store the raw content of the document. If specified, the class will use
`content` to initialize the base llama_index class.
Args:
content: the raw content of the document.
Attributes:
content: raw content of the document, can be anything
source: id of the source of the Document. Optional.
"""
content: Any
source: Optional[str] = None
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
if content is None:

View File

@ -121,9 +121,12 @@ class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
)
def _get_lc_class(self):
import langchain.embeddings
try:
from langchain_community.embeddings import OpenAIEmbeddings
except ImportError:
from langchain.embeddings import OpenAIEmbeddings
return langchain.emebddings.OpenAIEmbeddings
return OpenAIEmbeddings
class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
@ -148,9 +151,12 @@ class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
)
def _get_lc_class(self):
import langchain.embeddings
try:
from langchain_community.embeddings import AzureOpenAIEmbeddings
except ImportError:
from langchain.embeddings import AzureOpenAIEmbeddings
return langchain.embeddings.AzureOpenAIEmbeddings
return AzureOpenAIEmbeddings
class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
@ -173,9 +179,12 @@ class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
)
def _get_lc_class(self):
import langchain.embeddings
try:
from langchain_community.embeddings import CohereEmbeddings
except ImportError:
from langchain.embeddings import CohereEmbeddings
return langchain.embeddings.CohereEmbeddings
return CohereEmbeddings
class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
@ -192,6 +201,9 @@ class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
)
def _get_lc_class(self):
import langchain.embeddings
try:
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
except ImportError:
from langchain.embeddings import HuggingFaceBgeEmbeddings
return langchain.embeddings.HuggingFaceBgeEmbeddings
return HuggingFaceBgeEmbeddings

View File

@ -11,6 +11,7 @@ from kotaemon.loaders import (
MathpixPDFReader,
OCRReader,
PandasExcelReader,
UnstructuredReader,
)
@ -19,8 +20,16 @@ class DocumentIngestor(BaseComponent):
Document types:
- pdf
- xlsx
- docx
- xlsx, xls
- docx, doc
Args:
pdf_mode: mode for pdf extraction, one of "normal", "mathpix", "ocr"
- normal: parse pdf text
- mathpix: parse pdf text using mathpix
- ocr: parse pdf image using flax
doc_parsers: list of document parsers to parse the document
text_splitter: splitter to split the document into text nodes
"""
pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
@ -34,6 +43,9 @@ class DocumentIngestor(BaseComponent):
"""Get appropriate readers for the input files based on file extension"""
file_extractor: dict[str, AutoReader | BaseReader] = {
".xlsx": PandasExcelReader(),
".docx": UnstructuredReader(),
".xls": UnstructuredReader(),
".doc": UnstructuredReader(),
}
if self.pdf_mode == "normal":

View File

@ -64,11 +64,7 @@ class CitationPipeline(BaseComponent):
llm: BaseLLM
def run(
self,
context: str,
question: str,
) -> QuestionAnswer:
def run(self, context: str, question: str):
schema = QuestionAnswer.schema()
function = {
"name": schema["title"],

View File

@ -1,6 +1,6 @@
import os
from kotaemon.base import BaseComponent, Document, RetrievedDocument
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
from .citation import CitationPipeline
@ -21,6 +21,9 @@ class CitationQAPipeline(BaseComponent):
temperature=0,
request_timeout=60,
)
citation_pipeline: CitationPipeline = Node(
default_callback=lambda self: CitationPipeline(llm=self.llm)
)
def _format_doc_text(self, text: str) -> str:
"""Format the text of each document"""
@ -52,9 +55,7 @@ class CitationQAPipeline(BaseComponent):
self.log_progress(".prompt", prompt=prompt)
answer_text = self.llm(prompt).text
if use_citation:
# run citation pipeline
citation_pipeline = CitationPipeline(llm=self.llm)
citation = citation_pipeline(context=context, question=question)
citation = self.citation_pipeline(context=context, question=question)
else:
citation = None

View File

@ -23,17 +23,18 @@ class CohereReranking(BaseReranking):
)
cohere_client = cohere.Client(self.cohere_api_key)
compressed_docs: list[Document] = []
# output documents
compressed_docs = []
if len(documents) > 0: # to avoid empty api call
_docs = [d.content for d in documents]
results = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
)
for r in results:
doc = documents[r.index]
doc.metadata["relevance_score"] = r.relevance_score
compressed_docs.append(doc)
if not documents: # to avoid empty api call
return compressed_docs
_docs = [d.content for d in documents]
results = cohere_client.rerank(
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
)
for r in results:
doc = documents[r.index]
doc.metadata["relevance_score"] = r.relevance_score
compressed_docs.append(doc)
return compressed_docs

View File

@ -29,8 +29,19 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
def __init__(self, window_size: int = 3, **params):
super().__init__(window_size=window_size, **params)
def __init__(
self,
window_size: int = 3,
window_metadata_key: str = "window",
original_text_metadata_key: str = "original_text",
**params,
):
super().__init__(
window_size=window_size,
window_metadata_key=window_metadata_key,
original_text_metadata_key=original_text_metadata_key,
**params,
)
def _get_li_class(self):
from llama_index.node_parser import SentenceWindowNodeParser

View File

@ -62,7 +62,7 @@ class VectorIndexing(BaseIndexing):
embeddings = self.embedding(input_)
self.vector_store.add(
embeddings=embeddings,
ids=[t.id_ for t in input_],
ids=[t.doc_id for t in input_],
)
if self.doc_store:
self.doc_store.add(input_)
@ -99,7 +99,7 @@ class VectorRetrieval(BaseRetrieval):
)
emb: list[float] = self.embedding(text)[0].embedding
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs)
docs = self.doc_store.get(ids)
result = [
RetrievedDocument(**doc.to_dict(), score=score)

View File

@ -15,15 +15,23 @@ class LCChatMixin:
"Please return the relevant Langchain class in in _get_lc_class"
)
def __init__(self, **params):
def __init__(self, stream: bool = False, **params):
self._lc_class = self._get_lc_class()
self._obj = self._lc_class(**params)
self._kwargs: dict = params
self._stream = stream
super().__init__()
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
if self._stream:
return self.stream(messages, **kwargs) # type: ignore
return self.invoke(messages, **kwargs)
def invoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
"""Generate response from messages
@ -68,6 +76,10 @@ class LCChatMixin:
logits=[],
)
def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs):
for response in self._obj.stream(input=messages, **kwargs):
yield LLMInterface(content=response.content)
def to_langchain_format(self):
return self._obj
@ -150,6 +162,9 @@ class AzureChatOpenAI(LCChatMixin, ChatLLM):
)
def _get_lc_class(self):
import langchain.chat_models
try:
from langchain_community.chat_models import AzureChatOpenAI
except ImportError:
from langchain.chat_models import AzureChatOpenAI
return langchain.chat_models.AzureChatOpenAI
return AzureChatOpenAI

View File

@ -186,6 +186,9 @@ class AzureOpenAI(LCCompletionMixin, LLM):
)
def _get_lc_class(self):
import langchain.llms as langchain_llms
try:
from langchain_community.llms import AzureOpenAI
except ImportError:
from langchain.llms import AzureOpenAI
return langchain_llms.AzureOpenAI
return AzureOpenAI

View File

@ -26,11 +26,7 @@ class OCRReader(BaseReader):
self.ocr_endpoint = endpoint
self.use_ocr = use_ocr
def load_data(
self,
file_path: Path,
**kwargs,
) -> List[Document]:
def load_data(self, file_path: Path, **kwargs) -> List[Document]:
"""Load data using OCR reader
Args:
@ -41,23 +37,24 @@ class OCRReader(BaseReader):
Returns:
List[Document]: list of documents extracted from the PDF file
"""
# create input params for the requests
content = open(file_path, "rb")
files = {"input": content}
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
file_path = Path(file_path).resolve()
with file_path.open("rb") as content:
files = {"input": content}
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
# call the API from FullOCR endpoint
if "response_content" in kwargs:
# overriding response content if specified
ocr_results = kwargs["response_content"]
else:
# call original API
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
ocr_results = resp.json()["result"]
debug_path = kwargs.pop("debug_path", None)
artifact_path = kwargs.pop("artifact_path", None)
# call the API from FullOCR endpoint
if "response_content" in kwargs:
# overriding response content if specified
ocr_results = kwargs["response_content"]
else:
# call original API
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
ocr_results = resp.json()["result"]
# read PDF through normal reader (unstructured)
pdf_page_items = read_pdf_unstructured(file_path)
# merge PDF text output with OCR output
@ -77,6 +74,9 @@ class OCRReader(BaseReader):
"type": "table",
"page_label": page_id + 1,
"source": file_path.name,
"file_path": str(file_path),
"file_name": file_path.name,
"filename": str(file_path),
},
metadata_template="",
metadata_seperator="",
@ -91,6 +91,9 @@ class OCRReader(BaseReader):
metadata={
"page_label": page_id + 1,
"source": file_path.name,
"file_path": str(file_path),
"file_name": file_path.name,
"filename": str(file_path),
},
)
for page_id, non_table_text in texts

View File

@ -74,9 +74,10 @@ class UnstructuredReader(BaseReader):
""" Process elements """
docs = []
file_name = Path(file).name
file_path = str(Path(file).resolve())
if split_documents:
for node in elements:
metadata = {"file_name": file_name}
metadata = {"file_name": file_name, "file_path": file_path}
if hasattr(node, "metadata"):
"""Load metadata fields"""
for field, val in vars(node.metadata).items():
@ -99,7 +100,7 @@ class UnstructuredReader(BaseReader):
else:
text_chunks = [" ".join(str(el).split()) for el in elements]
metadata = {"file_name": file_name}
metadata = {"file_name": file_name, "file_path": file_path}
if additional_metadata is not None:
metadata.update(additional_metadata)

View File

@ -16,6 +16,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
elasticsearch_url: str = "http://localhost:9200",
k1: float = 2.0,
b: float = 0.75,
**kwargs,
):
try:
from elasticsearch import Elasticsearch
@ -31,7 +32,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
self.b = b
# Create an Elasticsearch client instance
self.client = Elasticsearch(elasticsearch_url)
self.client = Elasticsearch(elasticsearch_url, **kwargs)
self.es_bulk = bulk
# Define the index settings and mappings
settings = {
@ -63,19 +64,16 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
self,
docs: Union[Document, List[Document]],
ids: Optional[Union[List[str], str]] = None,
**kwargs
refresh_indices: bool = True,
**kwargs,
):
"""Add document into document store
Args:
docs: list of documents to add
ids: specify the ids of documents to add or
use existing doc.doc_id
refresh_indices: request Elasticsearch to update
its index (default to True)
ids: specify the ids of documents to add or use existing doc.doc_id
refresh_indices: request Elasticsearch to update its index (default to True)
"""
refresh_indices = kwargs.pop("refresh_indices", True)
if ids and not isinstance(ids, list):
ids = [ids]
if not isinstance(docs, list):
@ -120,7 +118,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
)
return docs
def query(self, query: str, top_k: int = 10) -> List[Document]:
def query(
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
) -> List[Document]:
"""Search Elasticsearch docstore using search query (BM25)
Args:
@ -131,7 +131,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
Returns:
List[Document]: List of result documents
"""
query_dict = {"query": {"match": {"content": query}}, "size": top_k}
query_dict: dict = {"query": {"match": {"content": query}}, "size": top_k}
if doc_ids:
query_dict["query"]["match"]["_id"] = {"values": doc_ids}
return self.query_raw(query_dict)
def get(self, ids: Union[List[str], str]) -> List[Document]:

View File

@ -74,6 +74,11 @@ class InMemoryDocumentStore(BaseDocumentStore):
"""Load document store from path"""
with open(path) as f:
store = json.load(f)
# TODO: save and load aren't lossless. A Document-subclass will lose
# information. Need to edit the `to_dict` and `from_dict` methods in
# the Document class.
# For better query support, utilize SQLite as the default document store.
# Also, for portability, use SQLAlchemy for document store.
self._store = {key: Document.from_dict(value) for key, value in store.items()}
def __persist_flow__(self):

View File

@ -15,6 +15,18 @@ class SimpleFileDocumentStore(InMemoryDocumentStore):
if path is not None and Path(path).is_file():
self.load(path)
def get(self, ids: Union[List[str], str]) -> List[Document]:
"""Get document by id"""
if not isinstance(ids, list):
ids = [ids]
for doc_id in ids:
if doc_id not in self._store:
self.load(self._path)
break
return [self._store[doc_id] for doc_id in ids]
def add(
self,
docs: Union[Document, List[Document]],

View File

@ -76,8 +76,15 @@ class LlamaIndexVectorStore(BaseVectorStore):
"Require `_li_class` to set a VectorStore class from LlamarIndex"
)
from dataclasses import fields
self._client = self._li_class(*args, **kwargs)
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
for key in ["query_embedding", "similarity_top_k", "node_ids"]:
if key in self._vsq_kwargs:
self._vsq_kwargs.remove(key)
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_"):
return super().__setattr__(name, value)
@ -122,13 +129,35 @@ class LlamaIndexVectorStore(BaseVectorStore):
ids: Optional[list[str]] = None,
**kwargs,
) -> tuple[list[list[float]], list[float], list[str]]:
"""Return the top k most similar vector embeddings
Args:
embedding: List of embeddings
top_k: Number of most similar embeddings to return
ids: List of ids of the embeddings to be queried
kwargs: extra query parameters. Depending on the name, these parameters
will be used when constructing the VectorStoreQuery object or when
performing querying of the underlying vector store.
Returns:
the matched embeddings, the similarity scores, and the ids
"""
vsq_kwargs = {}
vs_kwargs = {}
for kwkey, kwvalue in kwargs.items():
if kwkey in self._vsq_kwargs:
vsq_kwargs[kwkey] = kwvalue
else:
vs_kwargs[kwkey] = kwvalue
output = self._client.query(
query=VectorStoreQuery(
query_embedding=embedding,
similarity_top_k=top_k,
node_ids=ids,
**kwargs,
**vsq_kwargs,
),
**vs_kwargs,
)
embeddings = []

View File

@ -64,7 +64,7 @@ class ChromaVectorStore(LlamaIndexVectorStore):
ids: List of ids of the embeddings to be deleted
kwargs: meant for vectorstore-specific parameters
"""
self._client._collection.delete(ids=ids)
self._client.client.delete(ids=ids)
def delete_collection(self, collection_name: Optional[str] = None):
"""Delete entire collection under specified name from vector stores

View File

@ -16,6 +16,7 @@ requires-python = ">= 3.10"
description = "Kotaemon core library for AI development."
dependencies = [
"langchain",
"langchain-community",
"theflow",
"llama-index>=0.9.0",
"llama-hub",
@ -56,7 +57,6 @@ dev = [
"python-dotenv",
"pytest-mock",
"unstructured[pdf]",
# "farm-haystack==1.22.1",
"sentence_transformers",
"cohere",
"elasticsearch",

View File

@ -47,6 +47,7 @@ def generate_chat_completion_obj(text):
"function_call": None,
"tool_calls": None,
},
"logprobs": None,
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},

View File

@ -30,6 +30,7 @@ _openai_chat_completion_response = [
},
"tool_calls": None,
},
"logprobs": None,
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},

View File

@ -30,6 +30,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
"finish_reason": "length",
"logprobs": None,
},
"logprobs": None,
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},

View File

@ -23,6 +23,7 @@ _openai_chat_completion_response = [
"function_call": None,
"tool_calls": None,
},
"logprobs": None,
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},

15
tests/test_ingestor.py Normal file
View File

@ -0,0 +1,15 @@
from pathlib import Path
from kotaemon.indices.ingests import DocumentIngestor
from kotaemon.indices.splitters import TokenSplitter
def test_ingestor_include_src():
dirpath = Path(__file__).parent
ingestor = DocumentIngestor(
pdf_mode="normal",
text_splitter=TokenSplitter(chunk_size=50, chunk_overlap=10),
)
nodes = ingestor(dirpath / "resources" / "table.pdf")
assert type(nodes) is list
assert nodes[0].relationships

View File

@ -28,6 +28,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
"function_call": None,
"tool_calls": None,
},
"logprobs": None,
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},

View File

@ -25,6 +25,7 @@ _openai_chat_completion_responses = [
"function_call": None,
"tool_calls": None,
},
"logprobs": None,
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},

55
tests/test_splitter.py Normal file
View File

@ -0,0 +1,55 @@
from llama_index.schema import NodeRelationship
from kotaemon.base import Document
from kotaemon.indices.splitters import TokenSplitter
source1 = Document(
content="The City Hall and Raffles Place MRT stations are paired cross-platform "
"interchanges on the NorthSouth line (NSL) and EastWest line (EWL) of the "
"Singapore Mass Rapid Transit (MRT) system. Both are situated in the Downtown "
"Core district: City Hall station is near landmarks such as the former City Hall, "
"St Andrew's Cathedral and the Padang, while Raffles Place station serves Merlion "
"Park, The Fullerton Hotel and the Asian Civilisations Museum. The stations were "
"first announced in 1982. Constructing the tunnels between the City Hall and "
"Raffles Place stations required the draining of the Singapore River. The "
"stations opened on 12 December 1987 as part of the MRT extension to Outram Park "
"station. Cross-platform transfers between the NSL and EWL began on 28 October "
"1989, ahead of the split of the MRT network into two lines. Both stations are "
"designated Civil Defence shelters. City Hall station features a mural by Simon"
"Wong which depicts government buildings in the area, while two murals at Raffles "
"Place station by Lim Sew Yong and Thang Kiang How depict scenes of Singapore's "
"history"
)
source2 = Document(
content="The pink cockatoo (Cacatua leadbeateri) is a medium-sized cockatoo that "
"inhabits arid and semi-arid inland areas across Australia, with the exception of "
"the north east. The bird has a soft-textured white and salmon-pink plumage and "
"large, bright red and yellow crest. The sexes are quite similar, although males "
"are usually bigger while the female has a broader yellow stripe on the crest and "
"develops a red eye when mature. The pink cockatoo is usually found in pairs or "
"small groups, and feeds both on the ground and in trees. It is listed as an "
"endangered species by the Australian government. Formerly known as Major "
"Mitchell's cockatoo, after the explorer Thomas Mitchell, the species was "
"officially renamed the pink cockatoo in 2023 by BirdLife Australia in light of "
"Mitchell's involvement in the massacre of Aboriginal people at Mount Dispersion, "
"as well as a general trend to make Australian species names more culturally "
"inclusive. This pink cockatoo with a raised crest was photographed near Mount "
"Grenfell in New South Wales."
)
def test_split_token():
"""Test that it can split tokens successfully"""
splitter = TokenSplitter(chunk_size=30, chunk_overlap=10)
chunks = splitter([source1, source2])
assert isinstance(chunks, list), "Chunks should be a list"
assert isinstance(chunks[0], Document), "Chunks should be a list of Documents"
assert chunks[0].relationships[NodeRelationship.SOURCE].node_id == source1.doc_id
assert (
chunks[1].relationships[NodeRelationship.PREVIOUS].node_id == chunks[0].doc_id
)
assert chunks[1].relationships[NodeRelationship.NEXT].node_id == chunks[2].doc_id
assert chunks[-1].relationships[NodeRelationship.SOURCE].node_id == source2.doc_id