Skip to content

Commit

Permalink
openai[patch]: default to invoke on o1 stream() (#27983)
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan authored Nov 9, 2024
1 parent 503f248 commit 33dbfba
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 36 deletions.
58 changes: 22 additions & 36 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,24 +632,6 @@ def _stream(
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}

if "response_format" in payload and is_basemodel_subclass(
payload["response_format"]
):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = self._generate(
messages, stop, run_manager=run_manager, **kwargs
)
msg = chat_result.generations[0].message
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
)
return
if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
Expand Down Expand Up @@ -783,24 +765,6 @@ async def _astream(
payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
base_generation_info = {}
if "response_format" in payload and is_basemodel_subclass(
payload["response_format"]
):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
chat_result = await self._agenerate(
messages, stop, run_manager=run_manager, **kwargs
)
msg = chat_result.generations[0].message
yield ChatGenerationChunk(
message=AIMessageChunk(
**msg.dict(exclude={"type", "additional_kwargs"}),
# preserve the "parsed" Pydantic object without converting to dict
additional_kwargs=msg.additional_kwargs,
),
generation_info=chat_result.generations[0].generation_info,
)
return
if self.include_response_headers:
raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
Expand Down Expand Up @@ -999,6 +963,28 @@ def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
num_tokens += 3
return num_tokens

def _should_stream(
self,
*,
async_api: bool,
run_manager: Optional[
Union[CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun]
] = None,
response_format: Optional[Union[dict, type]] = None,
**kwargs: Any,
) -> bool:
if isinstance(response_format, type) and is_basemodel_subclass(response_format):
# TODO: Add support for streaming with Pydantic response_format.
warnings.warn("Streaming with Pydantic response_format not yet supported.")
return False
if self.model_name.startswith("o1"):
# TODO: Add support for streaming with o1 once supported.
return False

return super()._should_stream(
async_api=async_api, run_manager=run_manager, **kwargs
)

@deprecated(
since="0.2.1",
alternative="langchain_openai.chat_models.base.ChatOpenAI.bind_tools",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1061,3 +1061,27 @@ def test_prediction_tokens() -> None:
]
assert output_token_details["accepted_prediction_tokens"] > 0
assert output_token_details["rejected_prediction_tokens"] > 0


def test_stream_o1() -> None:
list(ChatOpenAI(model="o1-mini").stream("how are you"))


async def test_astream_o1() -> None:
async for _ in ChatOpenAI(model="o1-mini").astream("how are you"):
pass


class Foo(BaseModel):
response: str


def test_stream_response_format() -> None:
list(ChatOpenAI(model="gpt-4o-mini").stream("how are ya", response_format=Foo))


async def test_astream_response_format() -> None:
async for _ in ChatOpenAI(model="gpt-4o-mini").astream(
"how are ya", response_format=Foo
):
pass

0 comments on commit 33dbfba

Please sign in to comment.