Skip to content

Commit 67e7038

Browse files
committed
fix: store tool_call_id and status for ToolMessage serialization (#51)
- Add ToolMessage import to chat_message_history - Store tool_call_id and status fields when adding ToolMessage to history - Rename redis_msg to common_data_to_store with proper type annotation - Add comprehensive test suite with 5 tests covering ToolMessage scenarios The issue occurred when using RunnableWithMessageHistory with tool-calling agents. On the second invocation with the same session_id, deserialization failed with KeyError: 'tool_call_id' because these required fields were not being stored. Based on PR #73 by @alienware Closes #51
1 parent 160c5d7 commit 67e7038

File tree

2 files changed

+250
-3
lines changed

2 files changed

+250
-3
lines changed

libs/redis/langchain_redis/chat_message_history.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Any, Dict, List, Optional
55

66
from langchain_core.chat_history import BaseChatMessageHistory
7-
from langchain_core.messages import BaseMessage, messages_from_dict
7+
from langchain_core.messages import BaseMessage, ToolMessage, messages_from_dict
88
from redis import Redis
99
from redis.exceptions import ConnectionError, ResponseError
1010
from redisvl.exceptions import RedisSearchError # type: ignore
@@ -316,7 +316,7 @@ def add_message(self, message: BaseMessage) -> None:
316316

317317
timestamp = datetime.now().timestamp()
318318
message_id = str(ULID())
319-
redis_msg = {
319+
common_data_to_store: Dict[str, Any] = {
320320
"type": message.type,
321321
"message_id": message_id,
322322
"data": {
@@ -327,10 +327,15 @@ def add_message(self, message: BaseMessage) -> None:
327327
"session_id": self.session_id,
328328
"timestamp": timestamp,
329329
}
330+
if isinstance(message, ToolMessage):
331+
common_data_to_store["data"]["tool_call_id"] = message.tool_call_id
332+
common_data_to_store["data"]["status"] = message.status
330333

331334
# Use RedisVL to load the data
332335
self.index.load(
333-
data=[redis_msg], keys=[self._message_key(message_id)], ttl=self.ttl
336+
data=[common_data_to_store],
337+
keys=[self._message_key(message_id)],
338+
ttl=self.ttl,
334339
)
335340

336341
def clear(self) -> None:
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
"""Tests for ToolMessage KeyError issue (#51)."""
2+
3+
from typing import Any
4+
from unittest.mock import MagicMock, patch
5+
6+
from langchain_core.messages import (
7+
AIMessage,
8+
HumanMessage,
9+
SystemMessage,
10+
ToolMessage,
11+
)
12+
13+
from langchain_redis import RedisChatMessageHistory
14+
15+
16+
class TestToolMessageIssue51:
17+
"""Test ToolMessage serialization and deserialization."""
18+
19+
@patch("langchain_redis.chat_message_history.SearchIndex")
20+
def test_add_tool_message_stores_tool_call_id(
21+
self, mock_search_index: MagicMock
22+
) -> None:
23+
"""Test that adding a ToolMessage stores tool_call_id.
24+
25+
This is the fix for issue #51 where ToolMessage.tool_call_id
26+
was not being stored, causing KeyError on deserialization.
27+
"""
28+
mock_redis_client = MagicMock()
29+
mock_redis_client.client_setinfo = MagicMock()
30+
mock_redis_client.ft = MagicMock()
31+
mock_index_instance = MagicMock()
32+
mock_search_index.from_dict.return_value = mock_index_instance
33+
34+
history = RedisChatMessageHistory(
35+
session_id="test_session",
36+
redis_client=mock_redis_client,
37+
)
38+
39+
# Add a ToolMessage with tool_call_id
40+
tool_message = ToolMessage(
41+
content="Tool result", tool_call_id="call_123", status="success"
42+
)
43+
history.add_message(tool_message)
44+
45+
# Verify that index.load was called
46+
assert mock_index_instance.load.called
47+
48+
# Get the data that was passed to index.load
49+
call_args = mock_index_instance.load.call_args
50+
data = call_args[1]["data"][0]
51+
52+
# Verify tool_call_id is in the stored data
53+
assert "tool_call_id" in data["data"], "tool_call_id must be stored"
54+
assert data["data"]["tool_call_id"] == "call_123"
55+
assert "status" in data["data"], "status must be stored"
56+
assert data["data"]["status"] == "success"
57+
58+
@patch("langchain_redis.chat_message_history.SearchIndex")
59+
def test_add_regular_messages_without_tool_call_id(
60+
self, mock_search_index: MagicMock
61+
) -> None:
62+
"""Test that regular messages don't have tool_call_id added."""
63+
mock_redis_client = MagicMock()
64+
mock_redis_client.client_setinfo = MagicMock()
65+
mock_redis_client.ft = MagicMock()
66+
mock_index_instance = MagicMock()
67+
mock_search_index.from_dict.return_value = mock_index_instance
68+
69+
history = RedisChatMessageHistory(
70+
session_id="test_session",
71+
redis_client=mock_redis_client,
72+
)
73+
74+
# Add various message types
75+
history.add_message(HumanMessage(content="Hello"))
76+
call_args = mock_index_instance.load.call_args
77+
data = call_args[1]["data"][0]
78+
assert "tool_call_id" not in data["data"]
79+
80+
history.add_message(AIMessage(content="Hi there"))
81+
call_args = mock_index_instance.load.call_args
82+
data = call_args[1]["data"][0]
83+
assert "tool_call_id" not in data["data"]
84+
85+
history.add_message(SystemMessage(content="System message"))
86+
call_args = mock_index_instance.load.call_args
87+
data = call_args[1]["data"][0]
88+
assert "tool_call_id" not in data["data"]
89+
90+
@patch("langchain_redis.chat_message_history.SearchIndex")
91+
def test_retrieve_tool_message_without_key_error(
92+
self, mock_search_index: MagicMock
93+
) -> None:
94+
"""Test that retrieving ToolMessage doesn't raise KeyError.
95+
96+
This reproduces the original issue #51 where messages_from_dict
97+
would fail with KeyError: 'tool_call_id' when retrieving ToolMessages.
98+
"""
99+
mock_redis_client = MagicMock()
100+
mock_redis_client.client_setinfo = MagicMock()
101+
mock_redis_client.ft = MagicMock()
102+
mock_index_instance = MagicMock()
103+
mock_search_index.from_dict.return_value = mock_index_instance
104+
105+
# Mock the query to return a ToolMessage
106+
mock_index_instance.query.return_value = [
107+
{
108+
"type": "tool",
109+
"$.data": (
110+
'{"content": "Tool result", "additional_kwargs": {}, '
111+
'"type": "tool", "tool_call_id": "call_123", '
112+
'"status": "success"}'
113+
),
114+
}
115+
]
116+
117+
history = RedisChatMessageHistory(
118+
session_id="test_session",
119+
redis_client=mock_redis_client,
120+
)
121+
122+
# This should not raise KeyError
123+
messages = history.messages
124+
125+
assert len(messages) == 1
126+
assert isinstance(messages[0], ToolMessage)
127+
assert messages[0].content == "Tool result"
128+
assert messages[0].tool_call_id == "call_123"
129+
assert messages[0].status == "success"
130+
131+
@patch("langchain_redis.chat_message_history.SearchIndex")
132+
def test_round_trip_tool_message(self, mock_search_index: MagicMock) -> None:
133+
"""Test complete round-trip: add ToolMessage and retrieve it.
134+
135+
This simulates the real-world scenario from issue #51 where
136+
a ToolMessage is added in one session and retrieved in a follow-up.
137+
"""
138+
mock_redis_client = MagicMock()
139+
mock_redis_client.client_setinfo = MagicMock()
140+
mock_redis_client.ft = MagicMock()
141+
mock_index_instance = MagicMock()
142+
mock_search_index.from_dict.return_value = mock_index_instance
143+
144+
# Create stored data that will be captured
145+
stored_data = []
146+
147+
def capture_load(**kwargs: Any) -> None:
148+
stored_data.append(kwargs["data"][0])
149+
150+
mock_index_instance.load.side_effect = capture_load
151+
152+
history = RedisChatMessageHistory(
153+
session_id="test_session",
154+
redis_client=mock_redis_client,
155+
)
156+
157+
# Add a ToolMessage
158+
tool_msg = ToolMessage(
159+
content="Search results", tool_call_id="call_456", status="success"
160+
)
161+
history.add_message(tool_msg)
162+
163+
# Verify data was stored
164+
assert len(stored_data) == 1
165+
stored = stored_data[0]
166+
167+
# Now mock the query to return what was stored
168+
import json
169+
170+
mock_index_instance.query.return_value = [
171+
{"type": stored["type"], "$.data": json.dumps(stored["data"])}
172+
]
173+
174+
# Retrieve messages - should not raise KeyError
175+
messages = history.messages
176+
177+
assert len(messages) == 1
178+
assert isinstance(messages[0], ToolMessage)
179+
assert messages[0].content == "Search results"
180+
assert messages[0].tool_call_id == "call_456"
181+
assert messages[0].status == "success"
182+
183+
@patch("langchain_redis.chat_message_history.SearchIndex")
184+
def test_mixed_message_types_with_tool_message(
185+
self, mock_search_index: MagicMock
186+
) -> None:
187+
"""Test conversation with mixed message types including ToolMessage."""
188+
mock_redis_client = MagicMock()
189+
mock_redis_client.client_setinfo = MagicMock()
190+
mock_redis_client.ft = MagicMock()
191+
mock_index_instance = MagicMock()
192+
mock_search_index.from_dict.return_value = mock_index_instance
193+
194+
stored_messages = []
195+
196+
def capture_load(**kwargs: Any) -> None:
197+
stored_messages.append(kwargs["data"][0])
198+
199+
mock_index_instance.load.side_effect = capture_load
200+
201+
history = RedisChatMessageHistory(
202+
session_id="test_session",
203+
redis_client=mock_redis_client,
204+
)
205+
206+
# Add a conversation with tool calls
207+
history.add_message(HumanMessage(content="Search for Python tutorials"))
208+
history.add_message(AIMessage(content="I'll search for that"))
209+
history.add_message(
210+
ToolMessage(
211+
content="Found 10 tutorials", tool_call_id="call_789", status="success"
212+
)
213+
)
214+
history.add_message(AIMessage(content="Here are the tutorials I found..."))
215+
216+
# Verify all messages were stored
217+
assert len(stored_messages) == 4
218+
219+
# Verify the ToolMessage has tool_call_id
220+
tool_msg_data = stored_messages[2]
221+
assert tool_msg_data["type"] == "tool"
222+
assert "tool_call_id" in tool_msg_data["data"]
223+
assert tool_msg_data["data"]["tool_call_id"] == "call_789"
224+
assert tool_msg_data["data"]["status"] == "success"
225+
226+
# Mock query to return all messages
227+
import json
228+
229+
mock_index_instance.query.return_value = [
230+
{"type": msg["type"], "$.data": json.dumps(msg["data"])}
231+
for msg in stored_messages
232+
]
233+
234+
# Retrieve all messages - should not raise KeyError
235+
messages = history.messages
236+
237+
assert len(messages) == 4
238+
assert isinstance(messages[0], HumanMessage)
239+
assert isinstance(messages[1], AIMessage)
240+
assert isinstance(messages[2], ToolMessage)
241+
assert isinstance(messages[3], AIMessage)
242+
assert messages[2].tool_call_id == "call_789"

0 commit comments

Comments
 (0)