From 6ab18545324a1bdf4240adb3683a9d5ec304ee8b Mon Sep 17 00:00:00 2001 From: "Nguyen Trung Duc (john)" Date: Wed, 4 Oct 2023 02:16:33 +0700 Subject: [PATCH] feat: Add chain-of-thought (#37) * Add chain-of-thought * Use BasePromptComponent * Add terminate callback for the chain-of-thought --- .env.secret | Bin 1837 -> 1911 bytes .gitignore | 1 + .gitsecret/paths/mapping.cfg | 2 +- knowledgehub/pipelines/cot.py | 169 ++++++++++++++++++++++++++++++++++ tests/test_cot.py | 120 ++++++++++++++++++++++++ 5 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 knowledgehub/pipelines/cot.py create mode 100644 tests/test_cot.py diff --git a/.env.secret b/.env.secret index 3d6cc8e0d0f7fb33d0505cdec6460b4020f99cca..00e529c822b46aa118bffaf69aeb623b7c792476 100644 GIT binary patch literal 1911 zcmV--2Z;EE0gMB9>i0uNB-FA23;Yk!rf-D8B>1`tB#Z*mWgpm@7w4xFC){rMQsU&q z*>8KL7m|{}&}cxi4AlUw;d0uH=4@s&{_}}(tDK^d?R+raeILNbA9{)cb_cUtT8DU1~NE$gO^)H}ql&zsh@Kw6sGr_Q{ekBOJpB4XGkj|?i z#9@#bCsHV6^n3)&#eo*?mFbL%Olsl!ze`z#e6<}iy*nZ~^xbSqV+F6Enb(&$nTjOs zwz7-r)`sAz%EH^5q2RF@v$YA!H^%=KpiW`UHus-|6wX!iAr>P;KEYN;&o-y(Wq4%- zaU^c0;&ZpSyv7mAG8&`c)kH*~_*1K2*bGXt$w>sMq>E4`7E9ZTWL!o^^W=msJ`juI zzm`Wn2)MOwcxBj|3!w0K8TU|h&g9fm?#tyCS>@65i@q7Ib?Sd%;M-&9EmjCk2!=)>Rid^4ryi1&&uv-X=zGQR{u`4%{PJgdg+7lXOuFl<25FG0WEyE{ znFo5IrQYFb!|@i33$|2nQM0LPD0Jt|f6Ej&1d>`f){1%B-HBZHYM@w2WsROL2l=!s z*gsvqMBi1Yo`uj5DubzMZ44foTNtT$XFdY@%Bxb`e{r7q_}yI=sv;YrxmTbISc2`A z@2{1-_k)XQJ>Rh99@9G1c)fljX^Ac5jLd08FObBvfW#c4KeLQDsaPP0i+rcPBtZhMh^?l=vs&>y)krlq~Cc$Pv_cYHJPJ`P-y>os_S4?=+CP? zE-3K^aTe+V==V!6ZgNqvp8n2(-m>n*Uiv*v9N6wf3wVN{YD)(qH^Uoc3G3?y3fjERxs! z|2S^7{QZOEHL>DFB-oBV8yp3`o>R_iJuaE~x>&a1A-6z@Vd6`&!rA#NIGwdjZON_9 zI2k@*{E@uiL@gFWdJw|WxdE&ZqmByy3qbphi|;TTLPb){#sB>mH)~Jj~R(4T$(D$`ho0Ovxz6-+$QG6V= zNRU5UbWKT+*{E{b)Wb{q-B-+@G)UXXZ6~_&Xu^^iob4+5^{F5zOA5(LFILmn0~r8R x=XX09kagN-ek0)@5azVQGEQsPA3OFMS7o9k=9si0uNB-FA23;$?X+ZHg0$OH@lirPi40OyYGE#?vYOIuFcdtJpE z7aaHaSLvMd$zjNhu-Bb48RnLdzZ!J5fzBD#Y>{q6|T<1A8@`=X(=W z@)GewY6g}EX<#Q-;-pnBDN0`qs0AXd0ypYwR&Pe>(!nH1f$s>u(w=y9htW)>O%kH- z4m;4NhTU!xsV&e-Z%a7}&S3n==`h+6c=wXyyCR|Sd|sz=BlGg^>QDhBe6#YzR~Gch zQCK3r{_GS^Z~mu*%eSHx1vT@W%11;oUPsOZ_8{F3bq8h!2`SS%h_iQ_mfte5#$*uo zWgG#X<-z)$>3f%BANAVaSCZfq25cof2^)Yv10YV;5nrq5oBit7Y{0nmiNEFak>oq1 z&|@5Fm)I>rx!-lalj%_x>NM98lj?LCN33zy|9qPzytsBoIauZGJv7pIFx$Jze0m!& zEoexhpIaLyYVR#g<`84H(77A4u8hbsYZ6^`<{%5+k0I|FoR0XJ*W+Bq%&PjxnPfXsAWO=NV+O$g1R2T!MCkCBmpqwY7yV~V^_N(At zdfN#X+jO=CFp!Fi|B^tepu9>3&2(oGCfN@9aYf(Puad>9H6f3*lP0?N_Kmd_r`<9p zZKOu!k70mXE;z!o9a2p00#mmao7w@DMFLG#TzqG_x5yY$a3U{aN;_UdEcVmy?8!a< z>pAE7-XG>Y`c8H?m6!@BrTwv_;#d-7V3oxxchel|C~(Zqq4`)T7ygXZKiV0uyuIVu z_5TXCofL~`6-7@jR7~jdOdq4^P#7DWW|k+2RoA>OHwXgOv7*g-aK0JUnDP?&t28%d zNUM2y969*T3>J`xBS*3WUgDUfO4OIcjY^60T3}pgudhpHW{&|DRzuJS9xL8bdx#E4 z{dw=E##GeIP}%$W%h|6ypi*64!tz3Z_0@WDG+%SzrL*c(X~wLuH3<9{K%h z(I+s_?+c1>$=-S+_WtQ_8^w>^%;v`S0}brc0H;-dkii&L%!|YrgR`csUNcQ-)Gfud z{r9V>I&Z?!xJW!G(caCk^W^!Vl0g#nBMw)PfJ7lLub0So_oZ2hKo zSzGi(h39umV!Fj1M9#UNO4Dpvt4ZF6bnc}tbG8uhk^!fY5lV!QPj)4|ylMsVdy5VCKnF!m{Hj?2y5aZ{$TvMVTAZ_P-vsfzp z_~^Sl4_Spj{HFI bHk8>x95yk1y>58U-tpPPfBa@a3c^e92XckK diff --git a/.gitignore b/.gitignore index 1f9bef6..a8af57d 100644 --- a/.gitignore +++ b/.gitignore @@ -454,6 +454,7 @@ logs/ .gitsecret/keys/random_seed !*.secret .env +.envrc S.gpg-agent* .vscode/settings.json diff --git a/.gitsecret/paths/mapping.cfg b/.gitsecret/paths/mapping.cfg index 5af535f..ae4cf29 100644 --- a/.gitsecret/paths/mapping.cfg +++ b/.gitsecret/paths/mapping.cfg @@ -1 +1 @@ -.env:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5 +.env:555d804179d7207ad6784a84afb88d2ec44f90ea3b7a061d0e38f9dd53fe7211 diff --git a/knowledgehub/pipelines/cot.py b/knowledgehub/pipelines/cot.py new file mode 100644 index 0000000..4a79768 --- /dev/null +++ b/knowledgehub/pipelines/cot.py @@ -0,0 +1,169 @@ +from copy import deepcopy +from typing import List + +from theflow import Compose, Node, Param + +from kotaemon.base import BaseComponent +from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.prompt.base import BasePromptComponent + + +class Thought(BaseComponent): + """A thought in the chain of thought + + - Input: `**kwargs` pairs, where key is the placeholder in the prompt, and + value is the value. + - Output: an output dictionary + + ##### Usage: + + Create and run a thought: + + ```python + >> from kotaemon.pipelines.cot import Thought + >> thought = Thought( + prompt="How to {action} {object}?", + llm=AzureChatOpenAI(...), + post_process=lambda string: {"tutorial": string}, + ) + >> output = thought(action="install", object="python") + >> print(output) + {'tutorial': 'As an AI language model,...'} + ``` + + Basically, when a thought is run, it will: + + 1. Populate the prompt template with the input `**kwargs`. + 2. Run the LLM model with the populated prompt. + 3. Post-process the LLM output with the post-processor. + + This `Thought` allows chaining sequentially with the + operator. For example: + + ```python + >> llm = AzureChatOpenAI(...) + >> thought1 = Thought( + prompt="Word {word} in {language} is ", + llm=llm, + post_process=lambda string: {"translated": string}, + ) + >> thought2 = Thought( + prompt="Translate {translated} to Japanese", + llm=llm, + post_process=lambda string: {"output": string}, + ) + + >> thought = thought1 + thought2 + >> thought(word="hello", language="French") + {'word': 'hello', + 'language': 'French', + 'translated': '"Bonjour"', + 'output': 'こんにちは (Konnichiwa)'} + ``` + + Under the hood, when the `+` operator is used, a `ManualSequentialChainOfThought` + is created. + """ + + prompt: Param[str] = Param( + help="The prompt template string. This prompt template has Python-like " + "variable placeholders, that then will be subsituted with real values when " + "this component is executed" + ) + llm = Node( + default=AzureChatOpenAI, help="The LLM model to execute the input prompt" + ) + post_process: Node[Compose] = Node( + help="The function post-processor that post-processes LLM output prediction ." + "It should take a string as input (this is the LLM output text) and return " + "a dictionary, where the key should" + ) + + @Node.decorate(depends_on="prompt") + def prompt_template(self): + return BasePromptComponent(self.prompt) + + def run(self, **kwargs) -> dict: + """Run the chain of thought""" + prompt = self.prompt_template(**kwargs).text + response = self.llm(prompt).text + return self.post_process(response) + + def get_variables(self) -> List[str]: + return [] + + def __add__(self, next_thought: "Thought") -> "ManualSequentialChainOfThought": + return ManualSequentialChainOfThought( + thoughts=[self, next_thought], llm=self.llm + ) + + +class ManualSequentialChainOfThought(BaseComponent): + """Perform sequential chain-of-thought with manual pre-defined prompts + + This method supports variable number of steps. Each step corresponds to a + `kotaemon.pipelines.cot.Thought`. Please refer that section for + Thought's detail. This section is about chaining thought together. + + ##### Usage: + + **Create and run a chain of thought without "+" operator:** + + ```python + >> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought + + >> llm = AzureChatOpenAI(...) + >> thought1 = Thought( + prompt="Word {word} in {language} is ", + post_process=lambda string: {"translated": string}, + ) + >> thought2 = Thought( + prompt="Translate {translated} to Japanese", + post_process=lambda string: {"output": string}, + ) + >> thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm) + >> thought(word="hello", language="French") + {'word': 'hello', + 'language': 'French', + 'translated': '"Bonjour"', + 'output': 'こんにちは (Konnichiwa)'} + ``` + + **Create and run a chain of thought without "+" operator:** Please refer the + `kotaemon.pipelines.cot.Thought` section for examples. + + This chain-of-thought optionally takes a termination check callback function. + This function will be called after each thought is executed. It takes in a + dictionary of all thought outputs so far, and it returns True or False. If + True, the chain-of-thought will terminate. If unset, the default callback always + returns False. + """ + + thoughts: Param[List[Thought]] = Param( + default_callback=lambda *_: [], help="List of Thought" + ) + llm: Param = Param(help="The LLM model to use (base of kotaemon.llms.LLM)") + terminate: Param = Param( + default=lambda _: False, + help="Callback on terminate condition. Default to always return False", + ) + + def run(self, **kwargs) -> dict: + """Run the manual chain of thought""" + + inputs = deepcopy(kwargs) + for idx, thought in enumerate(self.thoughts): + if self.llm: + thought.llm = self.llm + self._prepare_child(thought, f"thought{idx}") + + output = thought(**inputs) + inputs.update(output) + if self.terminate(inputs): + break + + return inputs + + def __add__(self, next_thought: Thought) -> "ManualSequentialChainOfThought": + return ManualSequentialChainOfThought( + thoughts=self.thoughts + [next_thought], llm=self.llm + ) diff --git a/tests/test_cot.py b/tests/test_cot.py new file mode 100644 index 0000000..f9423af --- /dev/null +++ b/tests/test_cot.py @@ -0,0 +1,120 @@ +from unittest.mock import patch + +from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.pipelines.cot import ManualSequentialChainOfThought, Thought + +_openai_chat_completion_response = [ + { + "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", + "object": "chat.completion", + "created": 1692338378, + "model": "gpt-35-turbo", + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": text, + }, + } + ], + "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, + } + for text in ["Bonjour", "こんにちは (Konnichiwa)"] +] + + +@patch( + "openai.api_resources.chat_completion.ChatCompletion.create", + side_effect=_openai_chat_completion_response, +) +def test_cot_plus_operator(openai_completion): + llm = AzureChatOpenAI( + openai_api_base="https://dummy.openai.azure.com/", + openai_api_key="dummy", + openai_api_version="2023-03-15-preview", + deployment_name="dummy-q2", + temperature=0, + ) + thought1 = Thought( + prompt="Word {word} in {language} is ", + llm=llm, + post_process=lambda string: {"translated": string}, + ) + thought2 = Thought( + prompt="Translate {translated} to Japanese", + llm=llm, + post_process=lambda string: {"output": string}, + ) + thought = thought1 + thought2 + output = thought(word="hello", language="French") + assert output == { + "word": "hello", + "language": "French", + "translated": "Bonjour", + "output": "こんにちは (Konnichiwa)", + } + + +@patch( + "openai.api_resources.chat_completion.ChatCompletion.create", + side_effect=_openai_chat_completion_response, +) +def test_cot_manual(openai_completion): + llm = AzureChatOpenAI( + openai_api_base="https://dummy.openai.azure.com/", + openai_api_key="dummy", + openai_api_version="2023-03-15-preview", + deployment_name="dummy-q2", + temperature=0, + ) + thought1 = Thought( + prompt="Word {word} in {language} is ", + post_process=lambda string: {"translated": string}, + ) + thought2 = Thought( + prompt="Translate {translated} to Japanese", + post_process=lambda string: {"output": string}, + ) + thought = ManualSequentialChainOfThought(thoughts=[thought1, thought2], llm=llm) + output = thought(word="hello", language="French") + assert output == { + "word": "hello", + "language": "French", + "translated": "Bonjour", + "output": "こんにちは (Konnichiwa)", + } + + +@patch( + "openai.api_resources.chat_completion.ChatCompletion.create", + side_effect=_openai_chat_completion_response, +) +def test_cot_with_termination_callback(openai_completion): + llm = AzureChatOpenAI( + openai_api_base="https://dummy.openai.azure.com/", + openai_api_key="dummy", + openai_api_version="2023-03-15-preview", + deployment_name="dummy-q2", + temperature=0, + ) + thought1 = Thought( + prompt="Word {word} in {language} is ", + post_process=lambda string: {"translated": string}, + ) + thought2 = Thought( + prompt="Translate {translated} to Japanese", + post_process=lambda string: {"output": string}, + ) + thought = ManualSequentialChainOfThought( + thoughts=[thought1, thought2], + llm=llm, + terminate=lambda d: True if d.get("translated", "") == "Bonjour" else False, + ) + output = thought(word="hallo", language="French") + assert output == { + "word": "hallo", + "language": "French", + "translated": "Bonjour", + }