__all__ = ['GooglePromptDriver']
module-attribute
Bases:
BasePromptDriver
Attributes
Name | Type | Description |
---|---|---|
api_key | Optional[str] | Google API key. |
model | str | Google model name. |
client | GenerativeModel | Custom GenerativeModel client. |
top_p | Optional[float] | Optional value for top_p. |
top_k | Optional[int] | Optional value for top_k. |
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@define class GooglePromptDriver(BasePromptDriver): """Google Prompt Driver. Attributes: api_key: Google API key. model: Google model name. client: Custom `GenerativeModel` client. top_p: Optional value for top_p. top_k: Optional value for top_k. """ api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) model: str = field(kw_only=True, metadata={"serializable": True}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True, ) top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: StructuredOutputStrategy = field( default="tool", kw_only=True, metadata={"serializable": True} ) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) _client: Optional[GenerativeModel] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @lazy_property() def client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) return genai.GenerativeModel(self.model) @observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self.__to_google_messages(prompt_stack) params = self._base_params(prompt_stack) logger.debug((messages, params["generation_config"].__dict__)) response: GenerateContentResponse = self.client.generate_content(messages, **params) logger.debug(response.to_dict()) usage_metadata = response.usage_metadata return Message( content=[self.__to_prompt_stack_message_content(part) for part in response.parts], role=Message.ASSISTANT_ROLE, usage=Message.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count, ), ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: messages = self.__to_google_messages(prompt_stack) params = {**self._base_params(prompt_stack), "stream": True} logger.debug((messages, params)) response: GenerateContentResponse = self.client.generate_content( messages, **params, ) prompt_token_count = None for chunk in response: logger.debug(chunk.to_dict()) usage_metadata = chunk.usage_metadata content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None # Only want to output the prompt token count once since it is static each chunk if prompt_token_count is None: prompt_token_count = usage_metadata.prompt_token_count yield DeltaMessage( content=content, usage=DeltaMessage.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count, ), ) else: yield DeltaMessage( content=content, usage=DeltaMessage.Usage(output_tokens=usage_metadata.candidates_token_count), ) def _base_params(self, prompt_stack: PromptStack) -> dict: types = import_optional_dependency("google.generativeai.types") protos = import_optional_dependency("google.generativeai.protos") system_messages = prompt_stack.system_messages if system_messages: self.client._system_instruction = types.ContentDict( role="system", parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) params = { "generation_config": types.GenerationConfig( **{ # For some reason, providing stop sequences when streaming breaks native functions # https://github.com/google-gemini/generative-ai-python/issues/446 "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences, "max_output_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, **self.extra_params, }, ), } if prompt_stack.tools and self.use_native_tools: params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_config"]["function_calling_config"]["mode"] = "auto" params["tools"] = self.__to_google_tools(prompt_stack.tools) return params def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") return [ types.ContentDict( { "role": self.__to_google_role(message), "parts": [self.__to_google_message_content(content) for content in message.content], }, ) for message in prompt_stack.messages if not message.is_system() ] def __to_google_role(self, message: Message) -> str: if message.is_assistant(): return "model" return "user" def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]: types = import_optional_dependency("google.generativeai.types") tool_declarations = [] for tool in tools: for activity in tool.activities(): schema = tool.to_activity_json_schema(activity, "Parameters Schema") if "values" in schema["properties"]: schema = schema["properties"]["values"] schema = remove_key_in_dict_recursively(schema, "additionalProperties") tool_declaration = types.FunctionDeclaration( name=tool.to_native_tool_name(activity), description=tool.activity_description(activity), **( { "parameters": { "type": schema["type"], "properties": schema["properties"], "required": schema.get("required", []), } } if schema.get("properties") else {} ), ) tool_declarations.append(tool_declaration) return tool_declarations def __to_google_message_content(self, content: BaseMessageContent) -> ContentDict | Part | str: types = import_optional_dependency("google.generativeai.types") protos = import_optional_dependency("google.generativeai.protos") if isinstance(content, TextMessageContent): return content.artifact.to_text() if isinstance(content, ImageMessageContent): if isinstance(content.artifact, ImageArtifact): return types.ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value) # TODO: Google requires uploading to the files endpoint: https://ai.google.dev/gemini-api/docs/image-understanding#upload-image # Can be worked around by using GenericMessageContent, similar to videos. raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}") if isinstance(content, ActionCallMessageContent): action = content.artifact.value return protos.Part(function_call=protos.FunctionCall(name=action.tag, args=action.input)) if isinstance(content, ActionResultMessageContent): artifact = content.artifact return protos.Part( function_response=protos.FunctionResponse( name=content.action.to_native_tool_name(), response=artifact.to_dict(), ), ) if isinstance(content, GenericMessageContent): return content.artifact.value raise ValueError(f"Unsupported prompt stack content type: {type(content)}") def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent: json_format = import_optional_dependency("google.protobuf.json_format") if content.text: return TextMessageContent(TextArtifact(content.text)) if content.function_call: function_call = content.function_call name, path = ToolAction.from_native_tool_name(function_call.name) args = json_format.MessageToDict(function_call._pb).get("args", {}) return ActionCallMessageContent( artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args)), ) raise ValueError(f"Unsupported message content type {content}") def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent: json_format = import_optional_dependency("google.protobuf.json_format") if content.text: return TextDeltaMessageContent(content.text) if content.function_call: function_call = content.function_call name, path = ToolAction.from_native_tool_name(function_call.name) args = json_format.MessageToDict(function_call._pb).get("args", {}) return ActionCallDeltaMessageContent( tag=function_call.name, name=name, path=path, partial_input=json.dumps(args), ) raise ValueError(f"Unsupported message content type {content}")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributemodel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestructured_output_strategy = field(default='tool', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True)
class-attribute instance-attributetool_choice = field(default='auto', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetop_k = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetop_p = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeuse_native_tools = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
__to_google_message_content(content)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_message_content(self, content: BaseMessageContent) -> ContentDict | Part | str: types = import_optional_dependency("google.generativeai.types") protos = import_optional_dependency("google.generativeai.protos") if isinstance(content, TextMessageContent): return content.artifact.to_text() if isinstance(content, ImageMessageContent): if isinstance(content.artifact, ImageArtifact): return types.ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value) # TODO: Google requires uploading to the files endpoint: https://ai.google.dev/gemini-api/docs/image-understanding#upload-image # Can be worked around by using GenericMessageContent, similar to videos. raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}") if isinstance(content, ActionCallMessageContent): action = content.artifact.value return protos.Part(function_call=protos.FunctionCall(name=action.tag, args=action.input)) if isinstance(content, ActionResultMessageContent): artifact = content.artifact return protos.Part( function_response=protos.FunctionResponse( name=content.action.to_native_tool_name(), response=artifact.to_dict(), ), ) if isinstance(content, GenericMessageContent): return content.artifact.value raise ValueError(f"Unsupported prompt stack content type: {type(content)}")
__to_google_messages(prompt_stack)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") return [ types.ContentDict( { "role": self.__to_google_role(message), "parts": [self.__to_google_message_content(content) for content in message.content], }, ) for message in prompt_stack.messages if not message.is_system() ]
__to_google_role(message)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_role(self, message: Message) -> str: if message.is_assistant(): return "model" return "user"
__to_google_tools(tools)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]: types = import_optional_dependency("google.generativeai.types") tool_declarations = [] for tool in tools: for activity in tool.activities(): schema = tool.to_activity_json_schema(activity, "Parameters Schema") if "values" in schema["properties"]: schema = schema["properties"]["values"] schema = remove_key_in_dict_recursively(schema, "additionalProperties") tool_declaration = types.FunctionDeclaration( name=tool.to_native_tool_name(activity), description=tool.activity_description(activity), **( { "parameters": { "type": schema["type"], "properties": schema["properties"], "required": schema.get("required", []), } } if schema.get("properties") else {} ), ) tool_declarations.append(tool_declaration) return tool_declarations
__to_prompt_stack_delta_message_content(content)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent: json_format = import_optional_dependency("google.protobuf.json_format") if content.text: return TextDeltaMessageContent(content.text) if content.function_call: function_call = content.function_call name, path = ToolAction.from_native_tool_name(function_call.name) args = json_format.MessageToDict(function_call._pb).get("args", {}) return ActionCallDeltaMessageContent( tag=function_call.name, name=name, path=path, partial_input=json.dumps(args), ) raise ValueError(f"Unsupported message content type {content}")
__to_prompt_stack_message_content(content)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent: json_format = import_optional_dependency("google.protobuf.json_format") if content.text: return TextMessageContent(TextArtifact(content.text)) if content.function_call: function_call = content.function_call name, path = ToolAction.from_native_tool_name(function_call.name) args = json_format.MessageToDict(function_call._pb).get("args", {}) return ActionCallMessageContent( artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args)), ) raise ValueError(f"Unsupported message content type {content}")
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: types = import_optional_dependency("google.generativeai.types") protos = import_optional_dependency("google.generativeai.protos") system_messages = prompt_stack.system_messages if system_messages: self.client._system_instruction = types.ContentDict( role="system", parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) params = { "generation_config": types.GenerationConfig( **{ # For some reason, providing stop sequences when streaming breaks native functions # https://github.com/google-gemini/generative-ai-python/issues/446 "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences, "max_output_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, **self.extra_params, }, ), } if prompt_stack.tools and self.use_native_tools: params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_config"]["function_calling_config"]["mode"] = "auto" params["tools"] = self.__to_google_tools(prompt_stack.tools) return params
client()
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@lazy_property() def client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) return genai.GenerativeModel(self.model)
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self.__to_google_messages(prompt_stack) params = self._base_params(prompt_stack) logger.debug((messages, params["generation_config"].__dict__)) response: GenerateContentResponse = self.client.generate_content(messages, **params) logger.debug(response.to_dict()) usage_metadata = response.usage_metadata return Message( content=[self.__to_prompt_stack_message_content(part) for part in response.parts], role=Message.ASSISTANT_ROLE, usage=Message.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count, ), )
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: messages = self.__to_google_messages(prompt_stack) params = {**self._base_params(prompt_stack), "stream": True} logger.debug((messages, params)) response: GenerateContentResponse = self.client.generate_content( messages, **params, ) prompt_token_count = None for chunk in response: logger.debug(chunk.to_dict()) usage_metadata = chunk.usage_metadata content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None # Only want to output the prompt token count once since it is static each chunk if prompt_token_count is None: prompt_token_count = usage_metadata.prompt_token_count yield DeltaMessage( content=content, usage=DeltaMessage.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count, ), ) else: yield DeltaMessage( content=content, usage=DeltaMessage.Usage(output_tokens=usage_metadata.candidates_token_count), )
validatestructured_output_strategy(, value)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value
- On this page
- Attributes
- __to_google_message_content(content)
- __to_google_messages(prompt_stack)
- __to_google_role(message)
- __to_google_tools(tools)
- __to_prompt_stack_delta_message_content(content)
- __to_prompt_stack_message_content(content)
- _base_params(prompt_stack)
- client()
- try_run(prompt_stack)
- try_stream(prompt_stack)
- validatestructured_output_strategy(, value)
Could this page be better? Report a problem or suggest an addition!