cohere
__all__ = ['CoherePromptDriver']
module-attribute
Bases:
BasePromptDriver
Attributes
Name | Type | Description |
---|---|---|
api_key | str | Cohere API key. |
model | str | Cohere model name. |
client | ClientV2 | Custom cohere.Client . |
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@define(kw_only=True) class CoherePromptDriver(BasePromptDriver): """Cohere Prompt Driver. Attributes: api_key: Cohere API key. model: Cohere model name. client: Custom `cohere.Client`. """ api_key: str = field(default=None, metadata={"serializable": False}) model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Optional[ClientV2] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), ) @lazy_property() def client(self) -> ClientV2: return import_optional_dependency("cohere").ClientV2(self.api_key, log_warning_experimental_features=False) @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) result: ChatResponse = self.client.chat(**params) logger.debug(result.model_dump()) return Message( content=self.__to_prompt_stack_message_content(result.message), role=Message.ASSISTANT_ROLE, usage=Message.Usage( input_tokens=result.usage.tokens.input_tokens if result.usage and result.usage.tokens else 0, output_tokens=result.usage.tokens.output_tokens if result.usage and result.usage.tokens else 0, ), ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = self._base_params(prompt_stack) logger.debug(params) result: Iterator[StreamedChatResponseV2] = self.client.chat_stream(**params) for event in result: if event.type == "stream-end": usage = event.response.meta.tokens yield DeltaMessage( usage=DeltaMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) elif event.type in ("tool-plan-delta", "content-delta", "tool-call-start", "tool-call-delta"): yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event)) def _base_params(self, prompt_stack: PromptStack) -> dict: tool_results = [] messages = self.__to_cohere_messages(prompt_stack.messages) params = { "model": self.model, "messages": messages, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, **({"tool_results": tool_results} if tool_results else {}), **self.extra_params, } if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": params["response_format"] = {"type": "json_object", "schema": prompt_stack.to_output_json_schema()} if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_cohere_tools(prompt_stack.tools) return params def __to_cohere_messages(self, messages: list[Message]) -> list[dict]: cohere_messages = [] for message in messages: # If the message only contains textual content we can send it as a single content. if message.is_text(): cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()}) # Action results must be sent as separate messages. elif message.has_any_content_type(ActionResultMessageContent): cohere_messages.extend( { "role": self.__to_cohere_role(message, action_result), "content": self.__to_cohere_message_content(action_result), "tool_call_id": action_result.action.tag, } for action_result in message.get_content_type(ActionResultMessageContent) ) if message.has_any_content_type(TextMessageContent): cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()}) else: cohere_message = { "role": self.__to_cohere_role(message), "content": [ self.__to_cohere_message_content(content) for content in [ content for content in message.content if not isinstance(content, ActionCallMessageContent) ] ], } # Action calls must be attached to the message, not sent as content. action_call_content = [ content for content in message.content if isinstance(content, ActionCallMessageContent) ] if action_call_content: cohere_message["tool_calls"] = [ self.__to_cohere_message_content(action_call) for action_call in action_call_content ] cohere_messages.append(cohere_message) return cohere_messages def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict | list[dict]: if isinstance(content, ActionCallMessageContent): action = content.artifact.value return { "type": "function", "id": action.tag, "function": { "name": action.to_native_tool_name(), "arguments": json.dumps(action.input), }, } if isinstance(content, ActionResultMessageContent): artifact = content.artifact if isinstance(artifact, ListArtifact): message_content = [{"type": "text", "text": artifact.to_text()} for artifact in artifact.value] else: message_content = {"type": "text", "text": artifact.to_text()} return message_content return {"type": "text", "text": content.artifact.to_text()} def __to_cohere_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str: if message.is_system(): return "system" if message.is_assistant(): return "assistant" if isinstance(message_content, ActionResultMessageContent): return "tool" return "user" def __to_cohere_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "function": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "parameters": tool.to_activity_json_schema(activity, "Parameters Schema"), }, "type": "function", } for tool in tools for activity in tool.activities() ] def __to_prompt_stack_message_content(self, response: AssistantMessageResponse) -> list[BaseMessageContent]: content = [] if response.content: content.extend([TextMessageContent(TextArtifact(content.text)) for content in response.content]) if response.tool_plan: content.append(TextMessageContent(TextArtifact(response.tool_plan))) if response.tool_calls is not None: content.extend( [ ActionCallMessageContent( ActionArtifact( ToolAction( tag=tool_call.id, name=ToolAction.from_native_tool_name(tool_call.function.name)[0], path=ToolAction.from_native_tool_name(tool_call.function.name)[1], input=json.loads(tool_call.function.arguments), ), ), ) for tool_call in response.tool_calls if tool_call.id is not None and tool_call.function is not None and tool_call.function.name is not None and tool_call.function.arguments is not None ], ) return content def __to_prompt_stack_delta_message_content(self, event: Any) -> BaseDeltaMessageContent: if event.type == "content-delta": return TextDeltaMessageContent(event.delta.message.content.text, index=0) if event.type == "tool-plan-delta": return TextDeltaMessageContent(event.delta.message["tool_plan"]) if event.type == "tool-call-start": tool_call_delta = event.delta.message["tool_calls"] name, path = ToolAction.from_native_tool_name(tool_call_delta["function"]["name"]) return ActionCallDeltaMessageContent(tag=tool_call_delta["id"], name=name, path=path) if event.type == "tool-call-delta": tool_call_delta = event.delta.message["tool_calls"]["function"] return ActionCallDeltaMessageContent(partial_input=tool_call_delta["arguments"]) raise ValueError(f"Unsupported event type: {event.type}")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, metadata={'serializable': False})
class-attribute instance-attributeforce_single_step = field(default=False, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True))
class-attribute instance-attributeuse_native_tools = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
__to_cohere_message_content(content)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict | list[dict]: if isinstance(content, ActionCallMessageContent): action = content.artifact.value return { "type": "function", "id": action.tag, "function": { "name": action.to_native_tool_name(), "arguments": json.dumps(action.input), }, } if isinstance(content, ActionResultMessageContent): artifact = content.artifact if isinstance(artifact, ListArtifact): message_content = [{"type": "text", "text": artifact.to_text()} for artifact in artifact.value] else: message_content = {"type": "text", "text": artifact.to_text()} return message_content return {"type": "text", "text": content.artifact.to_text()}
__to_cohere_messages(messages)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_messages(self, messages: list[Message]) -> list[dict]: cohere_messages = [] for message in messages: # If the message only contains textual content we can send it as a single content. if message.is_text(): cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()}) # Action results must be sent as separate messages. elif message.has_any_content_type(ActionResultMessageContent): cohere_messages.extend( { "role": self.__to_cohere_role(message, action_result), "content": self.__to_cohere_message_content(action_result), "tool_call_id": action_result.action.tag, } for action_result in message.get_content_type(ActionResultMessageContent) ) if message.has_any_content_type(TextMessageContent): cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()}) else: cohere_message = { "role": self.__to_cohere_role(message), "content": [ self.__to_cohere_message_content(content) for content in [ content for content in message.content if not isinstance(content, ActionCallMessageContent) ] ], } # Action calls must be attached to the message, not sent as content. action_call_content = [ content for content in message.content if isinstance(content, ActionCallMessageContent) ] if action_call_content: cohere_message["tool_calls"] = [ self.__to_cohere_message_content(action_call) for action_call in action_call_content ] cohere_messages.append(cohere_message) return cohere_messages
__to_cohere_role(message, message_content=None)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str: if message.is_system(): return "system" if message.is_assistant(): return "assistant" if isinstance(message_content, ActionResultMessageContent): return "tool" return "user"
__to_cohere_tools(tools)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "function": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "parameters": tool.to_activity_json_schema(activity, "Parameters Schema"), }, "type": "function", } for tool in tools for activity in tool.activities() ]
__to_prompt_stack_delta_message_content(event)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, event: Any) -> BaseDeltaMessageContent: if event.type == "content-delta": return TextDeltaMessageContent(event.delta.message.content.text, index=0) if event.type == "tool-plan-delta": return TextDeltaMessageContent(event.delta.message["tool_plan"]) if event.type == "tool-call-start": tool_call_delta = event.delta.message["tool_calls"] name, path = ToolAction.from_native_tool_name(tool_call_delta["function"]["name"]) return ActionCallDeltaMessageContent(tag=tool_call_delta["id"], name=name, path=path) if event.type == "tool-call-delta": tool_call_delta = event.delta.message["tool_calls"]["function"] return ActionCallDeltaMessageContent(partial_input=tool_call_delta["arguments"]) raise ValueError(f"Unsupported event type: {event.type}")
__to_prompt_stack_message_content(response)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_prompt_stack_message_content(self, response: AssistantMessageResponse) -> list[BaseMessageContent]: content = [] if response.content: content.extend([TextMessageContent(TextArtifact(content.text)) for content in response.content]) if response.tool_plan: content.append(TextMessageContent(TextArtifact(response.tool_plan))) if response.tool_calls is not None: content.extend( [ ActionCallMessageContent( ActionArtifact( ToolAction( tag=tool_call.id, name=ToolAction.from_native_tool_name(tool_call.function.name)[0], path=ToolAction.from_native_tool_name(tool_call.function.name)[1], input=json.loads(tool_call.function.arguments), ), ), ) for tool_call in response.tool_calls if tool_call.id is not None and tool_call.function is not None and tool_call.function.name is not None and tool_call.function.arguments is not None ], ) return content
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: tool_results = [] messages = self.__to_cohere_messages(prompt_stack.messages) params = { "model": self.model, "messages": messages, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, **({"tool_results": tool_results} if tool_results else {}), **self.extra_params, } if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": params["response_format"] = {"type": "json_object", "schema": prompt_stack.to_output_json_schema()} if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_cohere_tools(prompt_stack.tools) return params
client()
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@lazy_property() def client(self) -> ClientV2: return import_optional_dependency("cohere").ClientV2(self.api_key, log_warning_experimental_features=False)
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) result: ChatResponse = self.client.chat(**params) logger.debug(result.model_dump()) return Message( content=self.__to_prompt_stack_message_content(result.message), role=Message.ASSISTANT_ROLE, usage=Message.Usage( input_tokens=result.usage.tokens.input_tokens if result.usage and result.usage.tokens else 0, output_tokens=result.usage.tokens.output_tokens if result.usage and result.usage.tokens else 0, ), )
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = self._base_params(prompt_stack) logger.debug(params) result: Iterator[StreamedChatResponseV2] = self.client.chat_stream(**params) for event in result: if event.type == "stream-end": usage = event.response.meta.tokens yield DeltaMessage( usage=DeltaMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) elif event.type in ("tool-plan-delta", "content-delta", "tool-call-start", "tool-call-delta"): yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
- On this page
- Attributes
- __to_cohere_message_content(content)
- __to_cohere_messages(messages)
- __to_cohere_role(message, message_content=None)
- __to_cohere_tools(tools)
- __to_prompt_stack_delta_message_content(event)
- __to_prompt_stack_message_content(response)
- _base_params(prompt_stack)
- client()
- try_run(prompt_stack)
- try_stream(prompt_stack)
Could this page be better? Report a problem or suggest an addition!