openai_assistant_driver
Bases:
BaseAssistantDriver
Source Code in griptape/drivers/assistant/openai_assistant_driver.py
@define class OpenAiAssistantDriver(BaseAssistantDriver): class EventHandler(AssistantEventHandler): @override def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: if delta.value is not None: EventBus.publish_event(TextChunkEvent(token=delta.value)) @override def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None: if delta.type == "code_interpreter" and delta.code_interpreter is not None: if delta.code_interpreter.input: EventBus.publish_event(TextChunkEvent(token=delta.code_interpreter.input)) if delta.code_interpreter.outputs: EventBus.publish_event(TextChunkEvent(token="\n\noutput >")) for output in delta.code_interpreter.outputs: if output.type == "logs" and output.logs: EventBus.publish_event(TextChunkEvent(token=output.logs)) base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) thread_id: Optional[str] = field(default=None, kw_only=True) assistant_id: str = field(kw_only=True) event_handler: AssistantEventHandler = field( default=Factory(lambda: OpenAiAssistantDriver.EventHandler()), kw_only=True, metadata={"serializable": False} ) auto_create_thread: bool = field(default=True, kw_only=True) _client: Optional[openai.OpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI( base_url=self.base_url, api_key=self.api_key, organization=self.organization, ) def try_run(self, *args: BaseArtifact) -> TextArtifact: if self.thread_id is None: if self.auto_create_thread: thread_id = self.client.beta.threads.create().id self.thread_id = thread_id else: raise ValueError("Thread ID is required but not provided and auto_create_thread is disabled.") else: thread_id = self.thread_id response = self._create_run(thread_id, *args) response.meta.update({"assistant_id": self.assistant_id, "thread_id": self.thread_id}) return response def _create_run(self, thread_id: str, *args: BaseArtifact) -> TextArtifact: content = "\n".join(arg.value for arg in args) message_id = self.client.beta.threads.messages.create(thread_id=thread_id, role="user", content=content) with self.client.beta.threads.runs.stream( thread_id=thread_id, assistant_id=self.assistant_id, event_handler=self.event_handler, ) as stream: stream.until_done() last_messages = stream.get_final_messages() message_contents = [] for message in last_messages: message_contents.append( "".join(content.text.value for content in message.content if content.type == "TextContentBlock") ) message_text = "\n".join(message_contents) response = TextArtifact(message_text) response.meta.update( {"assistant_id": self.assistant_id, "thread_id": self.thread_id, "message_id": message_id} ) return response
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeassistant_id = field(kw_only=True)
class-attribute instance-attributeauto_create_thread = field(default=True, kw_only=True)
class-attribute instance-attributebase_url = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeevent_handler = field(default=Factory(lambda: OpenAiAssistantDriver.EventHandler()), kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeorganization = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributethread_id = field(default=None, kw_only=True)
class-attribute instance-attribute
EventHandler
Bases:
AssistantEventHandlerSource Code in griptape/drivers/assistant/openai_assistant_driver.py
class EventHandler(AssistantEventHandler): @override def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: if delta.value is not None: EventBus.publish_event(TextChunkEvent(token=delta.value)) @override def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None: if delta.type == "code_interpreter" and delta.code_interpreter is not None: if delta.code_interpreter.input: EventBus.publish_event(TextChunkEvent(token=delta.code_interpreter.input)) if delta.code_interpreter.outputs: EventBus.publish_event(TextChunkEvent(token="\n\noutput >")) for output in delta.code_interpreter.outputs: if output.type == "logs" and output.logs: EventBus.publish_event(TextChunkEvent(token=output.logs))
_create_run(thread_id, *args)
Source Code in griptape/drivers/assistant/openai_assistant_driver.py
def _create_run(self, thread_id: str, *args: BaseArtifact) -> TextArtifact: content = "\n".join(arg.value for arg in args) message_id = self.client.beta.threads.messages.create(thread_id=thread_id, role="user", content=content) with self.client.beta.threads.runs.stream( thread_id=thread_id, assistant_id=self.assistant_id, event_handler=self.event_handler, ) as stream: stream.until_done() last_messages = stream.get_final_messages() message_contents = [] for message in last_messages: message_contents.append( "".join(content.text.value for content in message.content if content.type == "TextContentBlock") ) message_text = "\n".join(message_contents) response = TextArtifact(message_text) response.meta.update( {"assistant_id": self.assistant_id, "thread_id": self.thread_id, "message_id": message_id} ) return response
client()
Source Code in griptape/drivers/assistant/openai_assistant_driver.py
@lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI( base_url=self.base_url, api_key=self.api_key, organization=self.organization, )
try_run(*args)
Source Code in griptape/drivers/assistant/openai_assistant_driver.py
def try_run(self, *args: BaseArtifact) -> TextArtifact: if self.thread_id is None: if self.auto_create_thread: thread_id = self.client.beta.threads.create().id self.thread_id = thread_id else: raise ValueError("Thread ID is required but not provided and auto_create_thread is disabled.") else: thread_id = self.thread_id response = self._create_run(thread_id, *args) response.meta.update({"assistant_id": self.assistant_id, "thread_id": self.thread_id}) return response
Could this page be better? Report a problem or suggest an addition!