drivers
__all__ = ['AmazonBedrockCohereEmbeddingDriver', 'AmazonBedrockImageGenerationDriver', 'AmazonBedrockPromptDriver', 'AmazonBedrockTitanEmbeddingDriver', 'AmazonDynamoDbConversationMemoryDriver', 'AmazonOpenSearchVectorStoreDriver', 'AmazonRedshiftSqlDriver', 'AmazonS3FileManagerDriver', 'AmazonSageMakerJumpstartEmbeddingDriver', 'AmazonSageMakerJumpstartPromptDriver', 'AmazonSqsEventListenerDriver', 'AnthropicPromptDriver', 'AstraDbVectorStoreDriver', 'AwsIotCoreEventListenerDriver', 'AzureMongoDbVectorStoreDriver', 'AzureOpenAiChatPromptDriver', 'AzureOpenAiEmbeddingDriver', 'AzureOpenAiImageGenerationDriver', 'AzureOpenAiTextToSpeechDriver', 'BaseAssistantDriver', 'BaseAudioTranscriptionDriver', 'BaseConversationMemoryDriver', 'BaseDiffusionImageGenerationPipelineDriver', 'BaseEmbeddingDriver', 'BaseEventListenerDriver', 'BaseFileManagerDriver', 'BaseImageGenerationDriver', 'BaseImageGenerationModelDriver', 'BaseMultiModelImageGenerationDriver', 'BaseObservabilityDriver', 'BasePromptDriver', 'BaseRerankDriver', 'BaseRulesetDriver', 'BaseSqlDriver', 'BaseStructureRunDriver', 'BaseTextToSpeechDriver', 'BaseVectorStoreDriver', 'BaseWebScraperDriver', 'BaseWebSearchDriver', 'BedrockStableDiffusionImageGenerationModelDriver', 'BedrockTitanImageGenerationModelDriver', 'CohereEmbeddingDriver', 'CoherePromptDriver', 'CohereRerankDriver', 'DatadogObservabilityDriver', 'DuckDuckGoWebSearchDriver', 'DummyAudioTranscriptionDriver', 'DummyEmbeddingDriver', 'DummyImageGenerationDriver', 'DummyPromptDriver', 'DummyTextToSpeechDriver', 'DummyVectorStoreDriver', 'ElevenLabsTextToSpeechDriver', 'ExaWebSearchDriver', 'GoogleEmbeddingDriver', 'GooglePromptDriver', 'GoogleWebSearchDriver', 'GriptapeCloudAssistantDriver', 'GriptapeCloudConversationMemoryDriver', 'GriptapeCloudEventListenerDriver', 'GriptapeCloudFileManagerDriver', 'GriptapeCloudObservabilityDriver', 'GriptapeCloudPromptDriver', 'GriptapeCloudRulesetDriver', 'GriptapeCloudStructureRunDriver', 'GriptapeCloudVectorStoreDriver', 'GrokPromptDriver', 'HuggingFaceHubEmbeddingDriver', 'HuggingFaceHubPromptDriver', 'HuggingFacePipelineImageGenerationDriver', 'HuggingFacePipelinePromptDriver', 'LeonardoImageGenerationDriver', 'LocalConversationMemoryDriver', 'LocalFileManagerDriver', 'LocalRerankDriver', 'LocalRulesetDriver', 'LocalStructureRunDriver', 'LocalVectorStoreDriver', 'MarkdownifyWebScraperDriver', 'MarqoVectorStoreDriver', 'MongoDbAtlasVectorStoreDriver', 'NoOpObservabilityDriver', 'OllamaEmbeddingDriver', 'OllamaPromptDriver', 'OpenAiAssistantDriver', 'OpenAiAudioTranscriptionDriver', 'OpenAiChatPromptDriver', 'OpenAiEmbeddingDriver', 'OpenAiImageGenerationDriver', 'OpenAiTextToSpeechDriver', 'OpenSearchVectorStoreDriver', 'OpenTelemetryObservabilityDriver', 'PerplexityPromptDriver', 'PerplexityWebSearchDriver', 'PgAiKnowledgeBaseVectorStoreDriver', 'PgVectorVectorStoreDriver', 'PineconeVectorStoreDriver', 'ProxyWebScraperDriver', 'PusherEventListenerDriver', 'QdrantVectorStoreDriver', 'RedisConversationMemoryDriver', 'RedisVectorStoreDriver', 'SnowflakeSqlDriver', 'SqlDriver', 'StableDiffusion3ControlNetImageGenerationPipelineDriver', 'StableDiffusion3ImageGenerationPipelineDriver', 'StableDiffusion3Img2ImgImageGenerationPipelineDriver', 'StableDiffusion3ControlNetImageGenerationPipelineDriver', 'BaseImageGenerationDriver', 'BaseMultiModelImageGenerationDriver', 'OpenAiImageGenerationDriver', 'LeonardoImageGenerationDriver', 'AmazonBedrockImageGenerationDriver', 'AzureOpenAiImageGenerationDriver', 'DummyImageGenerationDriver', 'HuggingFacePipelineImageGenerationDriver', 'GriptapeCloudImageGenerationDriver', 'BaseWebScraperDriver', 'TrafilaturaWebScraperDriver', 'MarkdownifyWebScraperDriver', 'ProxyWebScraperDriver', 'BaseWebSearchDriver', 'GoogleWebSearchDriver', 'DuckDuckGoWebSearchDriver', 'ExaWebSearchDriver', 'TavilyWebSearchDriver', 'TrafilaturaWebScraperDriver', 'VoyageAiEmbeddingDriver', 'WebhookEventListenerDriver']
module-attribute
Bases:
BaseEmbeddingDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | Embedding model name. Defaults to DEFAULT_MODEL. |
input_type | str | Defaults to search_query . Prepends special tokens to differentiate each type from one another: search_document when you encode documents for embeddings that you store in a vector database. search_query when querying your vector DB to find relevant documents. |
session | Session | Optionally provide custom boto3.Session . |
tokenizer | BaseTokenizer | Optionally provide custom BedrockCohereTokenizer . |
client | BedrockRuntimeClient | Optionally provide custom bedrock-runtime client. |
Source Code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
@define class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver): """Amazon Bedrock Cohere Embedding Driver. Attributes: model: Embedding model name. Defaults to DEFAULT_MODEL. input_type: Defaults to `search_query`. Prepends special tokens to differentiate each type from one another: `search_document` when you encode documents for embeddings that you store in a vector database. `search_query` when querying your vector DB to find relevant documents. session: Optionally provide custom `boto3.Session`. tokenizer: Optionally provide custom `BedrockCohereTokenizer`. client: Optionally provide custom `bedrock-runtime` client. """ DEFAULT_MODEL = "cohere.embed-english-v3" model: str = field(default=DEFAULT_MODEL, kw_only=True) input_type: str = field(default="search_query", kw_only=True) session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) tokenizer: BaseTokenizer = field( default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True, ) _client: Optional[BedrockRuntimeClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> BedrockRuntimeClient: return self.session.client("bedrock-runtime") def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: payload = {"input_type": self.input_type, "texts": [chunk]} response = self.client.invoke_model( body=json.dumps(payload), modelId=self.model, accept="*/*", contentType="application/json", ) response_body = json.loads(response.get("body").read()) return response_body.get("embeddings")[0]
DEFAULT_MODEL = 'cohere.embed-english-v3'
class-attribute instance-attribute_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeinput_type = field(default='search_query', kw_only=True)
class-attribute instance-attributemodel = field(default=DEFAULT_MODEL, kw_only=True)
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True)
class-attribute instance-attribute
client()
Source Code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
@lazy_property() def client(self) -> BedrockRuntimeClient: return self.session.client("bedrock-runtime")
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: payload = {"input_type": self.input_type, "texts": [chunk]} response = self.client.invoke_model( body=json.dumps(payload), modelId=self.model, accept="*/*", contentType="application/json", ) response_body = json.loads(response.get("body").read()) return response_body.get("embeddings")[0]
AmazonBedrockImageGenerationDriver
Bases:
BaseMultiModelImageGenerationDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | Bedrock model ID. |
session | Session | boto3 session. |
client | BedrockRuntimeClient | Bedrock runtime client. |
image_width | int | Width of output images. Defaults to 512 and must be a multiple of 64. |
image_height | int | Height of output images. Defaults to 512 and must be a multiple of 64. |
seed | Optional[int] | Optionally provide a consistent seed to generation requests, increasing consistency in output. |
Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
@define class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver): """Driver for image generation models provided by Amazon Bedrock. Attributes: model: Bedrock model ID. session: boto3 session. client: Bedrock runtime client. image_width: Width of output images. Defaults to 512 and must be a multiple of 64. image_height: Height of output images. Defaults to 512 and must be a multiple of 64. seed: Optionally provide a consistent seed to generation requests, increasing consistency in output. """ session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) image_width: int = field(default=512, kw_only=True, metadata={"serializable": True}) image_height: int = field(default=512, kw_only=True, metadata={"serializable": True}) seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) _client: Optional[BedrockRuntimeClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> BedrockRuntimeClient: return self.session.client("bedrock-runtime") def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: request = self.image_generation_model_driver.text_to_image_request_parameters( prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed, ) image_bytes = self._make_request(request) return ImageArtifact( value=image_bytes, format="png", width=self.image_width, height=self.image_height, meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: request = self.image_generation_model_driver.image_variation_request_parameters( prompts, image=image, negative_prompts=negative_prompts, seed=self.seed, ) image_bytes = self._make_request(request) return ImageArtifact( value=image_bytes, format="png", width=image.width, height=image.height, meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: request = self.image_generation_model_driver.image_inpainting_request_parameters( prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed, ) image_bytes = self._make_request(request) return ImageArtifact( value=image_bytes, format="png", width=image.width, height=image.height, meta={"prompt": ", ".join(prompts), "model": self.model}, ) def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: request = self.image_generation_model_driver.image_outpainting_request_parameters( prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed, ) image_bytes = self._make_request(request) return ImageArtifact( value=image_bytes, format="png", width=image.width, height=image.height, meta={"prompt": ", ".join(prompts), "model": self.model}, ) def _make_request(self, request: dict) -> bytes: response = self.client.invoke_model( body=json.dumps(request), modelId=self.model, accept="application/json", contentType="application/json", ) response_body = json.loads(response.get("body").read()) try: image_bytes = self.image_generation_model_driver.get_generated_image(response_body) except Exception as e: raise ValueError(f"Inpainting generation failed: {e}") from e return image_bytes
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeimage_height = field(default=512, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeimage_width = field(default=512, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeseed = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attribute
_make_request(request)
Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def _make_request(self, request: dict) -> bytes: response = self.client.invoke_model( body=json.dumps(request), modelId=self.model, accept="application/json", contentType="application/json", ) response_body = json.loads(response.get("body").read()) try: image_bytes = self.image_generation_model_driver.get_generated_image(response_body) except Exception as e: raise ValueError(f"Inpainting generation failed: {e}") from e return image_bytes
client()
Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
@lazy_property() def client(self) -> BedrockRuntimeClient: return self.session.client("bedrock-runtime")
try_image_inpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: request = self.image_generation_model_driver.image_inpainting_request_parameters( prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed, ) image_bytes = self._make_request(request) return ImageArtifact( value=image_bytes, format="png", width=image.width, height=image.height, meta={"prompt": ", ".join(prompts), "model": self.model}, )
try_image_outpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: request = self.image_generation_model_driver.image_outpainting_request_parameters( prompts, image=image, mask=mask, negative_prompts=negative_prompts, seed=self.seed, ) image_bytes = self._make_request(request) return ImageArtifact( value=image_bytes, format="png", width=image.width, height=image.height, meta={"prompt": ", ".join(prompts), "model": self.model}, )
try_image_variation(prompts, image, negative_prompts=None)
Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: request = self.image_generation_model_driver.image_variation_request_parameters( prompts, image=image, negative_prompts=negative_prompts, seed=self.seed, ) image_bytes = self._make_request(request) return ImageArtifact( value=image_bytes, format="png", width=image.width, height=image.height, meta={"prompt": ", ".join(prompts), "model": self.model}, )
try_text_to_image(prompts, negative_prompts=None)
Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: request = self.image_generation_model_driver.text_to_image_request_parameters( prompts, self.image_width, self.image_height, negative_prompts=negative_prompts, seed=self.seed, ) image_bytes = self._make_request(request) return ImageArtifact( value=image_bytes, format="png", width=self.image_width, height=self.image_height, meta={"prompt": ", ".join(prompts), "model": self.model}, )
AmazonBedrockPromptDriver
Bases:
BasePromptDriver
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@define class AmazonBedrockPromptDriver(BasePromptDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True) tokenizer: BaseTokenizer = field( default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: StructuredOutputStrategy = field( default="tool", kw_only=True, metadata={"serializable": True} ) tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True}) _client: Optional[Any] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @lazy_property() def client(self) -> Any: return self.session.client("bedrock-runtime") @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.converse(**params) logger.debug(response) usage = response["usage"] output_message = response["output"]["message"] message_content = output_message.get("content", []) # Move reasoning content to the beginning of the content to have it appear first in output reasoning_content_index = next( (i for i, content in enumerate(message_content) if "reasoningContent" in content), None, ) if reasoning_content_index is not None: message_content.insert(0, message_content.pop(reasoning_content_index)) return Message( content=[self.__to_prompt_stack_message_content(content) for content in message_content], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]), ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.converse_stream(**params) stream = response.get("stream") if stream is not None: for event in stream: logger.debug(event) if "contentBlockDelta" in event or "contentBlockStart" in event: yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event)) elif "metadata" in event: usage = event["metadata"]["usage"] yield DeltaMessage( usage=DeltaMessage.Usage( input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"], ), ) else: raise Exception("model response is empty") def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages] messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()]) params = { "modelId": self.model, "messages": messages, "system": system_messages, "inferenceConfig": { "temperature": self.temperature, **({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}), }, "additionalModelRequestFields": self.additional_model_request_fields, **self.extra_params, } if prompt_stack.tools and self.use_native_tools: params["toolConfig"] = { "tools": [], "toolChoice": self.tool_choice, } if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["toolConfig"]["toolChoice"] = {"any": {}} params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools) return params def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]: return [ { "role": self.__to_bedrock_role(message), "content": [self.__to_bedrock_message_content(content) for content in message.content], } for message in messages ] def __to_bedrock_role(self, message: Message) -> str: if message.is_assistant(): return "assistant" return "user" def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "toolSpec": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "inputSchema": { "json": tool.to_activity_json_schema(activity, "http://json-schema.org/draft-07/schema#"), }, }, } for tool in tools for activity in tool.activities() ] def __to_bedrock_message_content(self, content: BaseMessageContent) -> dict: if isinstance(content, TextMessageContent): return {"text": content.artifact.to_text()} if isinstance(content, ImageMessageContent): artifact = content.artifact if isinstance(artifact, ImageArtifact): return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}} if isinstance(artifact, ImageUrlArtifact): return {"image": {"format": "png", "source": {"s3Location": {"uri": artifact.value}}}} raise ValueError(f"Unsupported image artifact type: {type(artifact)}") if isinstance(content, ActionCallMessageContent): action_call = content.artifact.value return { "toolUse": { "toolUseId": action_call.tag, "name": f"{action_call.name}_{action_call.path}", "input": action_call.input, }, } if isinstance(content, ActionResultMessageContent): artifact = content.artifact if isinstance(artifact, ListArtifact): message_content = [self.__to_bedrock_tool_use_content(artifact) for artifact in artifact.value] else: message_content = [self.__to_bedrock_tool_use_content(artifact)] return { "toolResult": { "toolUseId": content.action.tag, "content": message_content, "status": "error" if isinstance(artifact, ErrorArtifact) else "success", }, } return content.artifact.value def __to_bedrock_tool_use_content(self, artifact: BaseArtifact) -> dict: if isinstance(artifact, ImageArtifact): return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}} return {"text": artifact.to_text()} def __to_prompt_stack_message_content(self, content: dict) -> BaseMessageContent: if "text" in content: return TextMessageContent(TextArtifact(content["text"])) if "toolUse" in content: name, path = ToolAction.from_native_tool_name(content["toolUse"]["name"]) return ActionCallMessageContent( artifact=ActionArtifact( value=ToolAction( tag=content["toolUse"]["toolUseId"], name=name, path=path, input=content["toolUse"]["input"], ), ), ) if "reasoningContent" in content: return TextMessageContent(TextArtifact(content["reasoningContent"]["reasoningText"]["text"])) raise ValueError(f"Unsupported message content type: {content}") def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessageContent: if "contentBlockStart" in event: content_block = event["contentBlockStart"]["start"] if "toolUse" in content_block: name, path = ToolAction.from_native_tool_name(content_block["toolUse"]["name"]) return ActionCallDeltaMessageContent( index=event["contentBlockStart"]["contentBlockIndex"], tag=content_block["toolUse"]["toolUseId"], name=name, path=path, ) if "text" in content_block: return TextDeltaMessageContent( content_block["text"], index=event["contentBlockStart"]["contentBlockIndex"], ) raise ValueError(f"Unsupported message content type: {event}") if "contentBlockDelta" in event: content_block_delta = event["contentBlockDelta"] if "text" in content_block_delta["delta"]: return TextDeltaMessageContent( content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"], ) if "toolUse" in content_block_delta["delta"]: return ActionCallDeltaMessageContent( index=content_block_delta["contentBlockIndex"], partial_input=content_block_delta["delta"]["toolUse"]["input"], ) raise ValueError(f"Unsupported message content type: {event}") raise ValueError(f"Unsupported message content type: {event}")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeadditional_model_request_fields = field(default=Factory(dict), kw_only=True)
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attributestructured_output_strategy = field(default='tool', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True)
class-attribute instance-attributetool_choice = field(default=Factory(lambda: {'auto': {}}), kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeuse_native_tools = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
__to_bedrock_message_content(content)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_message_content(self, content: BaseMessageContent) -> dict: if isinstance(content, TextMessageContent): return {"text": content.artifact.to_text()} if isinstance(content, ImageMessageContent): artifact = content.artifact if isinstance(artifact, ImageArtifact): return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}} if isinstance(artifact, ImageUrlArtifact): return {"image": {"format": "png", "source": {"s3Location": {"uri": artifact.value}}}} raise ValueError(f"Unsupported image artifact type: {type(artifact)}") if isinstance(content, ActionCallMessageContent): action_call = content.artifact.value return { "toolUse": { "toolUseId": action_call.tag, "name": f"{action_call.name}_{action_call.path}", "input": action_call.input, }, } if isinstance(content, ActionResultMessageContent): artifact = content.artifact if isinstance(artifact, ListArtifact): message_content = [self.__to_bedrock_tool_use_content(artifact) for artifact in artifact.value] else: message_content = [self.__to_bedrock_tool_use_content(artifact)] return { "toolResult": { "toolUseId": content.action.tag, "content": message_content, "status": "error" if isinstance(artifact, ErrorArtifact) else "success", }, } return content.artifact.value
__to_bedrock_messages(messages)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]: return [ { "role": self.__to_bedrock_role(message), "content": [self.__to_bedrock_message_content(content) for content in message.content], } for message in messages ]
__to_bedrock_role(message)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_role(self, message: Message) -> str: if message.is_assistant(): return "assistant" return "user"
__to_bedrock_tool_use_content(artifact)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_tool_use_content(self, artifact: BaseArtifact) -> dict: if isinstance(artifact, ImageArtifact): return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}} return {"text": artifact.to_text()}
__to_bedrock_tools(tools)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "toolSpec": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "inputSchema": { "json": tool.to_activity_json_schema(activity, "http://json-schema.org/draft-07/schema#"), }, }, } for tool in tools for activity in tool.activities() ]
__to_prompt_stack_delta_message_content(event)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessageContent: if "contentBlockStart" in event: content_block = event["contentBlockStart"]["start"] if "toolUse" in content_block: name, path = ToolAction.from_native_tool_name(content_block["toolUse"]["name"]) return ActionCallDeltaMessageContent( index=event["contentBlockStart"]["contentBlockIndex"], tag=content_block["toolUse"]["toolUseId"], name=name, path=path, ) if "text" in content_block: return TextDeltaMessageContent( content_block["text"], index=event["contentBlockStart"]["contentBlockIndex"], ) raise ValueError(f"Unsupported message content type: {event}") if "contentBlockDelta" in event: content_block_delta = event["contentBlockDelta"] if "text" in content_block_delta["delta"]: return TextDeltaMessageContent( content_block_delta["delta"]["text"], index=content_block_delta["contentBlockIndex"], ) if "toolUse" in content_block_delta["delta"]: return ActionCallDeltaMessageContent( index=content_block_delta["contentBlockIndex"], partial_input=content_block_delta["delta"]["toolUse"]["input"], ) raise ValueError(f"Unsupported message content type: {event}") raise ValueError(f"Unsupported message content type: {event}")
__to_prompt_stack_message_content(content)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_prompt_stack_message_content(self, content: dict) -> BaseMessageContent: if "text" in content: return TextMessageContent(TextArtifact(content["text"])) if "toolUse" in content: name, path = ToolAction.from_native_tool_name(content["toolUse"]["name"]) return ActionCallMessageContent( artifact=ActionArtifact( value=ToolAction( tag=content["toolUse"]["toolUseId"], name=name, path=path, input=content["toolUse"]["input"], ), ), ) if "reasoningContent" in content: return TextMessageContent(TextArtifact(content["reasoningContent"]["reasoningText"]["text"])) raise ValueError(f"Unsupported message content type: {content}")
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages] messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()]) params = { "modelId": self.model, "messages": messages, "system": system_messages, "inferenceConfig": { "temperature": self.temperature, **({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}), }, "additionalModelRequestFields": self.additional_model_request_fields, **self.extra_params, } if prompt_stack.tools and self.use_native_tools: params["toolConfig"] = { "tools": [], "toolChoice": self.tool_choice, } if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["toolConfig"]["toolChoice"] = {"any": {}} params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools) return params
client()
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@lazy_property() def client(self) -> Any: return self.session.client("bedrock-runtime")
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.converse(**params) logger.debug(response) usage = response["usage"] output_message = response["output"]["message"] message_content = output_message.get("content", []) # Move reasoning content to the beginning of the content to have it appear first in output reasoning_content_index = next( (i for i, content in enumerate(message_content) if "reasoningContent" in content), None, ) if reasoning_content_index is not None: message_content.insert(0, message_content.pop(reasoning_content_index)) return Message( content=[self.__to_prompt_stack_message_content(content) for content in message_content], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]), )
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.converse_stream(**params) stream = response.get("stream") if stream is not None: for event in stream: logger.debug(event) if "contentBlockDelta" in event or "contentBlockStart" in event: yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event)) elif "metadata" in event: usage = event["metadata"]["usage"] yield DeltaMessage( usage=DeltaMessage.Usage( input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"], ), ) else: raise Exception("model response is empty")
validatestructured_output_strategy(, value)
Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value
AmazonBedrockTitanEmbeddingDriver
Bases:
BaseEmbeddingDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | Embedding model name. Defaults to DEFAULT_MODEL. |
tokenizer | BaseTokenizer | Optionally provide custom BedrockTitanTokenizer . |
session | Session | Optionally provide custom boto3.Session . |
client | BedrockRuntimeClient | Optionally provide custom bedrock-runtime client. |
Source Code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@define class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver): """Amazon Bedrock Titan Embedding Driver. Attributes: model: Embedding model name. Defaults to DEFAULT_MODEL. tokenizer: Optionally provide custom `BedrockTitanTokenizer`. session: Optionally provide custom `boto3.Session`. client: Optionally provide custom `bedrock-runtime` client. """ DEFAULT_MODEL = "amazon.titan-embed-text-v1" model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True}) session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) tokenizer: BaseTokenizer = field( default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True, ) _client: Optional[BedrockRuntimeClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> BedrockRuntimeClient: return self.session.client("bedrock-runtime") def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact, **kwargs) -> list[float]: if isinstance(artifact, TextArtifact): return self.try_embed_chunk(artifact.value) return self._invoke_model({"inputImage": base64.b64encode(artifact.value).decode()})["embedding"] def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: return self._invoke_model( { "inputText": chunk, } )["embedding"] def _invoke_model(self, payload: dict) -> dict[str, Any]: response = self.client.invoke_model( body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json", ) return json.loads(response.get("body").read())
DEFAULT_MODEL = 'amazon.titan-embed-text-v1'
class-attribute instance-attribute_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributemodel = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True)
class-attribute instance-attribute
_invoke_model(payload)
Source Code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def _invoke_model(self, payload: dict) -> dict[str, Any]: response = self.client.invoke_model( body=json.dumps(payload), modelId=self.model, accept="application/json", contentType="application/json", ) return json.loads(response.get("body").read())
client()
Source Code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@lazy_property() def client(self) -> BedrockRuntimeClient: return self.session.client("bedrock-runtime")
try_embed_artifact(artifact, **kwargs)
Source Code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact, **kwargs) -> list[float]: if isinstance(artifact, TextArtifact): return self.try_embed_chunk(artifact.value) return self._invoke_model({"inputImage": base64.b64encode(artifact.value).decode()})["embedding"]
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: return self._invoke_model( { "inputText": chunk, } )["embedding"]
AmazonDynamoDbConversationMemoryDriver
Bases:
BaseConversationMemoryDriver
Source Code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
@define class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) table_name: str = field(kw_only=True, metadata={"serializable": True}) partition_key: str = field(kw_only=True, metadata={"serializable": True}) value_attribute_key: str = field(kw_only=True, metadata={"serializable": True}) partition_key_value: str = field(kw_only=True, metadata={"serializable": True}) sort_key: Optional[str] = field(default=None, metadata={"serializable": True}) sort_key_value: Optional[str | int] = field(default=None, metadata={"serializable": True}) _table: Optional[Table] = field(default=None, kw_only=True, alias="table", metadata={"serializable": False}) @lazy_property() def table(self) -> Table: return self.session.resource("dynamodb").Table(self.table_name) def store(self, runs: list[Run], metadata: dict) -> None: self.table.update_item( Key=self._get_key(), UpdateExpression="set #attr = :value", ExpressionAttributeNames={"#attr": self.value_attribute_key}, ExpressionAttributeValues={ ":value": json.dumps(self._to_params_dict(runs, metadata)), }, ) def load(self) -> tuple[list[Run], dict[str, Any]]: response = self.table.get_item(Key=self._get_key()) if "Item" in response and self.value_attribute_key in response["Item"]: memory_dict = json.loads(str(response["Item"][self.value_attribute_key])) return self._from_params_dict(memory_dict) return [], {} def _get_key(self) -> dict[str, str | int]: key: dict[str, str | int] = {self.partition_key: self.partition_key_value} if self.sort_key is not None and self.sort_key_value is not None: key[self.sort_key] = self.sort_key_value return key
_table = field(default=None, kw_only=True, alias='table', metadata={'serializable': False})
class-attribute instance-attributepartition_key = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributepartition_key_value = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attributesort_key = field(default=None, metadata={'serializable': True})
class-attribute instance-attributesort_key_value = field(default=None, metadata={'serializable': True})
class-attribute instance-attributetable_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributevalue_attribute_key = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
_get_key()
Source Code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def _get_key(self) -> dict[str, str | int]: key: dict[str, str | int] = {self.partition_key: self.partition_key_value} if self.sort_key is not None and self.sort_key_value is not None: key[self.sort_key] = self.sort_key_value return key
load()
Source Code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]: response = self.table.get_item(Key=self._get_key()) if "Item" in response and self.value_attribute_key in response["Item"]: memory_dict = json.loads(str(response["Item"][self.value_attribute_key])) return self._from_params_dict(memory_dict) return [], {}
store(runs, metadata)
Source Code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict) -> None: self.table.update_item( Key=self._get_key(), UpdateExpression="set #attr = :value", ExpressionAttributeNames={"#attr": self.value_attribute_key}, ExpressionAttributeValues={ ":value": json.dumps(self._to_params_dict(runs, metadata)), }, )
table()
Source Code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
@lazy_property() def table(self) -> Table: return self.session.resource("dynamodb").Table(self.table_name)
AmazonOpenSearchVectorStoreDriver
Bases:
OpenSearchVectorStoreDriver
Attributes
Name | Type | Description |
---|---|---|
session | Session | The boto3 session to use. |
service | str | Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'. |
http_auth | str | tuple[str, str] | The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session. |
client | OpenSearch | An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes. |
Source Code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
@define class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver): """A Vector Store Driver for Amazon OpenSearch. Attributes: session: The boto3 session to use. service: Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'. http_auth: The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session. client: An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes. """ session: Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) service: str = field(default="es", kw_only=True) http_auth: str | tuple[str, str] = field( default=Factory( lambda self: import_optional_dependency("opensearchpy").AWSV4SignerAuth( self.session.get_credentials(), self.session.region_name, self.service, ), takes_self=True, ), ) def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: """Inserts or updates a vector in OpenSearch. If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided. """ vector_id = vector_id or str_to_hash(str(vector)) doc = {"vector": vector, "namespace": namespace, "metadata": meta} doc.update(kwargs) if self.service == "aoss": response = self.client.index(index=self.index_name, body=doc) else: response = self.client.index(index=self.index_name, id=vector_id, body=doc) return response["_id"]
http_auth = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').AWSV4SignerAuth(self.session.get_credentials(), self.session.region_name, self.service), takes_self=True))
class-attribute instance-attributeservice = field(default='es', kw_only=True)
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attribute
upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: """Inserts or updates a vector in OpenSearch. If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided. """ vector_id = vector_id or str_to_hash(str(vector)) doc = {"vector": vector, "namespace": namespace, "metadata": meta} doc.update(kwargs) if self.service == "aoss": response = self.client.index(index=self.index_name, body=doc) else: response = self.client.index(index=self.index_name, id=vector_id, body=doc) return response["_id"]
AmazonRedshiftSqlDriver
Bases:
BaseSqlDriver
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@define class AmazonRedshiftSqlDriver(BaseSqlDriver): database: str = field(kw_only=True) session: boto3.Session = field(kw_only=True) cluster_identifier: Optional[str] = field(default=None, kw_only=True) workgroup_name: Optional[str] = field(default=None, kw_only=True) db_user: Optional[str] = field(default=None, kw_only=True) database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True) wait_for_query_completion_sec: float = field(default=0.3, kw_only=True) _client: Optional[RedshiftDataAPIServiceClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> RedshiftDataAPIServiceClient: return self.session.client("redshift-data") @workgroup_name.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None: if not self.cluster_identifier and not self.workgroup_name: raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`") if self.cluster_identifier and self.workgroup_name: raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both") @classmethod def _process_rows_from_records(cls, records: list) -> list[list]: return [[c[list(c.keys())[0]] for c in r] for r in records] @classmethod def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]: return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows] @classmethod def _process_columns_from_column_metadata(cls, meta: dict) -> list: return [k["name"] for k in meta] @classmethod def _post_process(cls, meta: dict, records: list) -> list[dict[str, Any]]: columns = cls._process_columns_from_column_metadata(meta) rows = cls._process_rows_from_records(records) return cls._process_cells_from_rows_and_columns(columns, rows) def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: rows = self.execute_query_raw(query) if rows: return [BaseSqlDriver.RowResult(row) for row in rows] return None def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]: function_kwargs = {"Sql": query, "Database": self.database} if self.workgroup_name: function_kwargs["WorkgroupName"] = self.workgroup_name if self.cluster_identifier: function_kwargs["ClusterIdentifier"] = self.cluster_identifier if self.db_user: function_kwargs["DbUser"] = self.db_user if self.database_credentials_secret_arn: function_kwargs["SecretArn"] = self.database_credentials_secret_arn response = self.client.execute_statement(**function_kwargs) # pyright: ignore[reportArgumentType] response_id = response["Id"] statement = self.client.describe_statement(Id=response_id) while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]: time.sleep(self.wait_for_query_completion_sec) statement = self.client.describe_statement(Id=response_id) if statement["Status"] == "FINISHED": statement_result = self.client.get_statement_result(Id=response_id) results = statement_result.get("Records", []) while "NextToken" in statement_result: statement_result = self.client.get_statement_result( Id=response_id, NextToken=statement_result["NextToken"], ) results = results + response.get("Records", []) return self._post_process(statement_result["ColumnMetadata"], results) # pyright: ignore[reportArgumentType] if statement["Status"] in ["FAILED", "ABORTED"]: return None return None def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: function_kwargs = {"Database": self.database, "Table": table_name} if schema: function_kwargs["Schema"] = schema if self.workgroup_name: function_kwargs["WorkgroupName"] = self.workgroup_name if self.cluster_identifier: function_kwargs["ClusterIdentifier"] = self.cluster_identifier if self.db_user: function_kwargs["DbUser"] = self.db_user if self.database_credentials_secret_arn: function_kwargs["SecretArn"] = self.database_credentials_secret_arn response = self.client.describe_table(**function_kwargs) # pyright: ignore[reportArgumentType] return str([col["name"] for col in response["ColumnList"] if "name" in col])
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributecluster_identifier = field(default=None, kw_only=True)
class-attribute instance-attributedatabase = field(kw_only=True)
class-attribute instance-attributedatabase_credentials_secret_arn = field(default=None, kw_only=True)
class-attribute instance-attributedb_user = field(default=None, kw_only=True)
class-attribute instance-attributesession = field(kw_only=True)
class-attribute instance-attributewait_for_query_completion_sec = field(default=0.3, kw_only=True)
class-attribute instance-attributeworkgroup_name = field(default=None, kw_only=True)
class-attribute instance-attribute
_post_process(meta, records)classmethod
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod def _post_process(cls, meta: dict, records: list) -> list[dict[str, Any]]: columns = cls._process_columns_from_column_metadata(meta) rows = cls._process_rows_from_records(records) return cls._process_cells_from_rows_and_columns(columns, rows)
_process_cells_from_rows_and_columns(columns, rows)classmethod
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]: return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows]
_process_columns_from_column_metadata(meta)classmethod
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod def _process_columns_from_column_metadata(cls, meta: dict) -> list: return [k["name"] for k in meta]
_process_rows_from_records(records)classmethod
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod def _process_rows_from_records(cls, records: list) -> list[list]: return [[c[list(c.keys())[0]] for c in r] for r in records]
client()
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@lazy_property() def client(self) -> RedshiftDataAPIServiceClient: return self.session.client("redshift-data")
execute_query(query)
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: rows = self.execute_query_raw(query) if rows: return [BaseSqlDriver.RowResult(row) for row in rows] return None
execute_query_raw(query)
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]: function_kwargs = {"Sql": query, "Database": self.database} if self.workgroup_name: function_kwargs["WorkgroupName"] = self.workgroup_name if self.cluster_identifier: function_kwargs["ClusterIdentifier"] = self.cluster_identifier if self.db_user: function_kwargs["DbUser"] = self.db_user if self.database_credentials_secret_arn: function_kwargs["SecretArn"] = self.database_credentials_secret_arn response = self.client.execute_statement(**function_kwargs) # pyright: ignore[reportArgumentType] response_id = response["Id"] statement = self.client.describe_statement(Id=response_id) while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]: time.sleep(self.wait_for_query_completion_sec) statement = self.client.describe_statement(Id=response_id) if statement["Status"] == "FINISHED": statement_result = self.client.get_statement_result(Id=response_id) results = statement_result.get("Records", []) while "NextToken" in statement_result: statement_result = self.client.get_statement_result( Id=response_id, NextToken=statement_result["NextToken"], ) results = results + response.get("Records", []) return self._post_process(statement_result["ColumnMetadata"], results) # pyright: ignore[reportArgumentType] if statement["Status"] in ["FAILED", "ABORTED"]: return None return None
get_table_schema(table_name, schema=None)
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: function_kwargs = {"Database": self.database, "Table": table_name} if schema: function_kwargs["Schema"] = schema if self.workgroup_name: function_kwargs["WorkgroupName"] = self.workgroup_name if self.cluster_identifier: function_kwargs["ClusterIdentifier"] = self.cluster_identifier if self.db_user: function_kwargs["DbUser"] = self.db_user if self.database_credentials_secret_arn: function_kwargs["SecretArn"] = self.database_credentials_secret_arn response = self.client.describe_table(**function_kwargs) # pyright: ignore[reportArgumentType] return str([col["name"] for col in response["ColumnList"] if "name" in col])
validateparams(, workgroup_name)
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@workgroup_name.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None: if not self.cluster_identifier and not self.workgroup_name: raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`") if self.cluster_identifier and self.workgroup_name: raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")
AmazonS3FileManagerDriver
Bases:
BaseFileManagerDriver
Attributes
Name | Type | Description |
---|---|---|
session | Session | The boto3 session to use for S3 operations. |
bucket | str | The name of the S3 bucket. |
workdir | str | The absolute working directory (must start with "/"). List, load, and save operations will be performed relative to this directory. |
client | S3Client | The S3 client to use for S3 operations. |
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@define class AmazonS3FileManagerDriver(BaseFileManagerDriver): """AmazonS3FileManagerDriver can be used to list, load, and save files in an Amazon S3 bucket. Attributes: session: The boto3 session to use for S3 operations. bucket: The name of the S3 bucket. workdir: The absolute working directory (must start with "/"). List, load, and save operations will be performed relative to this directory. client: The S3 client to use for S3 operations. """ session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) bucket: str = field(kw_only=True) _workdir: str = field(default="/", kw_only=True, alias="workdir") _client: Optional[S3Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @property def workdir(self) -> str: if self._workdir.startswith("/"): return self._workdir return f"/{self._workdir}" @workdir.setter def workdir(self, value: str) -> None: self._workdir = value @lazy_property() def client(self) -> S3Client: return self.session.client("s3") def try_list_files(self, path: str) -> list[str]: full_key = self._to_dir_full_key(path) files_and_dirs = self._list_files_and_dirs(full_key) if len(files_and_dirs) == 0: if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0: raise NotADirectoryError raise FileNotFoundError return files_and_dirs def try_load_file(self, path: str) -> bytes: botocore = import_optional_dependency("botocore") full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError try: response = self.client.get_object(Bucket=self.bucket, Key=full_key) return response["Body"].read() except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in {"NoSuchKey", "404"}: raise FileNotFoundError from e raise e def try_save_file(self, path: str, value: bytes) -> str: full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError self.client.put_object(Bucket=self.bucket, Key=full_key, Body=value) return f"s3://{self.bucket}/{full_key}" def _to_full_key(self, path: str) -> str: path = path.lstrip("/") full_key = f"{self.workdir}/{path}" # Need to keep the trailing slash if it was there, # because it means the path is a directory. ended_with_slash = path.endswith("/") full_key = self._normpath(full_key) if ended_with_slash: full_key += "/" return full_key.lstrip("/") def _to_dir_full_key(self, path: str) -> str: full_key = self._to_full_key(path) # S3 "directories" always end with a slash, except for the root. if full_key != "" and not full_key.endswith("/"): full_key += "/" return full_key def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]: max_items = kwargs.get("max_items") pagination_config: PaginatorConfigTypeDef = {} if max_items is not None: pagination_config["MaxItems"] = max_items paginator = self.client.get_paginator("list_objects_v2") pages = paginator.paginate( Bucket=self.bucket, Prefix=full_key, Delimiter="/", PaginationConfig=pagination_config, ) files_and_dirs = [] for page in pages: for obj in page.get("CommonPrefixes", []): prefix = obj.get("Prefix", "") directory = prefix[len(full_key) :].rstrip("/") files_and_dirs.append(directory) for obj in page.get("Contents", []): key = obj.get("Key", "") file = key[len(full_key) :] files_and_dirs.append(file) return files_and_dirs def _is_a_directory(self, full_key: str) -> bool: botocore = import_optional_dependency("botocore") if full_key == "" or full_key.endswith("/"): return True try: self.client.head_object(Bucket=self.bucket, Key=full_key) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in {"NoSuchKey", "404"}: return len(self._list_files_and_dirs(full_key, max_items=1)) > 0 raise e return False def _normpath(self, path: str) -> str: unix_path = path.replace("\\", "/") parts = unix_path.split("/") stack = [] for part in parts: if part == "" or part == ".": continue if part == "..": if stack: stack.pop() else: stack.append(part) return "/".join(stack)
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attribute_workdir = field(default='/', kw_only=True, alias='workdir')
class-attribute instance-attributebucket = field(kw_only=True)
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attributeworkdir
property writable
_is_a_directory(full_key)
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def _is_a_directory(self, full_key: str) -> bool: botocore = import_optional_dependency("botocore") if full_key == "" or full_key.endswith("/"): return True try: self.client.head_object(Bucket=self.bucket, Key=full_key) except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in {"NoSuchKey", "404"}: return len(self._list_files_and_dirs(full_key, max_items=1)) > 0 raise e return False
_list_files_and_dirs(full_key, **kwargs)
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]: max_items = kwargs.get("max_items") pagination_config: PaginatorConfigTypeDef = {} if max_items is not None: pagination_config["MaxItems"] = max_items paginator = self.client.get_paginator("list_objects_v2") pages = paginator.paginate( Bucket=self.bucket, Prefix=full_key, Delimiter="/", PaginationConfig=pagination_config, ) files_and_dirs = [] for page in pages: for obj in page.get("CommonPrefixes", []): prefix = obj.get("Prefix", "") directory = prefix[len(full_key) :].rstrip("/") files_and_dirs.append(directory) for obj in page.get("Contents", []): key = obj.get("Key", "") file = key[len(full_key) :] files_and_dirs.append(file) return files_and_dirs
_normpath(path)
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def _normpath(self, path: str) -> str: unix_path = path.replace("\\", "/") parts = unix_path.split("/") stack = [] for part in parts: if part == "" or part == ".": continue if part == "..": if stack: stack.pop() else: stack.append(part) return "/".join(stack)
_to_dir_full_key(path)
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def _to_dir_full_key(self, path: str) -> str: full_key = self._to_full_key(path) # S3 "directories" always end with a slash, except for the root. if full_key != "" and not full_key.endswith("/"): full_key += "/" return full_key
_to_full_key(path)
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def _to_full_key(self, path: str) -> str: path = path.lstrip("/") full_key = f"{self.workdir}/{path}" # Need to keep the trailing slash if it was there, # because it means the path is a directory. ended_with_slash = path.endswith("/") full_key = self._normpath(full_key) if ended_with_slash: full_key += "/" return full_key.lstrip("/")
client()
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@lazy_property() def client(self) -> S3Client: return self.session.client("s3")
try_list_files(path)
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_list_files(self, path: str) -> list[str]: full_key = self._to_dir_full_key(path) files_and_dirs = self._list_files_and_dirs(full_key) if len(files_and_dirs) == 0: if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0: raise NotADirectoryError raise FileNotFoundError return files_and_dirs
try_load_file(path)
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_load_file(self, path: str) -> bytes: botocore = import_optional_dependency("botocore") full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError try: response = self.client.get_object(Bucket=self.bucket, Key=full_key) return response["Body"].read() except botocore.exceptions.ClientError as e: if e.response["Error"]["Code"] in {"NoSuchKey", "404"}: raise FileNotFoundError from e raise e
try_save_file(path, value)
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_save_file(self, path: str, value: bytes) -> str: full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError self.client.put_object(Bucket=self.bucket, Key=full_key, Body=value) return f"s3://{self.bucket}/{full_key}"
AmazonSageMakerJumpstartEmbeddingDriver
Bases:
BaseEmbeddingDriver
Source Code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
@define class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) endpoint: str = field(kw_only=True, metadata={"serializable": True}) custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) _client: Optional[SageMakerRuntimeClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> SageMakerRuntimeClient: return self.session.client("sagemaker-runtime") def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: payload = {"text_inputs": chunk, "mode": "embedding"} endpoint_response = self.client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", Body=json.dumps(payload).encode("utf-8"), CustomAttributes=self.custom_attributes, **( {"InferenceComponentName": self.inference_component_name} if self.inference_component_name is not None else {} ), ) response = json.loads(endpoint_response.get("Body").read().decode("utf-8")) if "embedding" in response: embedding = response["embedding"] if embedding: if isinstance(embedding[0], list): return embedding[0] return embedding raise ValueError("model response is empty") raise ValueError("invalid response from model")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributecustom_attributes = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeendpoint = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeinference_component_name = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attribute
client()
Source Code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
@lazy_property() def client(self) -> SageMakerRuntimeClient: return self.session.client("sagemaker-runtime")
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: payload = {"text_inputs": chunk, "mode": "embedding"} endpoint_response = self.client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", Body=json.dumps(payload).encode("utf-8"), CustomAttributes=self.custom_attributes, **( {"InferenceComponentName": self.inference_component_name} if self.inference_component_name is not None else {} ), ) response = json.loads(endpoint_response.get("Body").read().decode("utf-8")) if "embedding" in response: embedding = response["embedding"] if embedding: if isinstance(embedding[0], list): return embedding[0] return embedding raise ValueError("model response is empty") raise ValueError("invalid response from model")
AmazonSageMakerJumpstartPromptDriver
Bases:
BasePromptDriver
Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@define class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) endpoint: str = field(kw_only=True, metadata={"serializable": True}) custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True, ), kw_only=True, ) structured_output_strategy: StructuredOutputStrategy = field( default="rule", kw_only=True, metadata={"serializable": True} ) _client: Optional[Any] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value != "rule": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @lazy_property() def client(self) -> Any: return self.session.client("sagemaker-runtime") @stream.validator # pyright: ignore[reportAttributeAccessIssue] def validate_stream(self, _: Attribute, stream: bool) -> None: # noqa: FBT001 if stream: raise ValueError("streaming is not supported") @observable def try_run(self, prompt_stack: PromptStack) -> Message: payload = { "inputs": self.prompt_stack_to_string(prompt_stack), "parameters": {**self._base_params(prompt_stack)}, } logger.debug(payload) response = self.client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", Body=json.dumps(payload), CustomAttributes=self.custom_attributes, **( {"InferenceComponentName": self.inference_component_name} if self.inference_component_name is not None else {} ), ) decoded_body = json.loads(response["Body"].read().decode("utf8")) logger.debug(decoded_body) if isinstance(decoded_body, list): if decoded_body: generated_text = decoded_body[0]["generated_text"] else: raise ValueError("model response is empty") else: generated_text = decoded_body["generated_text"] input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) # pyright: ignore[reportArgumentType] return Message( content=[TextMessageContent(TextArtifact(generated_text))], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise NotImplementedError("streaming is not supported") def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) # pyright: ignore[reportArgumentType] def _base_params(self, prompt_stack: PromptStack) -> dict: return { "temperature": self.temperature, "max_new_tokens": self.max_tokens, "do_sample": True, "eos_token_id": self.tokenizer.tokenizer.eos_token_id, "stop_strings": self.tokenizer.stop_sequences, "return_full_text": False, **self.extra_params, } def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: messages.append({"role": message.role, "content": message.to_text()}) return messages def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): return tokens # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int]. raise ValueError("Invalid output type.")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributecustom_attributes = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeendpoint = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeinference_component_name = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemax_tokens = field(default=250, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attributestream = field(default=False, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestructured_output_strategy = field(default='rule', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True)
class-attribute instance-attribute
__prompt_stack_to_tokens(prompt_stack)
Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): return tokens # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int]. raise ValueError("Invalid output type.")
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: return { "temperature": self.temperature, "max_new_tokens": self.max_tokens, "do_sample": True, "eos_token_id": self.tokenizer.tokenizer.eos_token_id, "stop_strings": self.tokenizer.stop_sequences, "return_full_text": False, **self.extra_params, }
_prompt_stack_to_messages(prompt_stack)
Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: messages.append({"role": message.role, "content": message.to_text()}) return messages
client()
Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@lazy_property() def client(self) -> Any: return self.session.client("sagemaker-runtime")
prompt_stack_to_string(prompt_stack)
Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) # pyright: ignore[reportArgumentType]
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: payload = { "inputs": self.prompt_stack_to_string(prompt_stack), "parameters": {**self._base_params(prompt_stack)}, } logger.debug(payload) response = self.client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", Body=json.dumps(payload), CustomAttributes=self.custom_attributes, **( {"InferenceComponentName": self.inference_component_name} if self.inference_component_name is not None else {} ), ) decoded_body = json.loads(response["Body"].read().decode("utf8")) logger.debug(decoded_body) if isinstance(decoded_body, list): if decoded_body: generated_text = decoded_body[0]["generated_text"] else: raise ValueError("model response is empty") else: generated_text = decoded_body["generated_text"] input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) # pyright: ignore[reportArgumentType] return Message( content=[TextMessageContent(TextArtifact(generated_text))], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), )
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise NotImplementedError("streaming is not supported")
validatestream(, stream)
Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@stream.validator # pyright: ignore[reportAttributeAccessIssue] def validate_stream(self, _: Attribute, stream: bool) -> None: # noqa: FBT001 if stream: raise ValueError("streaming is not supported")
validatestructured_output_strategy(, value)
Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value != "rule": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value
AmazonSqsEventListenerDriver
Bases:
BaseEventListenerDriver
Source Code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
@define class AmazonSqsEventListenerDriver(BaseEventListenerDriver): queue_url: str = field(kw_only=True) session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) _client: Optional[SQSClient] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> SQSClient: return self.session.client("sqs") def try_publish_event_payload(self, event_payload: dict) -> None: self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload)) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: entries: Sequence[SendMessageBatchRequestEntryTypeDef] = [ {"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)} for event_payload in event_payload_batch ] self.client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributequeue_url = field(kw_only=True)
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attribute
client()
Source Code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
@lazy_property() def client(self) -> SQSClient: return self.session.client("sqs")
try_publish_event_payload(event_payload)
Source Code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None: self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))
try_publish_event_payload_batch(event_payload_batch)
Source Code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: entries: Sequence[SendMessageBatchRequestEntryTypeDef] = [ {"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)} for event_payload in event_payload_batch ] self.client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)
AnthropicPromptDriver
Bases:
BasePromptDriver
Attributes
Name | Type | Description |
---|---|---|
api_key | Optional[str] | Anthropic API key. |
model | str | Anthropic model name. |
client | Client | Custom Anthropic client. |
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
@define class AnthropicPromptDriver(BasePromptDriver): """Anthropic Prompt Driver. Attributes: api_key: Anthropic API key. model: Anthropic model name. client: Custom `Anthropic` client. """ api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) model: str = field(kw_only=True, metadata={"serializable": True}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True, ) top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True}) top_k: int = field(default=250, kw_only=True, metadata={"serializable": True}) tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: StructuredOutputStrategy = field( default="tool", kw_only=True, metadata={"serializable": True} ) max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True}) _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> Client: return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.messages.create(**params) logger.debug(response.model_dump()) return Message( content=[self.__to_prompt_stack_message_content(content) for content in response.content], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens), ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = {**self._base_params(prompt_stack), "stream": True} logger.debug(params) events = self.client.messages.create(**params) for event in events: logger.debug(event) if event.type == "content_block_delta" or event.type == "content_block_start": yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event)) elif event.type == "message_start": yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=event.message.usage.input_tokens)) elif event.type == "message_delta": yield DeltaMessage(usage=DeltaMessage.Usage(output_tokens=event.usage.output_tokens)) def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self.__to_anthropic_messages([i for i in prompt_stack.messages if not i.is_system()]) system_messages = prompt_stack.system_messages system_message = system_messages[0].to_text() if system_messages else None params = { "model": self.model, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "top_p": self.top_p, "top_k": self.top_k, "max_tokens": self.max_tokens, "messages": messages, **({"system": system_message} if system_message else {}), **self.extra_params, } if prompt_stack.tools and self.use_native_tools: params["tool_choice"] = self.tool_choice if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_choice"] = {"type": "any"} params["tools"] = self.__to_anthropic_tools(prompt_stack.tools) return params def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]: return [ {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)} for message in messages ] def __to_anthropic_role(self, message: Message) -> str: if message.is_assistant(): return "assistant" return "user" def __to_anthropic_tools(self, tools: list[BaseTool]) -> list[dict]: tool_schemas = [ { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "input_schema": tool.to_activity_json_schema(activity, "Input Schema"), } for tool in tools for activity in tool.activities() ] # Anthropic doesn't support $schema and $id for tool_schema in tool_schemas: del tool_schema["input_schema"]["$schema"] del tool_schema["input_schema"]["$id"] return tool_schemas def __to_anthropic_content(self, message: Message) -> str | list[dict]: if message.has_all_content_type(TextMessageContent): return message.to_text() return [self.__to_anthropic_message_content(content) for content in message.content] def __to_anthropic_message_content(self, content: BaseMessageContent) -> dict: if isinstance(content, TextMessageContent): return {"type": "text", "text": content.artifact.value} if isinstance(content, ImageMessageContent): artifact = content.artifact if isinstance(artifact, ImageArtifact): return { "type": "image", "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64}, } if isinstance(artifact, ImageUrlArtifact): return { "type": "image", "source": {"type": "url", "url": artifact.value}, } raise ValueError(f"Unsupported image artifact type: {type(artifact)}") if isinstance(content, ActionCallMessageContent): action = content.artifact.value return {"type": "tool_use", "id": action.tag, "name": action.to_native_tool_name(), "input": action.input} if isinstance(content, ActionResultMessageContent): artifact = content.artifact if isinstance(artifact, ListArtifact): message_content = [self.__to_anthropic_tool_result_content(artifact) for artifact in artifact.value] else: message_content = [self.__to_anthropic_tool_result_content(artifact)] return { "type": "tool_result", "tool_use_id": content.action.tag, "content": message_content, "is_error": isinstance(artifact, ErrorArtifact), } return content.artifact.value def __to_anthropic_tool_result_content(self, artifact: BaseArtifact) -> dict: if isinstance(artifact, ImageArtifact): return { "type": "image", "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64}, } if isinstance(artifact, (TextArtifact, ErrorArtifact, InfoArtifact)): return {"type": "text", "text": artifact.to_text()} raise ValueError(f"Unsupported tool result artifact type: {type(artifact)}") def __to_prompt_stack_message_content(self, content: ContentBlock) -> BaseMessageContent: if content.type == "text": return TextMessageContent(TextArtifact(content.text)) if content.type == "tool_use": name, path = ToolAction.from_native_tool_name(content.name) return ActionCallMessageContent( artifact=ActionArtifact( value=ToolAction(tag=content.id, name=name, path=path, input=content.input), # pyright: ignore[reportArgumentType] ), ) raise ValueError(f"Unsupported message content type: {content.type}") def __to_prompt_stack_delta_message_content( self, event: ContentBlockDeltaEvent | ContentBlockStartEvent, ) -> BaseDeltaMessageContent: if event.type == "content_block_start": content_block = event.content_block if content_block.type == "tool_use": name, path = ToolAction.from_native_tool_name(content_block.name) return ActionCallDeltaMessageContent(index=event.index, tag=content_block.id, name=name, path=path) if content_block.type == "text": return TextDeltaMessageContent(content_block.text, index=event.index) raise ValueError(f"Unsupported content block type: {content_block.type}") if event.type == "content_block_delta": content_block_delta = event.delta if content_block_delta.type == "text_delta": return TextDeltaMessageContent(content_block_delta.text, index=event.index) if content_block_delta.type == "input_json_delta": return ActionCallDeltaMessageContent(index=event.index, partial_input=content_block_delta.partial_json) raise ValueError(f"Unsupported message content type: {event}") raise ValueError(f"Unsupported message content type: {event}")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(kw_only=True, default=None, metadata={'serializable': False})
class-attribute instance-attributemax_tokens = field(default=1000, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestructured_output_strategy = field(default='tool', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True)
class-attribute instance-attributetool_choice = field(default=Factory(lambda: {'type': 'auto'}), kw_only=True, metadata={'serializable': False})
class-attribute instance-attributetop_k = field(default=250, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetop_p = field(default=0.999, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeuse_native_tools = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
__to_anthropic_content(message)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_content(self, message: Message) -> str | list[dict]: if message.has_all_content_type(TextMessageContent): return message.to_text() return [self.__to_anthropic_message_content(content) for content in message.content]
__to_anthropic_message_content(content)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_message_content(self, content: BaseMessageContent) -> dict: if isinstance(content, TextMessageContent): return {"type": "text", "text": content.artifact.value} if isinstance(content, ImageMessageContent): artifact = content.artifact if isinstance(artifact, ImageArtifact): return { "type": "image", "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64}, } if isinstance(artifact, ImageUrlArtifact): return { "type": "image", "source": {"type": "url", "url": artifact.value}, } raise ValueError(f"Unsupported image artifact type: {type(artifact)}") if isinstance(content, ActionCallMessageContent): action = content.artifact.value return {"type": "tool_use", "id": action.tag, "name": action.to_native_tool_name(), "input": action.input} if isinstance(content, ActionResultMessageContent): artifact = content.artifact if isinstance(artifact, ListArtifact): message_content = [self.__to_anthropic_tool_result_content(artifact) for artifact in artifact.value] else: message_content = [self.__to_anthropic_tool_result_content(artifact)] return { "type": "tool_result", "tool_use_id": content.action.tag, "content": message_content, "is_error": isinstance(artifact, ErrorArtifact), } return content.artifact.value
__to_anthropic_messages(messages)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]: return [ {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)} for message in messages ]
__to_anthropic_role(message)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_role(self, message: Message) -> str: if message.is_assistant(): return "assistant" return "user"
__to_anthropic_tool_result_content(artifact)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_tool_result_content(self, artifact: BaseArtifact) -> dict: if isinstance(artifact, ImageArtifact): return { "type": "image", "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64}, } if isinstance(artifact, (TextArtifact, ErrorArtifact, InfoArtifact)): return {"type": "text", "text": artifact.to_text()} raise ValueError(f"Unsupported tool result artifact type: {type(artifact)}")
__to_anthropic_tools(tools)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_tools(self, tools: list[BaseTool]) -> list[dict]: tool_schemas = [ { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "input_schema": tool.to_activity_json_schema(activity, "Input Schema"), } for tool in tools for activity in tool.activities() ] # Anthropic doesn't support $schema and $id for tool_schema in tool_schemas: del tool_schema["input_schema"]["$schema"] del tool_schema["input_schema"]["$id"] return tool_schemas
__to_prompt_stack_delta_message_content(event)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_prompt_stack_delta_message_content( self, event: ContentBlockDeltaEvent | ContentBlockStartEvent, ) -> BaseDeltaMessageContent: if event.type == "content_block_start": content_block = event.content_block if content_block.type == "tool_use": name, path = ToolAction.from_native_tool_name(content_block.name) return ActionCallDeltaMessageContent(index=event.index, tag=content_block.id, name=name, path=path) if content_block.type == "text": return TextDeltaMessageContent(content_block.text, index=event.index) raise ValueError(f"Unsupported content block type: {content_block.type}") if event.type == "content_block_delta": content_block_delta = event.delta if content_block_delta.type == "text_delta": return TextDeltaMessageContent(content_block_delta.text, index=event.index) if content_block_delta.type == "input_json_delta": return ActionCallDeltaMessageContent(index=event.index, partial_input=content_block_delta.partial_json) raise ValueError(f"Unsupported message content type: {event}") raise ValueError(f"Unsupported message content type: {event}")
__to_prompt_stack_message_content(content)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_prompt_stack_message_content(self, content: ContentBlock) -> BaseMessageContent: if content.type == "text": return TextMessageContent(TextArtifact(content.text)) if content.type == "tool_use": name, path = ToolAction.from_native_tool_name(content.name) return ActionCallMessageContent( artifact=ActionArtifact( value=ToolAction(tag=content.id, name=name, path=path, input=content.input), # pyright: ignore[reportArgumentType] ), ) raise ValueError(f"Unsupported message content type: {content.type}")
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self.__to_anthropic_messages([i for i in prompt_stack.messages if not i.is_system()]) system_messages = prompt_stack.system_messages system_message = system_messages[0].to_text() if system_messages else None params = { "model": self.model, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "top_p": self.top_p, "top_k": self.top_k, "max_tokens": self.max_tokens, "messages": messages, **({"system": system_message} if system_message else {}), **self.extra_params, } if prompt_stack.tools and self.use_native_tools: params["tool_choice"] = self.tool_choice if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_choice"] = {"type": "any"} params["tools"] = self.__to_anthropic_tools(prompt_stack.tools) return params
client()
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
@lazy_property() def client(self) -> Client: return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.messages.create(**params) logger.debug(response.model_dump()) return Message( content=[self.__to_prompt_stack_message_content(content) for content in response.content], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens), )
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = {**self._base_params(prompt_stack), "stream": True} logger.debug(params) events = self.client.messages.create(**params) for event in events: logger.debug(event) if event.type == "content_block_delta" or event.type == "content_block_start": yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event)) elif event.type == "message_start": yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=event.message.usage.input_tokens)) elif event.type == "message_delta": yield DeltaMessage(usage=DeltaMessage.Usage(output_tokens=event.usage.output_tokens))
validatestructured_output_strategy(, value)
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value
AstraDbVectorStoreDriver
Bases:
BaseVectorStoreDriver
Attributes
Name | Type | Description |
---|---|---|
embedding_driver | BaseEmbeddingDriver | a griptape.drivers.BaseEmbeddingDriver for embedding computations within the store |
api_endpoint | str | the "API Endpoint" for the Astra DB instance. |
token | Optional[str | TokenProvider] | a Database Token ("AstraCS:...") secret to access Astra DB. An instance of astrapy.authentication.TokenProvider is also accepted. |
collection_name | str | the name of the collection on Astra DB. The collection must have been created beforehand, and support vectors with a vector dimension matching the embeddings being used by this driver. |
environment | Optional[str] | the environment ("prod", "hcd", ...) hosting the target Data API. It can be omitted for production Astra DB targets. See astrapy.constants.Environment for allowed values. |
astra_db_namespace | Optional[str] | optional specification of the namespace (in the Astra database) for the data. Note: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store. |
caller_name | str | the name of the caller for the Astra DB client. Defaults to "griptape". |
client | DataAPIClient | an instance of astrapy.DataAPIClient for the Astra DB. |
collection | Collection | an instance of astrapy.Collection for the Astra DB. |
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
@define class AstraDbVectorStoreDriver(BaseVectorStoreDriver): """A Vector Store Driver for Astra DB. Attributes: embedding_driver: a `griptape.drivers.BaseEmbeddingDriver` for embedding computations within the store api_endpoint: the "API Endpoint" for the Astra DB instance. token: a Database Token ("AstraCS:...") secret to access Astra DB. An instance of `astrapy.authentication.TokenProvider` is also accepted. collection_name: the name of the collection on Astra DB. The collection must have been created beforehand, and support vectors with a vector dimension matching the embeddings being used by this driver. environment: the environment ("prod", "hcd", ...) hosting the target Data API. It can be omitted for production Astra DB targets. See `astrapy.constants.Environment` for allowed values. astra_db_namespace: optional specification of the namespace (in the Astra database) for the data. *Note*: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store. caller_name: the name of the caller for the Astra DB client. Defaults to "griptape". client: an instance of `astrapy.DataAPIClient` for the Astra DB. collection: an instance of `astrapy.Collection` for the Astra DB. """ api_endpoint: str = field(kw_only=True, metadata={"serializable": True}) token: Optional[str | astrapy.authentication.TokenProvider] = field( kw_only=True, default=None, metadata={"serializable": False} ) collection_name: str = field(kw_only=True, metadata={"serializable": True}) environment: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True}) astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) caller_name: str = field(default="griptape", kw_only=True, metadata={"serializable": False}) _client: Optional[astrapy.DataAPIClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) _collection: Optional[astrapy.Collection] = field( default=None, kw_only=True, alias="collection", metadata={"serializable": False} ) @lazy_property() def client(self) -> astrapy.DataAPIClient: astrapy = import_optional_dependency("astrapy") return astrapy.DataAPIClient( callers=[(self.caller_name, None)], environment=self.environment or astrapy.utils.unset.UnsetType(), ) @lazy_property() def collection(self) -> astrapy.Collection: if self.token is None: raise ValueError("Astra DB token must be provided.") return self.client.get_database( self.api_endpoint, token=self.token, keyspace=self.astra_db_namespace ).get_collection(self.collection_name) def delete_vector(self, vector_id: str) -> None: """Delete a vector from Astra DB store. The method succeeds regardless of whether a vector with the provided ID was actually stored or not in the first place. Args: vector_id: ID of the vector to delete. """ self.collection.delete_one({"_id": vector_id}) def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs: Any, ) -> str: """Write a vector to the Astra DB store. In case the provided ID exists already, an overwrite will take place. Args: vector: the vector to be upserted. vector_id: the ID for the vector to store. If omitted, a server-provided new ID will be employed. namespace: a namespace (a grouping within the vector store) to assign the vector to. meta: a metadata dictionary associated to the vector. kwargs: additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning. Returns: the ID of the written vector (str). """ document = { k: v for k, v in {"$vector": vector, "_id": vector_id, "keyspace": namespace, "meta": meta}.items() if v is not None } if vector_id is not None: self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True) return vector_id insert_result = self.collection.insert_one(document) return insert_result.inserted_id def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Load a single vector entry from the Astra DB store given its ID. Args: vector_id: the ID of the required vector. namespace: a namespace, within the vector store, to constrain the search. Returns: The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None. """ find_filter = {k: v for k, v in {"_id": vector_id, "keyspace": namespace}.items() if v is not None} match = self.collection.find_one(filter=find_filter, projection={"*": 1}) if match is not None: return BaseVectorStoreDriver.Entry( id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("keyspace") ) return None def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Load entries from the Astra DB store. Args: namespace: a namespace, within the vector store, to constrain the search. Returns: A list of vector (`BaseVectorStoreDriver.Entry`) entries. """ find_filter: dict[str, str] = {} if namespace is None else {"keyspace": namespace} return [ BaseVectorStoreDriver.Entry( id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("keyspace") ) for match in self.collection.find(filter=find_filter, projection={"*": 1}) ] def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs: Any, ) -> list[BaseVectorStoreDriver.Entry]: """Run a similarity search on the Astra DB store, based on a vector list. Args: vector: the vector to be queried. count: the maximum number of results to return. If omitted, defaults will apply. namespace: the namespace to filter results by. include_vectors: whether to include vector data in the results. kwargs: additional keyword arguments. Currently only the free-form dict `filter` is recognized (and goes straight to the Data API query); others will generate a warning and be ignored. Returns: A list of vector (`BaseVectorStoreDriver.Entry`) entries, with their `score` attribute set to the vector similarity to the query. """ query_filter: Optional[dict[str, Any]] = kwargs.get("filter") find_filter_ns: dict[str, Any] = {} if namespace is None else {"keyspace": namespace} find_filter = {**(query_filter or {}), **find_filter_ns} find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT matches = self.collection.find( filter=find_filter, sort={"$vector": vector}, limit=ann_limit, projection=find_projection, include_similarity=True, ) return [ BaseVectorStoreDriver.Entry( id=match["_id"], vector=match.get("$vector"), score=match["$similarity"], meta=match.get("meta"), namespace=match.get("keyspace"), ) for match in matches ]
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attribute_collection = field(default=None, kw_only=True, alias='collection', metadata={'serializable': False})
class-attribute instance-attributeapi_endpoint = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeastra_db_namespace = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributecaller_name = field(default='griptape', kw_only=True, metadata={'serializable': False})
class-attribute instance-attributecollection_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeenvironment = field(kw_only=True, default=None, metadata={'serializable': True})
class-attribute instance-attributetoken = field(kw_only=True, default=None, metadata={'serializable': False})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
@lazy_property() def client(self) -> astrapy.DataAPIClient: astrapy = import_optional_dependency("astrapy") return astrapy.DataAPIClient( callers=[(self.caller_name, None)], environment=self.environment or astrapy.utils.unset.UnsetType(), )
collection()
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
@lazy_property() def collection(self) -> astrapy.Collection: if self.token is None: raise ValueError("Astra DB token must be provided.") return self.client.get_database( self.api_endpoint, token=self.token, keyspace=self.astra_db_namespace ).get_collection(self.collection_name)
delete_vector(vector_id)
Delete a vector from Astra DB store.
The method succeeds regardless of whether a vector with the provided ID was actually stored or not in the first place.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector_id | str | ID of the vector to delete. | required |
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
def delete_vector(self, vector_id: str) -> None: """Delete a vector from Astra DB store. The method succeeds regardless of whether a vector with the provided ID was actually stored or not in the first place. Args: vector_id: ID of the vector to delete. """ self.collection.delete_one({"_id": vector_id})
load_entries(*, namespace=None)
Load entries from the Astra DB store.
Parameters
Name | Type | Description | Default |
---|---|---|---|
namespace | Optional[str] | a namespace, within the vector store, to constrain the search. | None |
Returns
Type | Description |
---|---|
list[Entry] | A list of vector (BaseVectorStoreDriver.Entry ) entries. |
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Load entries from the Astra DB store. Args: namespace: a namespace, within the vector store, to constrain the search. Returns: A list of vector (`BaseVectorStoreDriver.Entry`) entries. """ find_filter: dict[str, str] = {} if namespace is None else {"keyspace": namespace} return [ BaseVectorStoreDriver.Entry( id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("keyspace") ) for match in self.collection.find(filter=find_filter, projection={"*": 1}) ]
load_entry(vector_id, *, namespace=None)
Load a single vector entry from the Astra DB store given its ID.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector_id | str | the ID of the required vector. | required |
namespace | Optional[str] | a namespace, within the vector store, to constrain the search. | None |
Returns
Type | Description |
---|---|
Optional[Entry] | The vector entry (a BaseVectorStoreDriver.Entry ) if found, otherwise None. |
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Load a single vector entry from the Astra DB store given its ID. Args: vector_id: the ID of the required vector. namespace: a namespace, within the vector store, to constrain the search. Returns: The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None. """ find_filter = {k: v for k, v in {"_id": vector_id, "keyspace": namespace}.items() if v is not None} match = self.collection.find_one(filter=find_filter, projection={"*": 1}) if match is not None: return BaseVectorStoreDriver.Entry( id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("keyspace") ) return None
query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)
Run a similarity search on the Astra DB store, based on a vector list.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector | list[float] | the vector to be queried. | required |
count | Optional[int] | the maximum number of results to return. If omitted, defaults will apply. | None |
namespace | Optional[str] | the namespace to filter results by. | None |
include_vectors | bool | whether to include vector data in the results. | False |
kwargs | Any | additional keyword arguments. Currently only the free-form dict filter is recognized (and goes straight to the Data API query); others will generate a warning and be ignored. | {} |
Returns
Type | Description |
---|---|
list[Entry] | A list of vector (BaseVectorStoreDriver.Entry ) entries, |
list[Entry] | with their score attribute set to the vector similarity to the query. |
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs: Any, ) -> list[BaseVectorStoreDriver.Entry]: """Run a similarity search on the Astra DB store, based on a vector list. Args: vector: the vector to be queried. count: the maximum number of results to return. If omitted, defaults will apply. namespace: the namespace to filter results by. include_vectors: whether to include vector data in the results. kwargs: additional keyword arguments. Currently only the free-form dict `filter` is recognized (and goes straight to the Data API query); others will generate a warning and be ignored. Returns: A list of vector (`BaseVectorStoreDriver.Entry`) entries, with their `score` attribute set to the vector similarity to the query. """ query_filter: Optional[dict[str, Any]] = kwargs.get("filter") find_filter_ns: dict[str, Any] = {} if namespace is None else {"keyspace": namespace} find_filter = {**(query_filter or {}), **find_filter_ns} find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT matches = self.collection.find( filter=find_filter, sort={"$vector": vector}, limit=ann_limit, projection=find_projection, include_similarity=True, ) return [ BaseVectorStoreDriver.Entry( id=match["_id"], vector=match.get("$vector"), score=match["$similarity"], meta=match.get("meta"), namespace=match.get("keyspace"), ) for match in matches ]
upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)
Write a vector to the Astra DB store.
In case the provided ID exists already, an overwrite will take place.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector | list[float] | the vector to be upserted. | required |
vector_id | Optional[str] | the ID for the vector to store. If omitted, a server-provided new ID will be employed. | None |
namespace | Optional[str] | a namespace (a grouping within the vector store) to assign the vector to. | None |
meta | Optional[dict] | a metadata dictionary associated to the vector. | None |
kwargs | Any | additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning. | {} |
Returns
Type | Description |
---|---|
str | the ID of the written vector (str). |
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs: Any, ) -> str: """Write a vector to the Astra DB store. In case the provided ID exists already, an overwrite will take place. Args: vector: the vector to be upserted. vector_id: the ID for the vector to store. If omitted, a server-provided new ID will be employed. namespace: a namespace (a grouping within the vector store) to assign the vector to. meta: a metadata dictionary associated to the vector. kwargs: additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning. Returns: the ID of the written vector (str). """ document = { k: v for k, v in {"$vector": vector, "_id": vector_id, "keyspace": namespace, "meta": meta}.items() if v is not None } if vector_id is not None: self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True) return vector_id insert_result = self.collection.insert_one(document) return insert_result.inserted_id
AwsIotCoreEventListenerDriver
Bases:
BaseEventListenerDriver
Source Code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
@define class AwsIotCoreEventListenerDriver(BaseEventListenerDriver): iot_endpoint: str = field(kw_only=True) topic: str = field(kw_only=True) session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) _client: Optional[IoTDataPlaneClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> IoTDataPlaneClient: return self.session.client("iot-data") def try_publish_event_payload(self, event_payload: dict) -> None: self.client.publish(topic=self.topic, payload=json.dumps(event_payload)) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: self.client.publish(topic=self.topic, payload=json.dumps(event_payload_batch))
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeiot_endpoint = field(kw_only=True)
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attributetopic = field(kw_only=True)
class-attribute instance-attribute
client()
Source Code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
@lazy_property() def client(self) -> IoTDataPlaneClient: return self.session.client("iot-data")
try_publish_event_payload(event_payload)
Source Code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None: self.client.publish(topic=self.topic, payload=json.dumps(event_payload))
try_publish_event_payload_batch(event_payload_batch)
Source Code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: self.client.publish(topic=self.topic, payload=json.dumps(event_payload_batch))
AzureMongoDbVectorStoreDriver
Bases:
MongoDbAtlasVectorStoreDriver
Source Code in griptape/drivers/vector/azure_mongodb_vector_store_driver.py
@define class AzureMongoDbVectorStoreDriver(MongoDbAtlasVectorStoreDriver): """A Vector Store Driver for CosmosDB with MongoDB vCore API.""" def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, offset: Optional[int] = None, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Queries the MongoDB collection for documents that match the provided vector list. Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. """ collection = self.get_collection() count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT offset = offset or 0 pipeline = [] pipeline.append( { "$search": { "cosmosSearch": { "vector": vector, "path": self.vector_path, "k": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES), }, "returnStoredSource": True, }, }, ) if namespace: pipeline.append({"$match": {"namespace": namespace}}) pipeline.append({"$project": {"similarityScore": {"$meta": "searchScore"}, "document": "$$ROOT"}}) return [ BaseVectorStoreDriver.Entry( id=str(doc["_id"]), vector=doc[self.vector_path] if include_vectors else [], score=doc["similarityScore"], meta=doc["document"]["meta"], namespace=namespace, ) for doc in collection.aggregate(pipeline) ]
query_vector(vector, *, count=None, namespace=None, include_vectors=False, offset=None, **kwargs)
Source Code in griptape/drivers/vector/azure_mongodb_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, offset: Optional[int] = None, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Queries the MongoDB collection for documents that match the provided vector list. Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. """ collection = self.get_collection() count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT offset = offset or 0 pipeline = [] pipeline.append( { "$search": { "cosmosSearch": { "vector": vector, "path": self.vector_path, "k": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES), }, "returnStoredSource": True, }, }, ) if namespace: pipeline.append({"$match": {"namespace": namespace}}) pipeline.append({"$project": {"similarityScore": {"$meta": "searchScore"}, "document": "$$ROOT"}}) return [ BaseVectorStoreDriver.Entry( id=str(doc["_id"]), vector=doc[self.vector_path] if include_vectors else [], score=doc["similarityScore"], meta=doc["document"]["meta"], namespace=namespace, ) for doc in collection.aggregate(pipeline) ]
AzureOpenAiChatPromptDriver
Bases:
OpenAiChatPromptDriver
Attributes
Name | Type | Description |
---|---|---|
azure_deployment | str | An optional Azure OpenAi deployment id. Defaults to the model name. |
azure_endpoint | str | An Azure OpenAi endpoint. |
azure_ad_token | Optional[str] | An optional Azure Active Directory token. |
azure_ad_token_provider | Optional[Callable[[], str]] | An optional Azure Active Directory token provider. |
api_version | str | An Azure OpenAi API version. |
client | AzureOpenAI | An openai.AzureOpenAI client. |
Source Code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@define class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver): """Azure OpenAi Chat Prompt Driver. Attributes: azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name. azure_endpoint: An Azure OpenAi endpoint. azure_ad_token: An optional Azure Active Directory token. azure_ad_token_provider: An optional Azure Active Directory token provider. api_version: An Azure OpenAi API version. client: An `openai.AzureOpenAI` client. """ azure_deployment: str = field( kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}, ) azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) azure_ad_token_provider: Optional[Callable[[], str]] = field( kw_only=True, default=None, metadata={"serializable": False}, ) api_version: str = field(default="2024-10-21", kw_only=True, metadata={"serializable": True}) _client: Optional[openai.AzureOpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.AzureOpenAI: return openai.AzureOpenAI( organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider, ) def _base_params(self, prompt_stack: PromptStack) -> dict: params = super()._base_params(prompt_stack) if self.api_version < "2024-02-01" and "seed" in params: del params["seed"] if self.api_version < "2024-10-21": if "stream_options" in params: del params["stream_options"] if "parallel_tool_calls" in params: del params["parallel_tool_calls"] # TODO: Add once Azure supports modalities if "modalities" in params: del params["modalities"] return params
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_version = field(default='2024-10-21', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeazure_ad_token = field(kw_only=True, default=None, metadata={'serializable': False})
class-attribute instance-attributeazure_ad_token_provider = field(kw_only=True, default=None, metadata={'serializable': False})
class-attribute instance-attributeazure_deployment = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True})
class-attribute instance-attributeazure_endpoint = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: params = super()._base_params(prompt_stack) if self.api_version < "2024-02-01" and "seed" in params: del params["seed"] if self.api_version < "2024-10-21": if "stream_options" in params: del params["stream_options"] if "parallel_tool_calls" in params: del params["parallel_tool_calls"] # TODO: Add once Azure supports modalities if "modalities" in params: del params["modalities"] return params
client()
Source Code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@lazy_property() def client(self) -> openai.AzureOpenAI: return openai.AzureOpenAI( organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider, )
AzureOpenAiEmbeddingDriver
Bases:
OpenAiEmbeddingDriver
Attributes
Name | Type | Description |
---|---|---|
azure_deployment | str | An optional Azure OpenAi deployment id. Defaults to the model name. |
azure_endpoint | str | An Azure OpenAi endpoint. |
azure_ad_token | Optional[str] | An optional Azure Active Directory token. |
azure_ad_token_provider | Optional[Callable[[], str]] | An optional Azure Active Directory token provider. |
api_version | str | An Azure OpenAi API version. |
tokenizer | OpenAiTokenizer | An OpenAiTokenizer . |
client | AzureOpenAI | An openai.AzureOpenAI client. |
Source Code in griptape/drivers/embedding/azure_openai_embedding_driver.py
@define class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver): """Azure OpenAi Embedding Driver. Attributes: azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name. azure_endpoint: An Azure OpenAi endpoint. azure_ad_token: An optional Azure Active Directory token. azure_ad_token_provider: An optional Azure Active Directory token provider. api_version: An Azure OpenAi API version. tokenizer: An `OpenAiTokenizer`. client: An `openai.AzureOpenAI` client. """ azure_deployment: str = field( kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}, ) azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) azure_ad_token_provider: Optional[Callable[[], str]] = field( kw_only=True, default=None, metadata={"serializable": False}, ) api_version: str = field(default="2024-10-21", kw_only=True, metadata={"serializable": True}) tokenizer: OpenAiTokenizer = field( default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True, ) _client: Optional[openai.AzureOpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.AzureOpenAI: return openai.AzureOpenAI( organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider, )
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_version = field(default='2024-10-21', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeazure_ad_token = field(kw_only=True, default=None, metadata={'serializable': False})
class-attribute instance-attributeazure_ad_token_provider = field(kw_only=True, default=None, metadata={'serializable': False})
class-attribute instance-attributeazure_deployment = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True})
class-attribute instance-attributeazure_endpoint = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True)
class-attribute instance-attribute
client()
Source Code in griptape/drivers/embedding/azure_openai_embedding_driver.py
@lazy_property() def client(self) -> openai.AzureOpenAI: return openai.AzureOpenAI( organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider, )
AzureOpenAiImageGenerationDriver
Bases:
OpenAiImageGenerationDriver
Attributes
Name | Type | Description |
---|---|---|
azure_deployment | str | An optional Azure OpenAi deployment id. Defaults to the model name. |
azure_endpoint | str | An Azure OpenAi endpoint. |
azure_ad_token | Optional[str] | An optional Azure Active Directory token. |
azure_ad_token_provider | Optional[Callable[[], str]] | An optional Azure Active Directory token provider. |
api_version | str | An Azure OpenAi API version. |
client | AzureOpenAI | An openai.AzureOpenAI client. |
Source Code in griptape/drivers/image_generation/azure_openai_image_generation_driver.py
@define class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver): """Driver for Azure-hosted OpenAI image generation API. Attributes: azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name. azure_endpoint: An Azure OpenAi endpoint. azure_ad_token: An optional Azure Active Directory token. azure_ad_token_provider: An optional Azure Active Directory token provider. api_version: An Azure OpenAi API version. client: An `openai.AzureOpenAI` client. """ azure_deployment: str = field( kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}, ) azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) azure_ad_token_provider: Optional[Callable[[], str]] = field( kw_only=True, default=None, metadata={"serializable": False}, ) api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True}) _client: Optional[openai.AzureOpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.AzureOpenAI: return openai.AzureOpenAI( organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider, )
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_version = field(default='2024-02-01', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeazure_ad_token = field(kw_only=True, default=None, metadata={'serializable': False})
class-attribute instance-attributeazure_ad_token_provider = field(kw_only=True, default=None, metadata={'serializable': False})
class-attribute instance-attributeazure_deployment = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True})
class-attribute instance-attributeazure_endpoint = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/image_generation/azure_openai_image_generation_driver.py
@lazy_property() def client(self) -> openai.AzureOpenAI: return openai.AzureOpenAI( organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider, )
AzureOpenAiTextToSpeechDriver
Bases:
OpenAiTextToSpeechDriver
Attributes
Name | Type | Description |
---|---|---|
azure_deployment | str | An optional Azure OpenAi deployment id. Defaults to the model name. |
azure_endpoint | str | An Azure OpenAi endpoint. |
azure_ad_token | Optional[str] | An optional Azure Active Directory token. |
azure_ad_token_provider | Optional[Callable[[], str]] | An optional Azure Active Directory token provider. |
api_version | str | An Azure OpenAi API version. |
client | AzureOpenAI | An openai.AzureOpenAI client. |
Source Code in griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py
@define class AzureOpenAiTextToSpeechDriver(OpenAiTextToSpeechDriver): """Azure OpenAi Text to Speech Driver. Attributes: azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name. azure_endpoint: An Azure OpenAi endpoint. azure_ad_token: An optional Azure Active Directory token. azure_ad_token_provider: An optional Azure Active Directory token provider. api_version: An Azure OpenAi API version. client: An `openai.AzureOpenAI` client. """ model: str = field(default="tts", kw_only=True, metadata={"serializable": True}) azure_deployment: str = field( kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={"serializable": True}, ) azure_endpoint: str = field(kw_only=True, metadata={"serializable": True}) azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False}) azure_ad_token_provider: Optional[Callable[[], str]] = field( kw_only=True, default=None, metadata={"serializable": False}, ) api_version: str = field(default="2024-07-01-preview", kw_only=True, metadata={"serializable": True}) _client: Optional[openai.AzureOpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.AzureOpenAI: return openai.AzureOpenAI( organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider, )
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_version = field(default='2024-07-01-preview', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeazure_ad_token = field(kw_only=True, default=None, metadata={'serializable': False})
class-attribute instance-attributeazure_ad_token_provider = field(kw_only=True, default=None, metadata={'serializable': False})
class-attribute instance-attributeazure_deployment = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True})
class-attribute instance-attributeazure_endpoint = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(default='tts', kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py
@lazy_property() def client(self) -> openai.AzureOpenAI: return openai.AzureOpenAI( organization=self.organization, api_key=self.api_key, api_version=self.api_version, azure_endpoint=self.azure_endpoint, azure_deployment=self.azure_deployment, azure_ad_token=self.azure_ad_token, azure_ad_token_provider=self.azure_ad_token_provider, )
BaseAssistantDriver
Bases:
ABCSource Code in griptape/drivers/assistant/base_assistant_driver.py
@define class BaseAssistantDriver(ABC): """Base class for AssistantDrivers.""" def run(self, *args: BaseArtifact) -> TextArtifact: return self.try_run(*args) @abstractmethod def try_run(self, *args: BaseArtifact) -> TextArtifact: ...
run(*args)
Source Code in griptape/drivers/assistant/base_assistant_driver.py
def run(self, *args: BaseArtifact) -> TextArtifact: return self.try_run(*args)
try_run(*args)abstractmethod
Source Code in griptape/drivers/assistant/base_assistant_driver.py
@abstractmethod def try_run(self, *args: BaseArtifact) -> TextArtifact: ...
BaseAudioTranscriptionDriver
Bases:
SerializableMixin
, ExponentialBackoffMixin
, ABC
Source Code in griptape/drivers/audio_transcription/base_audio_transcription_driver.py
@define class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self) -> None: EventBus.publish_event(StartAudioTranscriptionEvent()) def after_run(self) -> None: EventBus.publish_event(FinishAudioTranscriptionEvent()) def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): with attempt: self.before_run() result = self.try_run(audio, prompts) self.after_run() return result raise Exception("Failed to run audio transcription") @abstractmethod def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: ...
model = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
after_run()
Source Code in griptape/drivers/audio_transcription/base_audio_transcription_driver.py
def after_run(self) -> None: EventBus.publish_event(FinishAudioTranscriptionEvent())
before_run()
Source Code in griptape/drivers/audio_transcription/base_audio_transcription_driver.py
def before_run(self) -> None: EventBus.publish_event(StartAudioTranscriptionEvent())
run(audio, prompts=None)
Source Code in griptape/drivers/audio_transcription/base_audio_transcription_driver.py
def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: for attempt in self.retrying(): with attempt: self.before_run() result = self.try_run(audio, prompts) self.after_run() return result raise Exception("Failed to run audio transcription")
try_run(audio, prompts=None)abstractmethod
Source Code in griptape/drivers/audio_transcription/base_audio_transcription_driver.py
@abstractmethod def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: ...
BaseConversationMemoryDriver
Bases:
SerializableMixin
, ABC
Source Code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
class BaseConversationMemoryDriver(SerializableMixin, ABC): @abstractmethod def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: ... @abstractmethod def load(self) -> tuple[list[Run], dict[str, Any]]: ... def _to_params_dict(self, runs: list[Run], metadata: dict[str, Any]) -> dict: return {"runs": [run.to_dict() for run in runs], "metadata": metadata} def _from_params_dict(self, params_dict: dict[str, Any]) -> tuple[list[Run], dict[str, Any]]: from griptape.memory.structure import Run return [Run.from_dict(run) for run in params_dict.get("runs", [])], params_dict.get("metadata", {})
_from_params_dict(params_dict)
Source Code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
def _from_params_dict(self, params_dict: dict[str, Any]) -> tuple[list[Run], dict[str, Any]]: from griptape.memory.structure import Run return [Run.from_dict(run) for run in params_dict.get("runs", [])], params_dict.get("metadata", {})
_to_params_dict(runs, metadata)
Source Code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
def _to_params_dict(self, runs: list[Run], metadata: dict[str, Any]) -> dict: return {"runs": [run.to_dict() for run in runs], "metadata": metadata}
load()abstractmethod
Source Code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod def load(self) -> tuple[list[Run], dict[str, Any]]: ...
store(runs, metadata)abstractmethod
Source Code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: ...
BaseDiffusionImageGenerationPipelineDriver
Bases:
ABCSource Code in griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py
@define class BaseDiffusionImageGenerationPipelineDriver(ABC): @abstractmethod def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ... @abstractmethod def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ... @abstractmethod def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict: ... @property @abstractmethod def output_image_dimensions(self) -> tuple[int, int]: ...
output_image_dimensions
abstractmethod property
make_additional_params(negative_prompts, device)abstractmethod
Source Code in griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py
@abstractmethod def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict: ...
make_image_param(image)abstractmethod
Source Code in griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py
@abstractmethod def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ...
prepare_pipeline(model, device)abstractmethod
Source Code in griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py
@abstractmethod def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ...
BaseEmbeddingDriver
Bases:
SerializableMixin
, ExponentialBackoffMixin
, ABC
Attributes
Name | Type | Description |
---|---|---|
model | str | The name of the model to use. |
tokenizer | Optional[BaseTokenizer] | An instance of BaseTokenizer to use when calculating tokens. |
Source Code in griptape/drivers/embedding/base_embedding_driver.py
@define class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base Embedding Driver. Attributes: model: The name of the model to use. tokenizer: An instance of `BaseTokenizer` to use when calculating tokens. """ model: str = field(kw_only=True, metadata={"serializable": True}) tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True) chunker: Optional[BaseChunker] = field(init=False) def __attrs_post_init__(self) -> None: self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None def embed_text_artifact( self, artifact: TextArtifact, *, vector_operation: VectorOperation | None = None ) -> list[float]: warnings.warn( "`BaseEmbeddingDriver.embed_text_artifact` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) return self.embed(artifact, vector_operation=vector_operation) def embed_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]: warnings.warn( "`BaseEmbeddingDriver.embed_string` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) return self.embed(string, vector_operation=vector_operation) def embed( self, value: str | TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None ) -> list[float]: for attempt in self.retrying(): with attempt: if isinstance(value, str): if ( self.tokenizer is not None and self.tokenizer.count_tokens(value) > self.tokenizer.max_input_tokens ): return self._embed_long_string(value, vector_operation=vector_operation) return self.try_embed_chunk(value, vector_operation=vector_operation) if isinstance(value, TextArtifact): return self.embed(value.to_text(), vector_operation=vector_operation) if isinstance(value, ImageArtifact): return self.try_embed_artifact(value, vector_operation=vector_operation) raise RuntimeError("Failed to embed string.") def try_embed_artifact( self, artifact: TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None ) -> list[float]: # TODO: Mark as abstract method for griptape 2.0 if isinstance(artifact, TextArtifact): return self.try_embed_chunk(artifact.value, vector_operation=vector_operation) raise ValueError(f"{self.__class__.__name__} does not support embedding images.") @abstractmethod def try_embed_chunk(self, chunk: str, *, vector_operation: VectorOperation | None = None) -> list[float]: # TODO: Remove for griptape 2.0, subclasses should implement `try_embed_artifact` instead ... def _embed_long_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]: """Embeds a string that is too long to embed in one go. Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb """ chunks = self.chunker.chunk(string) # pyright: ignore[reportOptionalMemberAccess] In practice this is never None embedding_chunks = [] length_chunks = [] for chunk in chunks: embedding_chunks.append(self.embed(chunk.value, vector_operation=vector_operation)) length_chunks.append(len(chunk)) # generate weighted averages embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks) # normalize length to 1 embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks) return embedding_chunks.tolist()
chunker = field(init=False)
class-attribute instance-attributemodel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=None, kw_only=True)
class-attribute instance-attribute
attrs_post_init()
Source Code in griptape/drivers/embedding/base_embedding_driver.py
def __attrs_post_init__(self) -> None: self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None
_embed_long_string(string, *, vector_operation=None)
Source Code in griptape/drivers/embedding/base_embedding_driver.py
def _embed_long_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]: """Embeds a string that is too long to embed in one go. Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb """ chunks = self.chunker.chunk(string) # pyright: ignore[reportOptionalMemberAccess] In practice this is never None embedding_chunks = [] length_chunks = [] for chunk in chunks: embedding_chunks.append(self.embed(chunk.value, vector_operation=vector_operation)) length_chunks.append(len(chunk)) # generate weighted averages embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks) # normalize length to 1 embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks) return embedding_chunks.tolist()
embed(value, *, vector_operation=None)
Source Code in griptape/drivers/embedding/base_embedding_driver.py
def embed( self, value: str | TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None ) -> list[float]: for attempt in self.retrying(): with attempt: if isinstance(value, str): if ( self.tokenizer is not None and self.tokenizer.count_tokens(value) > self.tokenizer.max_input_tokens ): return self._embed_long_string(value, vector_operation=vector_operation) return self.try_embed_chunk(value, vector_operation=vector_operation) if isinstance(value, TextArtifact): return self.embed(value.to_text(), vector_operation=vector_operation) if isinstance(value, ImageArtifact): return self.try_embed_artifact(value, vector_operation=vector_operation) raise RuntimeError("Failed to embed string.")
embed_string(string, *, vector_operation=None)
Source Code in griptape/drivers/embedding/base_embedding_driver.py
def embed_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]: warnings.warn( "`BaseEmbeddingDriver.embed_string` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) return self.embed(string, vector_operation=vector_operation)
embed_text_artifact(artifact, *, vector_operation=None)
Source Code in griptape/drivers/embedding/base_embedding_driver.py
def embed_text_artifact( self, artifact: TextArtifact, *, vector_operation: VectorOperation | None = None ) -> list[float]: warnings.warn( "`BaseEmbeddingDriver.embed_text_artifact` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) return self.embed(artifact, vector_operation=vector_operation)
try_embed_artifact(artifact, *, vector_operation=None)
Source Code in griptape/drivers/embedding/base_embedding_driver.py
def try_embed_artifact( self, artifact: TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None ) -> list[float]: # TODO: Mark as abstract method for griptape 2.0 if isinstance(artifact, TextArtifact): return self.try_embed_chunk(artifact.value, vector_operation=vector_operation) raise ValueError(f"{self.__class__.__name__} does not support embedding images.")
try_embed_chunk(chunk, *, vector_operation=None)abstractmethod
Source Code in griptape/drivers/embedding/base_embedding_driver.py
@abstractmethod def try_embed_chunk(self, chunk: str, *, vector_operation: VectorOperation | None = None) -> list[float]: # TODO: Remove for griptape 2.0, subclasses should implement `try_embed_artifact` instead ...
BaseEventListenerDriver
Bases:
FuturesExecutorMixin
, ExponentialBackoffMixin
, ABC
Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
@define class BaseEventListenerDriver(FuturesExecutorMixin, ExponentialBackoffMixin, ABC): batched: bool = field(default=True, kw_only=True) batch_size: int = field(default=10, kw_only=True) _batch: list[dict] = field(default=Factory(list), kw_only=True) @property def batch(self) -> list[dict]: return self._batch def publish_event(self, event: BaseEvent | dict) -> None: event_payload = event if isinstance(event, dict) else event.to_dict() with self.create_futures_executor() as futures_executor: if self.batched: self._batch.append(event_payload) if len(self.batch) >= self.batch_size: futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch) self._batch = [] else: futures_executor.submit(with_contextvars(self._safe_publish_event_payload), event_payload) def flush_events(self) -> None: if self.batch: with self.create_futures_executor() as futures_executor: futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch) self._batch = [] @abstractmethod def try_publish_event_payload(self, event_payload: dict) -> None: ... @abstractmethod def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: ... def _safe_publish_event_payload(self, event_payload: dict) -> None: try: for attempt in self.retrying(): with attempt: self.try_publish_event_payload(event_payload) except Exception: logger.warning("Failed to publish event after %s attempts", self.max_attempts, exc_info=True) def _safe_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: try: for attempt in self.retrying(): with attempt: self.try_publish_event_payload_batch(event_payload_batch) except Exception: logger.warning("Failed to publish event batch after %s attempts", self.max_attempts, exc_info=True)
_batch = field(default=Factory(list), kw_only=True)
class-attribute instance-attributebatch
propertybatch_size = field(default=10, kw_only=True)
class-attribute instance-attributebatched = field(default=True, kw_only=True)
class-attribute instance-attribute
_safe_publish_event_payload(event_payload)
Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
def _safe_publish_event_payload(self, event_payload: dict) -> None: try: for attempt in self.retrying(): with attempt: self.try_publish_event_payload(event_payload) except Exception: logger.warning("Failed to publish event after %s attempts", self.max_attempts, exc_info=True)
_safe_publish_event_payload_batch(event_payload_batch)
Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
def _safe_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: try: for attempt in self.retrying(): with attempt: self.try_publish_event_payload_batch(event_payload_batch) except Exception: logger.warning("Failed to publish event batch after %s attempts", self.max_attempts, exc_info=True)
flush_events()
Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
def flush_events(self) -> None: if self.batch: with self.create_futures_executor() as futures_executor: futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch) self._batch = []
publish_event(event)
Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
def publish_event(self, event: BaseEvent | dict) -> None: event_payload = event if isinstance(event, dict) else event.to_dict() with self.create_futures_executor() as futures_executor: if self.batched: self._batch.append(event_payload) if len(self.batch) >= self.batch_size: futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch) self._batch = [] else: futures_executor.submit(with_contextvars(self._safe_publish_event_payload), event_payload)
try_publish_event_payload(event_payload)abstractmethod
Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
@abstractmethod def try_publish_event_payload(self, event_payload: dict) -> None: ...
try_publish_event_payload_batch(event_payload_batch)abstractmethod
Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
@abstractmethod def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: ...
BaseFileManagerDriver
Bases:
ABCAttributes
Name | Type | Description |
---|---|---|
default_loader | `` | The default loader to use for loading file contents into artifacts. |
loaders | `` | Dictionary of file extension specific loaders to use for loading file contents into artifacts. |
Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
@define class BaseFileManagerDriver(ABC): """BaseFileManagerDriver can be used to list, load, and save files. Attributes: default_loader: The default loader to use for loading file contents into artifacts. loaders: Dictionary of file extension specific loaders to use for loading file contents into artifacts. """ _workdir: str = field(kw_only=True, alias="workdir") encoding: Optional[str] = field(default=None, kw_only=True) @property @abstractmethod def workdir(self) -> str: ... @workdir.setter @abstractmethod def workdir(self, value: str) -> None: ... def list_files(self, path: str) -> TextArtifact: entries = self.try_list_files(path) return TextArtifact("\n".join(list(entries))) @abstractmethod def try_list_files(self, path: str) -> list[str]: ... def load_file(self, path: str) -> BlobArtifact | TextArtifact: if self.encoding is None: return BlobArtifact(self.try_load_file(path)) return TextArtifact(self.try_load_file(path).decode(encoding=self.encoding), encoding=self.encoding) @abstractmethod def try_load_file(self, path: str) -> bytes: ... def save_file(self, path: str, value: bytes | str) -> InfoArtifact: if isinstance(value, str): value = value.encode() if self.encoding is None else value.encode(encoding=self.encoding) elif isinstance(value, (bytearray, memoryview)): raise ValueError(f"Unsupported type: {type(value)}") location = self.try_save_file(path, value) return InfoArtifact(f"Successfully saved file at: {location}") @abstractmethod def try_save_file(self, path: str, value: bytes) -> str: ... def load_artifact(self, path: str) -> BaseArtifact: response = self.try_load_file(path) return BaseArtifact.from_json( response.decode() if self.encoding is None else response.decode(encoding=self.encoding) ) def save_artifact(self, path: str, artifact: BaseArtifact) -> InfoArtifact: artifact_json = artifact.to_json() value = artifact_json.encode() if self.encoding is None else artifact_json.encode(encoding=self.encoding) location = self.try_save_file(path, value) return InfoArtifact(f"Successfully saved artifact at: {location}")
_workdir = field(kw_only=True, alias='workdir')
class-attribute instance-attributeencoding = field(default=None, kw_only=True)
class-attribute instance-attributeworkdir
abstractmethod property writable
list_files(path)
Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
def list_files(self, path: str) -> TextArtifact: entries = self.try_list_files(path) return TextArtifact("\n".join(list(entries)))
load_artifact(path)
Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
def load_artifact(self, path: str) -> BaseArtifact: response = self.try_load_file(path) return BaseArtifact.from_json( response.decode() if self.encoding is None else response.decode(encoding=self.encoding) )
load_file(path)
Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
def load_file(self, path: str) -> BlobArtifact | TextArtifact: if self.encoding is None: return BlobArtifact(self.try_load_file(path)) return TextArtifact(self.try_load_file(path).decode(encoding=self.encoding), encoding=self.encoding)
save_artifact(path, artifact)
Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
def save_artifact(self, path: str, artifact: BaseArtifact) -> InfoArtifact: artifact_json = artifact.to_json() value = artifact_json.encode() if self.encoding is None else artifact_json.encode(encoding=self.encoding) location = self.try_save_file(path, value) return InfoArtifact(f"Successfully saved artifact at: {location}")
save_file(path, value)
Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
def save_file(self, path: str, value: bytes | str) -> InfoArtifact: if isinstance(value, str): value = value.encode() if self.encoding is None else value.encode(encoding=self.encoding) elif isinstance(value, (bytearray, memoryview)): raise ValueError(f"Unsupported type: {type(value)}") location = self.try_save_file(path, value) return InfoArtifact(f"Successfully saved file at: {location}")
try_list_files(path)abstractmethod
Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
@abstractmethod def try_list_files(self, path: str) -> list[str]: ...
try_load_file(path)abstractmethod
Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
@abstractmethod def try_load_file(self, path: str) -> bytes: ...
try_save_file(path, value)abstractmethod
Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
@abstractmethod def try_save_file(self, path: str, value: bytes) -> str: ...
BaseImageGenerationDriver
Bases:
SerializableMixin
, ExponentialBackoffMixin
, ABC
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
@define class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts)) def after_run(self) -> None: EventBus.publish_event(FinishImageGenerationEvent()) def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): with attempt: self.before_run(prompts, negative_prompts) result = self.try_text_to_image(prompts, negative_prompts) self.after_run() return result raise Exception("Failed to run text to image generation") def run_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: for attempt in self.retrying(): with attempt: self.before_run(prompts, negative_prompts) result = self.try_image_variation(prompts, image, negative_prompts) self.after_run() return result raise Exception("Failed to generate image variations") def run_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: for attempt in self.retrying(): with attempt: self.before_run(prompts, negative_prompts) result = self.try_image_inpainting(prompts, image, mask, negative_prompts) self.after_run() return result raise Exception("Failed to run image inpainting") def run_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: for attempt in self.retrying(): with attempt: self.before_run(prompts, negative_prompts) result = self.try_image_outpainting(prompts, image, mask, negative_prompts) self.after_run() return result raise Exception("Failed to run image outpainting") @abstractmethod def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: ... @abstractmethod def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: ... @abstractmethod def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: ... @abstractmethod def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: ...
model = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
after_run()
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def after_run(self) -> None: EventBus.publish_event(FinishImageGenerationEvent())
before_run(prompts, negative_prompts=None)
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None: EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts))
run_image_inpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: for attempt in self.retrying(): with attempt: self.before_run(prompts, negative_prompts) result = self.try_image_inpainting(prompts, image, mask, negative_prompts) self.after_run() return result raise Exception("Failed to run image inpainting")
run_image_outpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: for attempt in self.retrying(): with attempt: self.before_run(prompts, negative_prompts) result = self.try_image_outpainting(prompts, image, mask, negative_prompts) self.after_run() return result raise Exception("Failed to run image outpainting")
run_image_variation(prompts, image, negative_prompts=None)
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: for attempt in self.retrying(): with attempt: self.before_run(prompts, negative_prompts) result = self.try_image_variation(prompts, image, negative_prompts) self.after_run() return result raise Exception("Failed to generate image variations")
run_text_to_image(prompts, negative_prompts=None)
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: for attempt in self.retrying(): with attempt: self.before_run(prompts, negative_prompts) result = self.try_text_to_image(prompts, negative_prompts) self.after_run() return result raise Exception("Failed to run text to image generation")
try_image_inpainting(prompts, image, mask, negative_prompts=None)abstractmethod
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: ...
try_image_outpainting(prompts, image, mask, negative_prompts=None)abstractmethod
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: ...
try_image_variation(prompts, image, negative_prompts=None)abstractmethod
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: ...
try_text_to_image(prompts, negative_prompts=None)abstractmethod
Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: ...
BaseImageGenerationModelDriver
Bases:
SerializableMixin
, ABC
Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@define class BaseImageGenerationModelDriver(SerializableMixin, ABC): @abstractmethod def get_generated_image(self, response: dict) -> bytes: ... @abstractmethod def text_to_image_request_parameters( self, prompts: list[str], image_width: int, image_height: int, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict[str, Any]: ... @abstractmethod def image_variation_request_parameters( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict[str, Any]: ... @abstractmethod def image_inpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict[str, Any]: ... @abstractmethod def image_outpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict[str, Any]: ...
get_generated_image(response)abstractmethod
Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod def get_generated_image(self, response: dict) -> bytes: ...
image_inpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)abstractmethod
Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod def image_inpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict[str, Any]: ...
image_outpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)abstractmethod
Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod def image_outpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict[str, Any]: ...
image_variation_request_parameters(prompts, image, negative_prompts=None, seed=None)abstractmethod
Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod def image_variation_request_parameters( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict[str, Any]: ...
text_to_image_request_parameters(prompts, image_width, image_height, negative_prompts=None, seed=None)abstractmethod
Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod def text_to_image_request_parameters( self, prompts: list[str], image_width: int, image_height: int, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict[str, Any]: ...
BaseMultiModelImageGenerationDriver
Bases:
BaseImageGenerationDriver
, ABC
Attributes
Name | Type | Description |
---|---|---|
image_generation_model_driver | BaseImageGenerationModelDriver | Image Model Driver to use. |
Source Code in griptape/drivers/image_generation/base_multi_model_image_generation_driver.py
@define class BaseMultiModelImageGenerationDriver(BaseImageGenerationDriver, ABC): """Image Generation Driver for platforms like Amazon Bedrock that host many LLM models. Instances of this Image Generation Driver require a Image Generation Model Driver which is used to structure the image generation request in the format required by the model and to process the output. Attributes: image_generation_model_driver: Image Model Driver to use. """ image_generation_model_driver: BaseImageGenerationModelDriver = field(kw_only=True, metadata={"serializable": True})
image_generation_model_driver = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
BaseObservabilityDriver
Bases:
ABCSource Code in griptape/drivers/observability/base_observability_driver.py
@define class BaseObservabilityDriver(ABC): def __enter__(self) -> None: # noqa: B027 pass def __exit__( self, exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], exc_traceback: Optional[TracebackType], ) -> bool: return False @abstractmethod def observe(self, call: Observable.Call) -> Any: ... @abstractmethod def get_span_id(self) -> Optional[str]: ...
enter()
Source Code in griptape/drivers/observability/base_observability_driver.py
def __enter__(self) -> None: # noqa: B027 pass
exit(exc_type, exc_value, exc_traceback)
Source Code in griptape/drivers/observability/base_observability_driver.py
def __exit__( self, exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], exc_traceback: Optional[TracebackType], ) -> bool: return False
get_span_id()abstractmethod
Source Code in griptape/drivers/observability/base_observability_driver.py
@abstractmethod def get_span_id(self) -> Optional[str]: ...
observe(call)abstractmethod
Source Code in griptape/drivers/observability/base_observability_driver.py
@abstractmethod def observe(self, call: Observable.Call) -> Any: ...
BasePromptDriver
Bases:
SerializableMixin
, ExponentialBackoffMixin
, ABC
Attributes
Name | Type | Description |
---|---|---|
temperature | float | The temperature to use for the completion. |
max_tokens | Optional[int] | The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer. |
prompt_stack_to_string | str | A function that converts a PromptStack to a string. |
ignored_exception_types | tuple[type[Exception], ...] | A tuple of exception types to ignore. |
model | str | The model name. |
tokenizer | BaseTokenizer | An instance of BaseTokenizer to when calculating tokens. |
stream | bool | Whether to stream the completion or not. CompletionChunkEvent s will be published to the Structure if one is provided. |
use_native_tools | bool | Whether to use LLM's native function calling capabilities. Must be supported by the model. |
extra_params | dict | Extra parameters to pass to the model. |
Source Code in griptape/drivers/prompt/base_prompt_driver.py
@define(kw_only=True) class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC): """Base class for the Prompt Drivers. Attributes: temperature: The temperature to use for the completion. max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer. prompt_stack_to_string: A function that converts a `PromptStack` to a string. ignored_exception_types: A tuple of exception types to ignore. model: The model name. tokenizer: An instance of `BaseTokenizer` to when calculating tokens. stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided. use_native_tools: Whether to use LLM's native function calling capabilities. Must be supported by the model. extra_params: Extra parameters to pass to the model. """ temperature: float = field(default=0.1, metadata={"serializable": True}) max_tokens: Optional[int] = field(default=None, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda: (ImportError, ValueError))) model: str = field(metadata={"serializable": True}) tokenizer: BaseTokenizer stream: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True}) structured_output_strategy: StructuredOutputStrategy = field( default="rule", kw_only=True, metadata={"serializable": True} ) extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) def before_run(self, prompt_stack: PromptStack) -> None: self._init_structured_output(prompt_stack) EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack)) def after_run(self, result: Message) -> None: EventBus.publish_event( FinishPromptEvent( model=self.model, result=result.value, input_token_count=result.usage.input_tokens, output_token_count=result.usage.output_tokens, ), ) @observable(tags=["PromptDriver.run()"]) def run(self, prompt_input: PromptStack | BaseArtifact) -> Message: if isinstance(prompt_input, BaseArtifact): prompt_stack = PromptStack.from_artifact(prompt_input) else: prompt_stack = prompt_input for attempt in self.retrying(): with attempt: self.before_run(prompt_stack) result = self.__process_stream(prompt_stack) if self.stream else self.__process_run(prompt_stack) self.after_run(result) return result raise Exception("prompt driver failed after all retry attempts") def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: """Converts a Prompt Stack to a string for token counting or model prompt_input. This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens. Args: prompt_stack: The Prompt Stack to convert to a string. Returns: A single string representation of the Prompt Stack. """ prompt_lines = [] for i in prompt_stack.messages: content = i.to_text() if i.is_user(): prompt_lines.append(f"User: {content}") elif i.is_assistant(): prompt_lines.append(f"Assistant: {content}") else: prompt_lines.append(content) prompt_lines.append("Assistant:") return "\n\n".join(prompt_lines) @abstractmethod def try_run(self, prompt_stack: PromptStack) -> Message: ... @abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ... def _init_structured_output(self, prompt_stack: PromptStack) -> None: from griptape.tools import StructuredOutputTool if (output_schema := prompt_stack.output_schema) is not None: if self.structured_output_strategy == "tool": structured_output_tool = StructuredOutputTool(output_schema=output_schema) if structured_output_tool not in prompt_stack.tools: prompt_stack.tools.append(structured_output_tool) elif self.structured_output_strategy == "rule": output_artifact = TextArtifact(JsonSchemaRule(prompt_stack.to_output_json_schema()).to_text()) system_messages = prompt_stack.system_messages if system_messages: last_system_message = prompt_stack.system_messages[-1] last_system_message.content.extend( [ TextMessageContent(TextArtifact("\n\n")), TextMessageContent(output_artifact), ] ) else: prompt_stack.messages.insert( 0, Message( content=[TextMessageContent(output_artifact)], role=Message.SYSTEM_ROLE, ), ) def __process_run(self, prompt_stack: PromptStack) -> Message: return self.try_run(prompt_stack) def __process_stream(self, prompt_stack: PromptStack) -> Message: delta_contents: dict[int, list[BaseDeltaMessageContent]] = {} usage = DeltaMessage.Usage() # Aggregate all content deltas from the stream message_deltas = self.try_stream(prompt_stack) for message_delta in message_deltas: usage += message_delta.usage content = message_delta.content if content is not None: if content.index in delta_contents: delta_contents[content.index].append(content) else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index)) elif isinstance(content, AudioDeltaMessageContent) and content.data is not None: EventBus.publish_event(AudioChunkEvent(data=content.data)) elif isinstance(content, ActionCallDeltaMessageContent): EventBus.publish_event( ActionChunkEvent( partial_input=content.partial_input, tag=content.tag, name=content.name, path=content.path, index=content.index, ), ) # Build a complete content from the content deltas return self.__build_message(list(delta_contents.values()), usage) def __build_message( self, delta_contents: list[list[BaseDeltaMessageContent]], usage: DeltaMessage.Usage ) -> Message: content = [] for delta_content in delta_contents: text_deltas = [delta for delta in delta_content if isinstance(delta, TextDeltaMessageContent)] audio_deltas = [delta for delta in delta_content if isinstance(delta, AudioDeltaMessageContent)] action_deltas = [delta for delta in delta_content if isinstance(delta, ActionCallDeltaMessageContent)] if text_deltas: content.append(TextMessageContent.from_deltas(text_deltas)) if audio_deltas: content.append(AudioMessageContent.from_deltas(audio_deltas)) if action_deltas: content.append(ActionCallMessageContent.from_deltas(action_deltas)) return Message( content=content, role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), )
extra_params = field(factory=dict, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeignored_exception_types = field(default=Factory(lambda: (ImportError, ValueError)))
class-attribute instance-attributemax_tokens = field(default=None, metadata={'serializable': True})
class-attribute instance-attributemodel = field(metadata={'serializable': True})
class-attribute instance-attributestream = field(default=False, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestructured_output_strategy = field(default='rule', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetemperature = field(default=0.1, metadata={'serializable': True})
class-attribute instance-attributetokenizer
instance-attributeuse_native_tools = field(default=False, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
__build_message(delta_contents, usage)
Source Code in griptape/drivers/prompt/base_prompt_driver.py
def __build_message( self, delta_contents: list[list[BaseDeltaMessageContent]], usage: DeltaMessage.Usage ) -> Message: content = [] for delta_content in delta_contents: text_deltas = [delta for delta in delta_content if isinstance(delta, TextDeltaMessageContent)] audio_deltas = [delta for delta in delta_content if isinstance(delta, AudioDeltaMessageContent)] action_deltas = [delta for delta in delta_content if isinstance(delta, ActionCallDeltaMessageContent)] if text_deltas: content.append(TextMessageContent.from_deltas(text_deltas)) if audio_deltas: content.append(AudioMessageContent.from_deltas(audio_deltas)) if action_deltas: content.append(ActionCallMessageContent.from_deltas(action_deltas)) return Message( content=content, role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), )
__process_run(prompt_stack)
Source Code in griptape/drivers/prompt/base_prompt_driver.py
def __process_run(self, prompt_stack: PromptStack) -> Message: return self.try_run(prompt_stack)
__process_stream(prompt_stack)
Source Code in griptape/drivers/prompt/base_prompt_driver.py
def __process_stream(self, prompt_stack: PromptStack) -> Message: delta_contents: dict[int, list[BaseDeltaMessageContent]] = {} usage = DeltaMessage.Usage() # Aggregate all content deltas from the stream message_deltas = self.try_stream(prompt_stack) for message_delta in message_deltas: usage += message_delta.usage content = message_delta.content if content is not None: if content.index in delta_contents: delta_contents[content.index].append(content) else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index)) elif isinstance(content, AudioDeltaMessageContent) and content.data is not None: EventBus.publish_event(AudioChunkEvent(data=content.data)) elif isinstance(content, ActionCallDeltaMessageContent): EventBus.publish_event( ActionChunkEvent( partial_input=content.partial_input, tag=content.tag, name=content.name, path=content.path, index=content.index, ), ) # Build a complete content from the content deltas return self.__build_message(list(delta_contents.values()), usage)
_init_structured_output(prompt_stack)
Source Code in griptape/drivers/prompt/base_prompt_driver.py
def _init_structured_output(self, prompt_stack: PromptStack) -> None: from griptape.tools import StructuredOutputTool if (output_schema := prompt_stack.output_schema) is not None: if self.structured_output_strategy == "tool": structured_output_tool = StructuredOutputTool(output_schema=output_schema) if structured_output_tool not in prompt_stack.tools: prompt_stack.tools.append(structured_output_tool) elif self.structured_output_strategy == "rule": output_artifact = TextArtifact(JsonSchemaRule(prompt_stack.to_output_json_schema()).to_text()) system_messages = prompt_stack.system_messages if system_messages: last_system_message = prompt_stack.system_messages[-1] last_system_message.content.extend( [ TextMessageContent(TextArtifact("\n\n")), TextMessageContent(output_artifact), ] ) else: prompt_stack.messages.insert( 0, Message( content=[TextMessageContent(output_artifact)], role=Message.SYSTEM_ROLE, ), )
after_run(result)
Source Code in griptape/drivers/prompt/base_prompt_driver.py
def after_run(self, result: Message) -> None: EventBus.publish_event( FinishPromptEvent( model=self.model, result=result.value, input_token_count=result.usage.input_tokens, output_token_count=result.usage.output_tokens, ), )
before_run(prompt_stack)
Source Code in griptape/drivers/prompt/base_prompt_driver.py
def before_run(self, prompt_stack: PromptStack) -> None: self._init_structured_output(prompt_stack) EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))
prompt_stack_to_string(prompt_stack)
Converts a Prompt Stack to a string for token counting or model prompt_input.
This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens.
Parameters
Name | Type | Description | Default |
---|---|---|---|
prompt_stack | PromptStack | The Prompt Stack to convert to a string. | required |
Returns
Type | Description |
---|---|
str | A single string representation of the Prompt Stack. |
Source Code in griptape/drivers/prompt/base_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: """Converts a Prompt Stack to a string for token counting or model prompt_input. This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens. Args: prompt_stack: The Prompt Stack to convert to a string. Returns: A single string representation of the Prompt Stack. """ prompt_lines = [] for i in prompt_stack.messages: content = i.to_text() if i.is_user(): prompt_lines.append(f"User: {content}") elif i.is_assistant(): prompt_lines.append(f"Assistant: {content}") else: prompt_lines.append(content) prompt_lines.append("Assistant:") return "\n\n".join(prompt_lines)
run(prompt_input)
Source Code in griptape/drivers/prompt/base_prompt_driver.py
@observable(tags=["PromptDriver.run()"]) def run(self, prompt_input: PromptStack | BaseArtifact) -> Message: if isinstance(prompt_input, BaseArtifact): prompt_stack = PromptStack.from_artifact(prompt_input) else: prompt_stack = prompt_input for attempt in self.retrying(): with attempt: self.before_run(prompt_stack) result = self.__process_stream(prompt_stack) if self.stream else self.__process_run(prompt_stack) self.after_run(result) return result raise Exception("prompt driver failed after all retry attempts")
try_run(prompt_stack)abstractmethod
Source Code in griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod def try_run(self, prompt_stack: PromptStack) -> Message: ...
try_stream(prompt_stack)abstractmethod
Source Code in griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...
BaseRerankDriver
Bases:
ABCSource Code in griptape/drivers/rerank/base_rerank_driver.py
@define(kw_only=True) class BaseRerankDriver(ABC): @abstractmethod def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: ...
run(query, artifacts)abstractmethod
Source Code in griptape/drivers/rerank/base_rerank_driver.py
@abstractmethod def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: ...
BaseRulesetDriver
Bases:
SerializableMixin
, ABC
Attributes
Name | Type | Description |
---|---|---|
raise_not_found | bool | Whether to raise an error if the ruleset is not found. Defaults to True. |
Source Code in griptape/drivers/ruleset/base_ruleset_driver.py
@define class BaseRulesetDriver(SerializableMixin, ABC): """Base class for ruleset drivers. Attributes: raise_not_found: Whether to raise an error if the ruleset is not found. Defaults to True. """ raise_not_found: bool = field(default=True, kw_only=True, metadata={"serializable": True}) @abstractmethod def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]: ... def _from_ruleset_dict(self, params_dict: dict[str, Any]) -> tuple[list[BaseRule], dict[str, Any]]: return [ self._get_rule(rule["value"], rule.get("meta", {})) for rule in params_dict.get("rules", []) ], params_dict.get("meta", {}) def _get_rule(self, value: Any, meta: dict[str, Any]) -> BaseRule: from griptape.rules import JsonSchemaRule, Rule return JsonSchemaRule(value=value, meta=meta) if isinstance(value, dict) else Rule(value=str(value), meta=meta)
raise_not_found = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
_from_ruleset_dict(params_dict)
Source Code in griptape/drivers/ruleset/base_ruleset_driver.py
def _from_ruleset_dict(self, params_dict: dict[str, Any]) -> tuple[list[BaseRule], dict[str, Any]]: return [ self._get_rule(rule["value"], rule.get("meta", {})) for rule in params_dict.get("rules", []) ], params_dict.get("meta", {})
_get_rule(value, meta)
Source Code in griptape/drivers/ruleset/base_ruleset_driver.py
def _get_rule(self, value: Any, meta: dict[str, Any]) -> BaseRule: from griptape.rules import JsonSchemaRule, Rule return JsonSchemaRule(value=value, meta=meta) if isinstance(value, dict) else Rule(value=str(value), meta=meta)
load(ruleset_name)abstractmethod
Source Code in griptape/drivers/ruleset/base_ruleset_driver.py
@abstractmethod def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]: ...
BaseSqlDriver
Bases:
ABCSource Code in griptape/drivers/sql/base_sql_driver.py
@define class BaseSqlDriver(ABC): @dataclass class RowResult: cells: dict[str, Any] @abstractmethod def execute_query(self, query: str) -> Optional[list[RowResult]]: ... @abstractmethod def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]: ... @abstractmethod def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: ...
RowResultdataclass
Source Code in griptape/drivers/sql/base_sql_driver.py
@dataclass class RowResult: cells: dict[str, Any]
cells
instance-attribute
execute_query(query)abstractmethod
Source Code in griptape/drivers/sql/base_sql_driver.py
@abstractmethod def execute_query(self, query: str) -> Optional[list[RowResult]]: ...
execute_query_raw(query)abstractmethod
Source Code in griptape/drivers/sql/base_sql_driver.py
@abstractmethod def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]: ...
get_table_schema(table_name, schema=None)abstractmethod
Source Code in griptape/drivers/sql/base_sql_driver.py
@abstractmethod def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: ...
BaseStructureRunDriver
Bases:
ABCAttributes
Name | Type | Description |
---|---|---|
env | dict[str, str] | Environment variables to set before running the Structure. |
Source Code in griptape/drivers/structure_run/base_structure_run_driver.py
@define class BaseStructureRunDriver(ABC): """Base class for Structure Run Drivers. Attributes: env: Environment variables to set before running the Structure. """ env: dict[str, str] = field(default=Factory(dict), kw_only=True) def run(self, *args: BaseArtifact) -> BaseArtifact: return self.try_run(*args) @abstractmethod def try_run(self, *args: BaseArtifact) -> BaseArtifact: ...
env = field(default=Factory(dict), kw_only=True)
class-attribute instance-attribute
run(*args)
Source Code in griptape/drivers/structure_run/base_structure_run_driver.py
def run(self, *args: BaseArtifact) -> BaseArtifact: return self.try_run(*args)
try_run(*args)abstractmethod
Source Code in griptape/drivers/structure_run/base_structure_run_driver.py
@abstractmethod def try_run(self, *args: BaseArtifact) -> BaseArtifact: ...
BaseTextToSpeechDriver
Bases:
SerializableMixin
, ExponentialBackoffMixin
, ABC
Source Code in griptape/drivers/text_to_speech/base_text_to_speech_driver.py
@define class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC): model: str = field(kw_only=True, metadata={"serializable": True}) def before_run(self, prompts: list[str]) -> None: EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts)) def after_run(self) -> None: EventBus.publish_event(FinishTextToSpeechEvent()) def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): with attempt: self.before_run(prompts) result = self.try_text_to_audio(prompts) self.after_run() return result raise Exception("Failed to run text to audio generation") @abstractmethod def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: ...
model = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
after_run()
Source Code in griptape/drivers/text_to_speech/base_text_to_speech_driver.py
def after_run(self) -> None: EventBus.publish_event(FinishTextToSpeechEvent())
before_run(prompts)
Source Code in griptape/drivers/text_to_speech/base_text_to_speech_driver.py
def before_run(self, prompts: list[str]) -> None: EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts))
run_text_to_audio(prompts)
Source Code in griptape/drivers/text_to_speech/base_text_to_speech_driver.py
def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact: for attempt in self.retrying(): with attempt: self.before_run(prompts) result = self.try_text_to_audio(prompts) self.after_run() return result raise Exception("Failed to run text to audio generation")
try_text_to_audio(prompts)abstractmethod
Source Code in griptape/drivers/text_to_speech/base_text_to_speech_driver.py
@abstractmethod def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: ...
BaseVectorStoreDriver
Bases:
SerializableMixin
, FuturesExecutorMixin
, ABC
Source Code in griptape/drivers/vector/base_vector_store_driver.py
@define class BaseVectorStoreDriver(SerializableMixin, FuturesExecutorMixin, ABC): DEFAULT_QUERY_COUNT = 5 @define class Entry(SerializableMixin): id: str = field(metadata={"serializable": True}) vector: Optional[list[float]] = field(default=None, metadata={"serializable": True}) score: Optional[float] = field(default=None, metadata={"serializable": True}) meta: Optional[dict] = field(default=None, metadata={"serializable": True}) namespace: Optional[str] = field(default=None, metadata={"serializable": True}) def to_artifact(self) -> BaseArtifact: return BaseArtifact.from_json(self.meta["artifact"]) # pyright: ignore[reportOptionalSubscript] embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True}) def upsert_text_artifacts( self, artifacts: list[TextArtifact] | dict[str, list[TextArtifact]], *, meta: Optional[dict] = None, **kwargs, ) -> list[str] | dict[str, list[str]]: warnings.warn( "`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert_collection` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) return self.upsert_collection(artifacts, meta=meta, **kwargs) def upsert_text_artifact( self, artifact: TextArtifact, *, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: warnings.warn( "`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) return self.upsert(artifact, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs) def upsert_text( self, string: str, *, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: warnings.warn( "`BaseVectorStoreDriver.upsert_text` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) return self.upsert(string, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs) @overload def upsert_collection( self, artifacts: list[TextArtifact] | list[ImageArtifact], *, meta: Optional[dict] = None, **kwargs, ) -> list[str]: ... @overload def upsert_collection( self, artifacts: dict[str, list[TextArtifact]] | dict[str, list[ImageArtifact]], *, meta: Optional[dict] = None, **kwargs, ) -> dict[str, list[str]]: ... def upsert_collection( self, artifacts: list[TextArtifact] | list[ImageArtifact] | dict[str, list[TextArtifact]] | dict[str, list[ImageArtifact]], *, meta: Optional[dict] = None, **kwargs, ): with self.create_futures_executor() as futures_executor: if isinstance(artifacts, list): return utils.execute_futures_list( [ futures_executor.submit(with_contextvars(self.upsert), a, namespace=None, meta=meta, **kwargs) for a in artifacts ], ) futures_dict = {} for namespace, artifact_list in artifacts.items(): for a in artifact_list: if not futures_dict.get(namespace): futures_dict[namespace] = [] futures_dict[namespace].append( futures_executor.submit( with_contextvars(self.upsert), a, namespace=namespace, meta=meta, **kwargs ) ) return utils.execute_futures_list_dict(futures_dict) def upsert( self, value: str | TextArtifact | ImageArtifact, *, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: artifact = TextArtifact(value) if isinstance(value, str) else value meta = {} if meta is None else meta if vector_id is None: value = artifact.to_text() if artifact.reference is None else artifact.to_text() + str(artifact.reference) vector_id = self._get_default_vector_id(value) if self.does_entry_exist(vector_id, namespace=namespace): return vector_id meta = {**meta, "artifact": artifact.to_json()} vector = self.embedding_driver.embed(artifact, vector_operation="upsert") return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs) def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool: try: return self.load_entry(vector_id, namespace=namespace) is not None except Exception: return False def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact: result = self.load_entries(namespace=namespace) artifacts = [r.to_artifact() for r in result] return ListArtifact([a for a in artifacts if isinstance(a, TextArtifact)]) @abstractmethod def delete_vector(self, vector_id: str) -> None: ... @abstractmethod def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: ... @abstractmethod def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[Entry]: ... @abstractmethod def load_entries(self, *, namespace: Optional[str] = None) -> list[Entry]: ... def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[Entry]: # TODO: Mark as abstract method for griptape 2.0 raise NotImplementedError(f"{self.__class__.__name__} does not support vector query.") def query( self, query: str | TextArtifact | ImageArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[Entry]: try: vector = self.embedding_driver.embed(query, vector_operation="query") except ValueError as e: raise ValueError( "The Embedding Driver, %s, used by the Vector Store does not support embedding the %s type." "To resolve, provide an Embedding Driver that supports this type.", self.embedding_driver.__class__.__name__, type(query), ) from e return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs) def _get_default_vector_id(self, value: str) -> str: return str(uuid.uuid5(uuid.NAMESPACE_OID, value))
DEFAULT_QUERY_COUNT = 5
class-attribute instance-attributeembedding_driver = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
Entry
Bases:
SerializableMixin
Source Code in griptape/drivers/vector/base_vector_store_driver.py
@define class Entry(SerializableMixin): id: str = field(metadata={"serializable": True}) vector: Optional[list[float]] = field(default=None, metadata={"serializable": True}) score: Optional[float] = field(default=None, metadata={"serializable": True}) meta: Optional[dict] = field(default=None, metadata={"serializable": True}) namespace: Optional[str] = field(default=None, metadata={"serializable": True}) def to_artifact(self) -> BaseArtifact: return BaseArtifact.from_json(self.meta["artifact"]) # pyright: ignore[reportOptionalSubscript]
id = field(metadata={'serializable': True})
class-attribute instance-attributemeta = field(default=None, metadata={'serializable': True})
class-attribute instance-attributenamespace = field(default=None, metadata={'serializable': True})
class-attribute instance-attributescore = field(default=None, metadata={'serializable': True})
class-attribute instance-attributevector = field(default=None, metadata={'serializable': True})
class-attribute instance-attribute
_get_default_vector_id(value)
Source Code in griptape/drivers/vector/base_vector_store_driver.py
def _get_default_vector_id(self, value: str) -> str: return str(uuid.uuid5(uuid.NAMESPACE_OID, value))
delete_vector(vector_id)abstractmethod
Source Code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod def delete_vector(self, vector_id: str) -> None: ...
does_entry_exist(vector_id, *, namespace=None)
Source Code in griptape/drivers/vector/base_vector_store_driver.py
def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool: try: return self.load_entry(vector_id, namespace=namespace) is not None except Exception: return False
load_artifacts(*, namespace=None)
Source Code in griptape/drivers/vector/base_vector_store_driver.py
def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact: result = self.load_entries(namespace=namespace) artifacts = [r.to_artifact() for r in result] return ListArtifact([a for a in artifacts if isinstance(a, TextArtifact)])
load_entries(*, namespace=None)abstractmethod
Source Code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod def load_entries(self, *, namespace: Optional[str] = None) -> list[Entry]: ...
load_entry(vector_id, *, namespace=None)abstractmethod
Source Code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[Entry]: ...
query(query, *, count=None, namespace=None, include_vectors=False, **kwargs)
Source Code in griptape/drivers/vector/base_vector_store_driver.py
def query( self, query: str | TextArtifact | ImageArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[Entry]: try: vector = self.embedding_driver.embed(query, vector_operation="query") except ValueError as e: raise ValueError( "The Embedding Driver, %s, used by the Vector Store does not support embedding the %s type." "To resolve, provide an Embedding Driver that supports this type.", self.embedding_driver.__class__.__name__, type(query), ) from e return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs)
query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)
Source Code in griptape/drivers/vector/base_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[Entry]: # TODO: Mark as abstract method for griptape 2.0 raise NotImplementedError(f"{self.__class__.__name__} does not support vector query.")
upsert(value, *, namespace=None, meta=None, vector_id=None, **kwargs)
Source Code in griptape/drivers/vector/base_vector_store_driver.py
def upsert( self, value: str | TextArtifact | ImageArtifact, *, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: artifact = TextArtifact(value) if isinstance(value, str) else value meta = {} if meta is None else meta if vector_id is None: value = artifact.to_text() if artifact.reference is None else artifact.to_text() + str(artifact.reference) vector_id = self._get_default_vector_id(value) if self.does_entry_exist(vector_id, namespace=namespace): return vector_id meta = {**meta, "artifact": artifact.to_json()} vector = self.embedding_driver.embed(artifact, vector_operation="upsert") return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs)
upsert_collection(artifacts, *, meta=None, **kwargs)
Source Code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_collection( self, artifacts: list[TextArtifact] | list[ImageArtifact] | dict[str, list[TextArtifact]] | dict[str, list[ImageArtifact]], *, meta: Optional[dict] = None, **kwargs, ): with self.create_futures_executor() as futures_executor: if isinstance(artifacts, list): return utils.execute_futures_list( [ futures_executor.submit(with_contextvars(self.upsert), a, namespace=None, meta=meta, **kwargs) for a in artifacts ], ) futures_dict = {} for namespace, artifact_list in artifacts.items(): for a in artifact_list: if not futures_dict.get(namespace): futures_dict[namespace] = [] futures_dict[namespace].append( futures_executor.submit( with_contextvars(self.upsert), a, namespace=namespace, meta=meta, **kwargs ) ) return utils.execute_futures_list_dict(futures_dict)
upsert_text(string, *, namespace=None, meta=None, vector_id=None, **kwargs)
Source Code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_text( self, string: str, *, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: warnings.warn( "`BaseVectorStoreDriver.upsert_text` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) return self.upsert(string, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)
upsert_text_artifact(artifact, *, namespace=None, meta=None, vector_id=None, **kwargs)
Source Code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_text_artifact( self, artifact: TextArtifact, *, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: warnings.warn( "`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) return self.upsert(artifact, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)
upsert_text_artifacts(artifacts, *, meta=None, **kwargs)
Source Code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_text_artifacts( self, artifacts: list[TextArtifact] | dict[str, list[TextArtifact]], *, meta: Optional[dict] = None, **kwargs, ) -> list[str] | dict[str, list[str]]: warnings.warn( "`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert_collection` is a drop-in replacement.", DeprecationWarning, stacklevel=2, ) return self.upsert_collection(artifacts, meta=meta, **kwargs)
upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)abstractmethod
Source Code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: ...
BaseWebScraperDriver
Bases:
ABCSource Code in griptape/drivers/web_scraper/base_web_scraper_driver.py
class BaseWebScraperDriver(ABC): def scrape_url(self, url: str) -> TextArtifact: source = self.fetch_url(url) return self.extract_page(source) @abstractmethod def fetch_url(self, url: str) -> str: ... @abstractmethod def extract_page(self, page: str) -> TextArtifact: ...
extract_page(page)abstractmethod
Source Code in griptape/drivers/web_scraper/base_web_scraper_driver.py
@abstractmethod def extract_page(self, page: str) -> TextArtifact: ...
fetch_url(url)abstractmethod
Source Code in griptape/drivers/web_scraper/base_web_scraper_driver.py
@abstractmethod def fetch_url(self, url: str) -> str: ...
scrape_url(url)
Source Code in griptape/drivers/web_scraper/base_web_scraper_driver.py
def scrape_url(self, url: str) -> TextArtifact: source = self.fetch_url(url) return self.extract_page(source)
BaseWebSearchDriver
Bases:
ABCSource Code in griptape/drivers/web_search/base_web_search_driver.py
@define class BaseWebSearchDriver(ABC): results_count: int = field(default=5, kw_only=True) @abstractmethod def search(self, query: str, **kwargs) -> ListArtifact: ...
results_count = field(default=5, kw_only=True)
class-attribute instance-attribute
search(query, **kwargs)abstractmethod
Source Code in griptape/drivers/web_search/base_web_search_driver.py
@abstractmethod def search(self, query: str, **kwargs) -> ListArtifact: ...
BedrockStableDiffusionImageGenerationModelDriver
Bases:
BaseImageGenerationModelDriver
Attributes
Name | Type | Description |
---|---|---|
cfg_scale | int | Specifies how strictly image generation follows the provided prompt. Defaults to 7. |
mask_source | int | Specifies mask image configuration for image-to-image generations. Defaults to "MASK_IMAGE_BLACK". |
style_preset | Optional[str] | If provided, specifies a specific image generation style preset. |
clip_guidance_preset | Optional[str] | If provided, requests a specific clip guidance preset to be used in the diffusion process. |
sampler | Optional[str] | If provided, requests a specific sampler to be used in the diffusion process. |
steps | Optional[int] | If provided, specifies the number of diffusion steps to use in the image generation. |
start_schedule | Optional[float] | If provided, specifies the start_schedule parameter used to determine the influence of the input image in image-to-image generation. |
Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
@define class BedrockStableDiffusionImageGenerationModelDriver(BaseImageGenerationModelDriver): """Image generation model driver for Stable Diffusion models on Amazon Bedrock. For more information on all supported parameters, see the Stable Diffusion documentation: https://platform.stability.ai/docs/api-reference#tag/v1generation Attributes: cfg_scale: Specifies how strictly image generation follows the provided prompt. Defaults to 7. mask_source: Specifies mask image configuration for image-to-image generations. Defaults to "MASK_IMAGE_BLACK". style_preset: If provided, specifies a specific image generation style preset. clip_guidance_preset: If provided, requests a specific clip guidance preset to be used in the diffusion process. sampler: If provided, requests a specific sampler to be used in the diffusion process. steps: If provided, specifies the number of diffusion steps to use in the image generation. start_schedule: If provided, specifies the start_schedule parameter used to determine the influence of the input image in image-to-image generation. """ cfg_scale: int = field(default=7, kw_only=True, metadata={"serializable": True}) style_preset: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) clip_guidance_preset: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) sampler: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) steps: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) start_schedule: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) def text_to_image_request_parameters( self, prompts: list[str], image_width: int, image_height: int, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: return self._request_parameters( prompts, width=image_width, height=image_height, negative_prompts=negative_prompts, seed=seed, ) def image_variation_request_parameters( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: return self._request_parameters(prompts, image=image, negative_prompts=negative_prompts, seed=seed) def image_inpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: return self._request_parameters( prompts, image=image, mask=mask, mask_source="MASK_IMAGE_BLACK", negative_prompts=negative_prompts, seed=seed, ) def image_outpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: return self._request_parameters( prompts, image=image, mask=mask, mask_source="MASK_IMAGE_WHITE", negative_prompts=negative_prompts, seed=seed, ) def _request_parameters( self, prompts: list[str], width: Optional[int] = None, height: Optional[int] = None, image: Optional[ImageArtifact] = None, mask: Optional[ImageArtifact] = None, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, mask_source: Optional[str] = None, ) -> dict: if negative_prompts is None: negative_prompts = [] text_prompts = [{"text": prompt, "weight": 1.0} for prompt in prompts] text_prompts += [{"text": negative_prompt, "weight": -1.0} for negative_prompt in negative_prompts] request = { "text_prompts": text_prompts, "cfg_scale": self.cfg_scale, "style_preset": self.style_preset, "clip_guidance_preset": self.clip_guidance_preset, "sampler": self.sampler, "steps": self.steps, "seed": seed, "start_schedule": self.start_schedule, } if image is not None: request["init_image"] = image.base64 request["width"] = image.width request["height"] = image.height else: request["width"] = width request["height"] = height if mask is not None: if not mask_source: raise ValueError("mask_source must be provided when mask is provided") request["mask_source"] = mask_source request["mask_image"] = mask.base64 return {k: v for k, v in request.items() if v is not None} def get_generated_image(self, response: dict) -> bytes: image_response = response["artifacts"][0] # finishReason may be SUCCESS, CONTENT_FILTERED, or ERROR. if image_response.get("finishReason") == "ERROR": raise Exception(f"Image generation failed: {image_response.get('finishReason')}") if image_response.get("finishReason") == "CONTENT_FILTERED": logging.warning("Image generation triggered content filter and may be blurred") return base64.decodebytes(bytes(image_response.get("base64"), "utf-8"))
cfg_scale = field(default=7, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeclip_guidance_preset = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesampler = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestart_schedule = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesteps = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestyle_preset = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
_request_parameters(prompts, width=None, height=None, image=None, mask=None, negative_prompts=None, seed=None, mask_source=None)
Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def _request_parameters( self, prompts: list[str], width: Optional[int] = None, height: Optional[int] = None, image: Optional[ImageArtifact] = None, mask: Optional[ImageArtifact] = None, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, mask_source: Optional[str] = None, ) -> dict: if negative_prompts is None: negative_prompts = [] text_prompts = [{"text": prompt, "weight": 1.0} for prompt in prompts] text_prompts += [{"text": negative_prompt, "weight": -1.0} for negative_prompt in negative_prompts] request = { "text_prompts": text_prompts, "cfg_scale": self.cfg_scale, "style_preset": self.style_preset, "clip_guidance_preset": self.clip_guidance_preset, "sampler": self.sampler, "steps": self.steps, "seed": seed, "start_schedule": self.start_schedule, } if image is not None: request["init_image"] = image.base64 request["width"] = image.width request["height"] = image.height else: request["width"] = width request["height"] = height if mask is not None: if not mask_source: raise ValueError("mask_source must be provided when mask is provided") request["mask_source"] = mask_source request["mask_image"] = mask.base64 return {k: v for k, v in request.items() if v is not None}
get_generated_image(response)
Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def get_generated_image(self, response: dict) -> bytes: image_response = response["artifacts"][0] # finishReason may be SUCCESS, CONTENT_FILTERED, or ERROR. if image_response.get("finishReason") == "ERROR": raise Exception(f"Image generation failed: {image_response.get('finishReason')}") if image_response.get("finishReason") == "CONTENT_FILTERED": logging.warning("Image generation triggered content filter and may be blurred") return base64.decodebytes(bytes(image_response.get("base64"), "utf-8"))
image_inpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)
Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def image_inpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: return self._request_parameters( prompts, image=image, mask=mask, mask_source="MASK_IMAGE_BLACK", negative_prompts=negative_prompts, seed=seed, )
image_outpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)
Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def image_outpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: return self._request_parameters( prompts, image=image, mask=mask, mask_source="MASK_IMAGE_WHITE", negative_prompts=negative_prompts, seed=seed, )
image_variation_request_parameters(prompts, image, negative_prompts=None, seed=None)
Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def image_variation_request_parameters( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: return self._request_parameters(prompts, image=image, negative_prompts=negative_prompts, seed=seed)
text_to_image_request_parameters(prompts, image_width, image_height, negative_prompts=None, seed=None)
Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def text_to_image_request_parameters( self, prompts: list[str], image_width: int, image_height: int, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: return self._request_parameters( prompts, width=image_width, height=image_height, negative_prompts=negative_prompts, seed=seed, )
BedrockTitanImageGenerationModelDriver
Bases:
BaseImageGenerationModelDriver
Attributes
Name | Type | Description |
---|---|---|
quality | str | The quality of the generated image, defaults to standard. |
cfg_scale | int | Specifies how strictly image generation follows the provided prompt. Defaults to 7, (1.0 to 10.0]. |
outpainting_mode | str | Specifies the outpainting mode, defaults to PRECISE. |
Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
@define class BedrockTitanImageGenerationModelDriver(BaseImageGenerationModelDriver): """Image Generation Model Driver for Amazon Bedrock Titan Image Generator. Attributes: quality: The quality of the generated image, defaults to standard. cfg_scale: Specifies how strictly image generation follows the provided prompt. Defaults to 7, (1.0 to 10.0]. outpainting_mode: Specifies the outpainting mode, defaults to PRECISE. """ quality: str = field(default="standard", kw_only=True, metadata={"serializable": True}) cfg_scale: int = field(default=7, kw_only=True, metadata={"serializable": True}) outpainting_mode: str = field(default="PRECISE", kw_only=True, metadata={"serializable": True}) def text_to_image_request_parameters( self, prompts: list[str], image_width: int, image_height: int, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: prompt = ", ".join(prompts) request = { "taskType": "TEXT_IMAGE", "textToImageParams": {"text": prompt}, "imageGenerationConfig": { "numberOfImages": 1, "quality": self.quality, "width": image_width, "height": image_height, "cfgScale": self.cfg_scale, }, } if negative_prompts: request["textToImageParams"]["negativeText"] = ", ".join(negative_prompts) if seed: request["imageGenerationConfig"]["seed"] = seed return self._add_common_params(request, image_width, image_height, seed=seed) def image_variation_request_parameters( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: prompt = ", ".join(prompts) request = { "taskType": "IMAGE_VARIATION", "imageVariationParams": {"text": prompt, "images": [image.base64]}, "imageGenerationConfig": { "numberOfImages": 1, "quality": self.quality, "width": image.width, "height": image.height, "cfgScale": self.cfg_scale, }, } if negative_prompts: request["imageVariationParams"]["negativeText"] = ", ".join(negative_prompts) if seed: request["imageGenerationConfig"]["seed"] = seed return self._add_common_params(request, image.width, image.height, seed=seed) def image_inpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: prompt = ", ".join(prompts) request = { "taskType": "INPAINTING", "inPaintingParams": {"text": prompt, "image": image.base64, "maskImage": mask.base64}, } if negative_prompts: request["inPaintingParams"]["negativeText"] = ", ".join(negative_prompts) return self._add_common_params(request, image.width, image.height, seed=seed) def image_outpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: prompt = ", ".join(prompts) request = { "taskType": "OUTPAINTING", "outPaintingParams": { "text": prompt, "image": image.base64, "maskImage": mask.base64, "outPaintingMode": self.outpainting_mode, }, } if negative_prompts: request["outPaintingParams"]["negativeText"] = ", ".join(negative_prompts) return self._add_common_params(request, image.width, image.height, seed=seed) def get_generated_image(self, response: dict) -> bytes: b64_image_data = response["images"][0] return base64.decodebytes(bytes(b64_image_data, "utf-8")) def _add_common_params(self, request: dict[str, Any], width: int, height: int, seed: Optional[int] = None) -> dict: request["imageGenerationConfig"] = { "numberOfImages": 1, "quality": self.quality, "width": width, "height": height, "cfgScale": self.cfg_scale, } if seed: request["imageGenerationConfig"]["seed"] = seed return request
cfg_scale = field(default=7, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeoutpainting_mode = field(default='PRECISE', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributequality = field(default='standard', kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
_add_common_params(request, width, height, seed=None)
Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def _add_common_params(self, request: dict[str, Any], width: int, height: int, seed: Optional[int] = None) -> dict: request["imageGenerationConfig"] = { "numberOfImages": 1, "quality": self.quality, "width": width, "height": height, "cfgScale": self.cfg_scale, } if seed: request["imageGenerationConfig"]["seed"] = seed return request
get_generated_image(response)
Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def get_generated_image(self, response: dict) -> bytes: b64_image_data = response["images"][0] return base64.decodebytes(bytes(b64_image_data, "utf-8"))
image_inpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)
Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def image_inpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: prompt = ", ".join(prompts) request = { "taskType": "INPAINTING", "inPaintingParams": {"text": prompt, "image": image.base64, "maskImage": mask.base64}, } if negative_prompts: request["inPaintingParams"]["negativeText"] = ", ".join(negative_prompts) return self._add_common_params(request, image.width, image.height, seed=seed)
image_outpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)
Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def image_outpainting_request_parameters( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: prompt = ", ".join(prompts) request = { "taskType": "OUTPAINTING", "outPaintingParams": { "text": prompt, "image": image.base64, "maskImage": mask.base64, "outPaintingMode": self.outpainting_mode, }, } if negative_prompts: request["outPaintingParams"]["negativeText"] = ", ".join(negative_prompts) return self._add_common_params(request, image.width, image.height, seed=seed)
image_variation_request_parameters(prompts, image, negative_prompts=None, seed=None)
Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def image_variation_request_parameters( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: prompt = ", ".join(prompts) request = { "taskType": "IMAGE_VARIATION", "imageVariationParams": {"text": prompt, "images": [image.base64]}, "imageGenerationConfig": { "numberOfImages": 1, "quality": self.quality, "width": image.width, "height": image.height, "cfgScale": self.cfg_scale, }, } if negative_prompts: request["imageVariationParams"]["negativeText"] = ", ".join(negative_prompts) if seed: request["imageGenerationConfig"]["seed"] = seed return self._add_common_params(request, image.width, image.height, seed=seed)
text_to_image_request_parameters(prompts, image_width, image_height, negative_prompts=None, seed=None)
Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def text_to_image_request_parameters( self, prompts: list[str], image_width: int, image_height: int, negative_prompts: Optional[list[str]] = None, seed: Optional[int] = None, ) -> dict: prompt = ", ".join(prompts) request = { "taskType": "TEXT_IMAGE", "textToImageParams": {"text": prompt}, "imageGenerationConfig": { "numberOfImages": 1, "quality": self.quality, "width": image_width, "height": image_height, "cfgScale": self.cfg_scale, }, } if negative_prompts: request["textToImageParams"]["negativeText"] = ", ".join(negative_prompts) if seed: request["imageGenerationConfig"]["seed"] = seed return self._add_common_params(request, image_width, image_height, seed=seed)
CohereEmbeddingDriver
Bases:
BaseEmbeddingDriver
Attributes
Name | Type | Description |
---|---|---|
api_key | str | Cohere API key. |
model | str | Cohere model name. |
client | Client | Custom cohere.Client . |
tokenizer | CohereTokenizer | Custom CohereTokenizer . |
input_type | str | Cohere embedding input type. |
Source Code in griptape/drivers/embedding/cohere_embedding_driver.py
@define class CohereEmbeddingDriver(BaseEmbeddingDriver): """Cohere Embedding Driver. Attributes: api_key: Cohere API key. model: Cohere model name. client: Custom `cohere.Client`. tokenizer: Custom `CohereTokenizer`. input_type: Cohere embedding input type. """ DEFAULT_MODEL = "models/embedding-001" api_key: str = field(kw_only=True, metadata={"serializable": False}) input_type: str = field(kw_only=True, metadata={"serializable": True}) _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: CohereTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True, ) @lazy_property() def client(self) -> Client: return import_optional_dependency("cohere").Client(self.api_key) def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type) if isinstance(result.embeddings, list): return result.embeddings[0] raise ValueError("Non-float embeddings are not supported.")
DEFAULT_MODEL = 'models/embedding-001'
class-attribute instance-attribute_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeinput_type = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True)
class-attribute instance-attribute
client()
Source Code in griptape/drivers/embedding/cohere_embedding_driver.py
@lazy_property() def client(self) -> Client: return import_optional_dependency("cohere").Client(self.api_key)
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/cohere_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type) if isinstance(result.embeddings, list): return result.embeddings[0] raise ValueError("Non-float embeddings are not supported.")
CoherePromptDriver
Bases:
BasePromptDriver
Attributes
Name | Type | Description |
---|---|---|
api_key | str | Cohere API key. |
model | str | Cohere model name. |
client | ClientV2 | Custom cohere.Client . |
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@define(kw_only=True) class CoherePromptDriver(BasePromptDriver): """Cohere Prompt Driver. Attributes: api_key: Cohere API key. model: Cohere model name. client: Custom `cohere.Client`. """ api_key: str = field(default=None, metadata={"serializable": False}) model: str = field(metadata={"serializable": True}) force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Optional[ClientV2] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), ) @lazy_property() def client(self) -> ClientV2: return import_optional_dependency("cohere").ClientV2(self.api_key, log_warning_experimental_features=False) @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) result: ChatResponse = self.client.chat(**params) logger.debug(result.model_dump()) return Message( content=self.__to_prompt_stack_message_content(result.message), role=Message.ASSISTANT_ROLE, usage=Message.Usage( input_tokens=result.usage.tokens.input_tokens if result.usage and result.usage.tokens else 0, output_tokens=result.usage.tokens.output_tokens if result.usage and result.usage.tokens else 0, ), ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = self._base_params(prompt_stack) logger.debug(params) result: Iterator[StreamedChatResponseV2] = self.client.chat_stream(**params) for event in result: if event.type == "stream-end": usage = event.response.meta.tokens yield DeltaMessage( usage=DeltaMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) elif event.type in ("tool-plan-delta", "content-delta", "tool-call-start", "tool-call-delta"): yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event)) def _base_params(self, prompt_stack: PromptStack) -> dict: tool_results = [] messages = self.__to_cohere_messages(prompt_stack.messages) params = { "model": self.model, "messages": messages, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, **({"tool_results": tool_results} if tool_results else {}), **self.extra_params, } if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": params["response_format"] = {"type": "json_object", "schema": prompt_stack.to_output_json_schema()} if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_cohere_tools(prompt_stack.tools) return params def __to_cohere_messages(self, messages: list[Message]) -> list[dict]: cohere_messages = [] for message in messages: # If the message only contains textual content we can send it as a single content. if message.is_text(): cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()}) # Action results must be sent as separate messages. elif message.has_any_content_type(ActionResultMessageContent): cohere_messages.extend( { "role": self.__to_cohere_role(message, action_result), "content": self.__to_cohere_message_content(action_result), "tool_call_id": action_result.action.tag, } for action_result in message.get_content_type(ActionResultMessageContent) ) if message.has_any_content_type(TextMessageContent): cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()}) else: cohere_message = { "role": self.__to_cohere_role(message), "content": [ self.__to_cohere_message_content(content) for content in [ content for content in message.content if not isinstance(content, ActionCallMessageContent) ] ], } # Action calls must be attached to the message, not sent as content. action_call_content = [ content for content in message.content if isinstance(content, ActionCallMessageContent) ] if action_call_content: cohere_message["tool_calls"] = [ self.__to_cohere_message_content(action_call) for action_call in action_call_content ] cohere_messages.append(cohere_message) return cohere_messages def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict | list[dict]: if isinstance(content, ActionCallMessageContent): action = content.artifact.value return { "type": "function", "id": action.tag, "function": { "name": action.to_native_tool_name(), "arguments": json.dumps(action.input), }, } if isinstance(content, ActionResultMessageContent): artifact = content.artifact if isinstance(artifact, ListArtifact): message_content = [{"type": "text", "text": artifact.to_text()} for artifact in artifact.value] else: message_content = {"type": "text", "text": artifact.to_text()} return message_content return {"type": "text", "text": content.artifact.to_text()} def __to_cohere_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str: if message.is_system(): return "system" if message.is_assistant(): return "assistant" if isinstance(message_content, ActionResultMessageContent): return "tool" return "user" def __to_cohere_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "function": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "parameters": tool.to_activity_json_schema(activity, "Parameters Schema"), }, "type": "function", } for tool in tools for activity in tool.activities() ] def __to_prompt_stack_message_content(self, response: AssistantMessageResponse) -> list[BaseMessageContent]: content = [] if response.content: content.extend([TextMessageContent(TextArtifact(content.text)) for content in response.content]) if response.tool_plan: content.append(TextMessageContent(TextArtifact(response.tool_plan))) if response.tool_calls is not None: content.extend( [ ActionCallMessageContent( ActionArtifact( ToolAction( tag=tool_call.id, name=ToolAction.from_native_tool_name(tool_call.function.name)[0], path=ToolAction.from_native_tool_name(tool_call.function.name)[1], input=json.loads(tool_call.function.arguments), ), ), ) for tool_call in response.tool_calls if tool_call.id is not None and tool_call.function is not None and tool_call.function.name is not None and tool_call.function.arguments is not None ], ) return content def __to_prompt_stack_delta_message_content(self, event: Any) -> BaseDeltaMessageContent: if event.type == "content-delta": return TextDeltaMessageContent(event.delta.message.content.text, index=0) if event.type == "tool-plan-delta": return TextDeltaMessageContent(event.delta.message["tool_plan"]) if event.type == "tool-call-start": tool_call_delta = event.delta.message["tool_calls"] name, path = ToolAction.from_native_tool_name(tool_call_delta["function"]["name"]) return ActionCallDeltaMessageContent(tag=tool_call_delta["id"], name=name, path=path) if event.type == "tool-call-delta": tool_call_delta = event.delta.message["tool_calls"]["function"] return ActionCallDeltaMessageContent(partial_input=tool_call_delta["arguments"]) raise ValueError(f"Unsupported event type: {event.type}")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, metadata={'serializable': False})
class-attribute instance-attributeforce_single_step = field(default=False, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True))
class-attribute instance-attributeuse_native_tools = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
__to_cohere_message_content(content)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict | list[dict]: if isinstance(content, ActionCallMessageContent): action = content.artifact.value return { "type": "function", "id": action.tag, "function": { "name": action.to_native_tool_name(), "arguments": json.dumps(action.input), }, } if isinstance(content, ActionResultMessageContent): artifact = content.artifact if isinstance(artifact, ListArtifact): message_content = [{"type": "text", "text": artifact.to_text()} for artifact in artifact.value] else: message_content = {"type": "text", "text": artifact.to_text()} return message_content return {"type": "text", "text": content.artifact.to_text()}
__to_cohere_messages(messages)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_messages(self, messages: list[Message]) -> list[dict]: cohere_messages = [] for message in messages: # If the message only contains textual content we can send it as a single content. if message.is_text(): cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()}) # Action results must be sent as separate messages. elif message.has_any_content_type(ActionResultMessageContent): cohere_messages.extend( { "role": self.__to_cohere_role(message, action_result), "content": self.__to_cohere_message_content(action_result), "tool_call_id": action_result.action.tag, } for action_result in message.get_content_type(ActionResultMessageContent) ) if message.has_any_content_type(TextMessageContent): cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()}) else: cohere_message = { "role": self.__to_cohere_role(message), "content": [ self.__to_cohere_message_content(content) for content in [ content for content in message.content if not isinstance(content, ActionCallMessageContent) ] ], } # Action calls must be attached to the message, not sent as content. action_call_content = [ content for content in message.content if isinstance(content, ActionCallMessageContent) ] if action_call_content: cohere_message["tool_calls"] = [ self.__to_cohere_message_content(action_call) for action_call in action_call_content ] cohere_messages.append(cohere_message) return cohere_messages
__to_cohere_role(message, message_content=None)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str: if message.is_system(): return "system" if message.is_assistant(): return "assistant" if isinstance(message_content, ActionResultMessageContent): return "tool" return "user"
__to_cohere_tools(tools)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "function": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "parameters": tool.to_activity_json_schema(activity, "Parameters Schema"), }, "type": "function", } for tool in tools for activity in tool.activities() ]
__to_prompt_stack_delta_message_content(event)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, event: Any) -> BaseDeltaMessageContent: if event.type == "content-delta": return TextDeltaMessageContent(event.delta.message.content.text, index=0) if event.type == "tool-plan-delta": return TextDeltaMessageContent(event.delta.message["tool_plan"]) if event.type == "tool-call-start": tool_call_delta = event.delta.message["tool_calls"] name, path = ToolAction.from_native_tool_name(tool_call_delta["function"]["name"]) return ActionCallDeltaMessageContent(tag=tool_call_delta["id"], name=name, path=path) if event.type == "tool-call-delta": tool_call_delta = event.delta.message["tool_calls"]["function"] return ActionCallDeltaMessageContent(partial_input=tool_call_delta["arguments"]) raise ValueError(f"Unsupported event type: {event.type}")
__to_prompt_stack_message_content(response)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_prompt_stack_message_content(self, response: AssistantMessageResponse) -> list[BaseMessageContent]: content = [] if response.content: content.extend([TextMessageContent(TextArtifact(content.text)) for content in response.content]) if response.tool_plan: content.append(TextMessageContent(TextArtifact(response.tool_plan))) if response.tool_calls is not None: content.extend( [ ActionCallMessageContent( ActionArtifact( ToolAction( tag=tool_call.id, name=ToolAction.from_native_tool_name(tool_call.function.name)[0], path=ToolAction.from_native_tool_name(tool_call.function.name)[1], input=json.loads(tool_call.function.arguments), ), ), ) for tool_call in response.tool_calls if tool_call.id is not None and tool_call.function is not None and tool_call.function.name is not None and tool_call.function.arguments is not None ], ) return content
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: tool_results = [] messages = self.__to_cohere_messages(prompt_stack.messages) params = { "model": self.model, "messages": messages, "temperature": self.temperature, "stop_sequences": self.tokenizer.stop_sequences, "max_tokens": self.max_tokens, **({"tool_results": tool_results} if tool_results else {}), **self.extra_params, } if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": params["response_format"] = {"type": "json_object", "schema": prompt_stack.to_output_json_schema()} if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_cohere_tools(prompt_stack.tools) return params
client()
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@lazy_property() def client(self) -> ClientV2: return import_optional_dependency("cohere").ClientV2(self.api_key, log_warning_experimental_features=False)
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) result: ChatResponse = self.client.chat(**params) logger.debug(result.model_dump()) return Message( content=self.__to_prompt_stack_message_content(result.message), role=Message.ASSISTANT_ROLE, usage=Message.Usage( input_tokens=result.usage.tokens.input_tokens if result.usage and result.usage.tokens else 0, output_tokens=result.usage.tokens.output_tokens if result.usage and result.usage.tokens else 0, ), )
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = self._base_params(prompt_stack) logger.debug(params) result: Iterator[StreamedChatResponseV2] = self.client.chat_stream(**params) for event in result: if event.type == "stream-end": usage = event.response.meta.tokens yield DeltaMessage( usage=DeltaMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens), ) elif event.type in ("tool-plan-delta", "content-delta", "tool-call-start", "tool-call-delta"): yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
CohereRerankDriver
Bases:
BaseRerankDriver
Source Code in griptape/drivers/rerank/cohere_rerank_driver.py
@define(kw_only=True) class CohereRerankDriver(BaseRerankDriver): model: str = field(default="rerank-english-v3.0", metadata={"serializable": True}) top_n: Optional[int] = field(default=None) api_key: str = field(metadata={"serializable": True}) client: Client = field( default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True), ) def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: # Cohere errors out if passed "empty" documents or no documents at all artifacts_dict = {str(hash(a.to_text())): a for a in artifacts if a} if artifacts_dict: response = self.client.rerank( model=self.model, query=query, documents=[a.to_text() for a in artifacts_dict.values()], return_documents=True, top_n=self.top_n, ) return [artifacts_dict[str(hash(r.document.text))] for r in response.results if r.document is not None] return []
api_key = field(metadata={'serializable': True})
class-attribute instance-attributeclient = field(default=Factory(lambda self: import_optional_dependency('cohere').Client(self.api_key), takes_self=True))
class-attribute instance-attributemodel = field(default='rerank-english-v3.0', metadata={'serializable': True})
class-attribute instance-attributetop_n = field(default=None)
class-attribute instance-attribute
run(query, artifacts)
Source Code in griptape/drivers/rerank/cohere_rerank_driver.py
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: # Cohere errors out if passed "empty" documents or no documents at all artifacts_dict = {str(hash(a.to_text())): a for a in artifacts if a} if artifacts_dict: response = self.client.rerank( model=self.model, query=query, documents=[a.to_text() for a in artifacts_dict.values()], return_documents=True, top_n=self.top_n, ) return [artifacts_dict[str(hash(r.document.text))] for r in response.results if r.document is not None] return []
DatadogObservabilityDriver
Bases:
OpenTelemetryObservabilityDriver
Source Code in griptape/drivers/observability/datadog_observability_driver.py
@define class DatadogObservabilityDriver(OpenTelemetryObservabilityDriver): datadog_agent_endpoint: str = field( default=Factory(lambda: os.getenv("DD_AGENT_ENDPOINT", "http://localhost:4318")), kw_only=True ) span_processor: SpanProcessor = field( default=Factory( lambda self: import_optional_dependency("opentelemetry.sdk.trace.export").BatchSpanProcessor( import_optional_dependency("opentelemetry.exporter.otlp.proto.http.trace_exporter").OTLPSpanExporter( endpoint=f"{self.datadog_agent_endpoint}/v1/traces" ) ), takes_self=True, ), kw_only=True, )
datadog_agent_endpoint = field(default=Factory(lambda: os.getenv('DD_AGENT_ENDPOINT', 'http://localhost:4318')), kw_only=True)
class-attribute instance-attributespan_processor = field(default=Factory(lambda self: import_optional_dependency('opentelemetry.sdk.trace.export').BatchSpanProcessor(import_optional_dependency('opentelemetry.exporter.otlp.proto.http.trace_exporter').OTLPSpanExporter(endpoint=f'{self.datadog_agent_endpoint}/v1/traces')), takes_self=True), kw_only=True)
class-attribute instance-attribute
DuckDuckGoWebSearchDriver
Bases:
BaseWebSearchDriver
Source Code in griptape/drivers/web_search/duck_duck_go_web_search_driver.py
@define class DuckDuckGoWebSearchDriver(BaseWebSearchDriver): language: str = field(default="en", kw_only=True) country: str = field(default="us", kw_only=True) _client: Optional[DDGS] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> DDGS: return import_optional_dependency("duckduckgo_search").DDGS() def search(self, query: str, **kwargs) -> ListArtifact: try: results = self.client.text( query, region=f"{self.language}-{self.country}", max_results=self.results_count, **kwargs ) return ListArtifact( [ TextArtifact( json.dumps({"title": result["title"], "url": result["href"], "description": result["body"]}), ) for result in results ], ) except Exception as e: raise Exception(f"Error searching '{query}' with DuckDuckGo: {e}") from e
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributecountry = field(default='us', kw_only=True)
class-attribute instance-attributelanguage = field(default='en', kw_only=True)
class-attribute instance-attribute
client()
Source Code in griptape/drivers/web_search/duck_duck_go_web_search_driver.py
@lazy_property() def client(self) -> DDGS: return import_optional_dependency("duckduckgo_search").DDGS()
search(query, **kwargs)
Source Code in griptape/drivers/web_search/duck_duck_go_web_search_driver.py
def search(self, query: str, **kwargs) -> ListArtifact: try: results = self.client.text( query, region=f"{self.language}-{self.country}", max_results=self.results_count, **kwargs ) return ListArtifact( [ TextArtifact( json.dumps({"title": result["title"], "url": result["href"], "description": result["body"]}), ) for result in results ], ) except Exception as e: raise Exception(f"Error searching '{query}' with DuckDuckGo: {e}") from e
DummyAudioTranscriptionDriver
Bases:
BaseAudioTranscriptionDriver
Source Code in griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py
@define class DummyAudioTranscriptionDriver(BaseAudioTranscriptionDriver): model: str = field(init=False) def try_run(self, audio: AudioArtifact, prompts: Optional[list] = None) -> TextArtifact: raise DummyError(__class__.__name__, "try_transcription")
model = field(init=False)
class-attribute instance-attribute
try_run(audio, prompts=None)
Source Code in griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py
def try_run(self, audio: AudioArtifact, prompts: Optional[list] = None) -> TextArtifact: raise DummyError(__class__.__name__, "try_transcription")
DummyEmbeddingDriver
Bases:
BaseEmbeddingDriver
Source Code in griptape/drivers/embedding/dummy_embedding_driver.py
@define class DummyEmbeddingDriver(BaseEmbeddingDriver): model: None = field(init=False, default=None, kw_only=True) def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: raise DummyError(__class__.__name__, "try_embed_chunk")
model = field(init=False, default=None, kw_only=True)
class-attribute instance-attribute
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/dummy_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: raise DummyError(__class__.__name__, "try_embed_chunk")
DummyImageGenerationDriver
Bases:
BaseImageGenerationDriver
Source Code in griptape/drivers/image_generation/dummy_image_generation_driver.py
@define class DummyImageGenerationDriver(BaseImageGenerationDriver): model: None = field(init=False, default=None, kw_only=True) def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: raise DummyError(__class__.__name__, "try_text_to_image") def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise DummyError(__class__.__name__, "try_image_variation") def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise DummyError(__class__.__name__, "try_image_inpainting") def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise DummyError(__class__.__name__, "try_image_outpainting")
model = field(init=False, default=None, kw_only=True)
class-attribute instance-attribute
try_image_inpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise DummyError(__class__.__name__, "try_image_inpainting")
try_image_outpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise DummyError(__class__.__name__, "try_image_outpainting")
try_image_variation(prompts, image, negative_prompts=None)
Source Code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise DummyError(__class__.__name__, "try_image_variation")
try_text_to_image(prompts, negative_prompts=None)
Source Code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: raise DummyError(__class__.__name__, "try_text_to_image")
DummyPromptDriver
Bases:
BasePromptDriver
Source Code in griptape/drivers/prompt/dummy_prompt_driver.py
@define class DummyPromptDriver(BasePromptDriver): model: None = field(init=False, default=None, kw_only=True) tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True) @observable def try_run(self, prompt_stack: PromptStack) -> Message: raise DummyError(__class__.__name__, "try_run") @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise DummyError(__class__.__name__, "try_stream")
model = field(init=False, default=None, kw_only=True)
class-attribute instance-attributetokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True)
class-attribute instance-attribute
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/dummy_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: raise DummyError(__class__.__name__, "try_run")
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/dummy_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise DummyError(__class__.__name__, "try_stream")
DummyTextToSpeechDriver
Bases:
BaseTextToSpeechDriver
Source Code in griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py
@define class DummyTextToSpeechDriver(BaseTextToSpeechDriver): model: None = field(init=False, default=None, kw_only=True) def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: raise DummyError(__class__.__name__, "try_text_to_audio")
model = field(init=False, default=None, kw_only=True)
class-attribute instance-attribute
try_text_to_audio(prompts)
Source Code in griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py
def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: raise DummyError(__class__.__name__, "try_text_to_audio")
DummyVectorStoreDriver
Bases:
BaseVectorStoreDriver
Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
@define() class DummyVectorStoreDriver(BaseVectorStoreDriver): embedding_driver: BaseEmbeddingDriver = field( kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True}, ) def delete_vector(self, vector_id: str) -> None: raise DummyError(__class__.__name__, "delete_vector") def upsert_vector( self, vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: raise DummyError(__class__.__name__, "upsert_vector") def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: raise DummyError(__class__.__name__, "load_entry") def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: raise DummyError(__class__.__name__, "load_entries") def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: raise DummyError(__class__.__name__, "query_vector") def query( self, query: str | TextArtifact | ImageArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: raise DummyError(__class__.__name__, "query")
embedding_driver = field(kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={'serializable': True})
class-attribute instance-attribute
delete_vector(vector_id)
Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def delete_vector(self, vector_id: str) -> None: raise DummyError(__class__.__name__, "delete_vector")
load_entries(*, namespace=None)
Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: raise DummyError(__class__.__name__, "load_entries")
load_entry(vector_id, *, namespace=None)
Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: raise DummyError(__class__.__name__, "load_entry")
query(query, *, count=None, namespace=None, include_vectors=False, **kwargs)
Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def query( self, query: str | TextArtifact | ImageArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: raise DummyError(__class__.__name__, "query")
query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)
Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: raise DummyError(__class__.__name__, "query_vector")
upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def upsert_vector( self, vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: raise DummyError(__class__.__name__, "upsert_vector")
ElevenLabsTextToSpeechDriver
Bases:
BaseTextToSpeechDriver
Source Code in griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py
@define class ElevenLabsTextToSpeechDriver(BaseTextToSpeechDriver): api_key: str = field(kw_only=True, metadata={"serializable": True}) voice: str = field(kw_only=True, metadata={"serializable": True}) output_format: str = field(default="mp3_44100_128", kw_only=True, metadata={"serializable": True}) _client: Optional[ElevenLabs] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> ElevenLabs: return import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key) def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: audio = self.client.generate( text=". ".join(prompts), voice=self.voice, model=self.model, output_format=self.output_format, ) content = b"" for chunk in audio: content += chunk # All ElevenLabs audio format strings have the following structure: # {format}_{sample_rate}_{bitrate} artifact_format = self.output_format.split("_")[0] return AudioArtifact(value=content, format=artifact_format)
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeoutput_format = field(default='mp3_44100_128', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributevoice = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py
@lazy_property() def client(self) -> ElevenLabs: return import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key)
try_text_to_audio(prompts)
Source Code in griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py
def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: audio = self.client.generate( text=". ".join(prompts), voice=self.voice, model=self.model, output_format=self.output_format, ) content = b"" for chunk in audio: content += chunk # All ElevenLabs audio format strings have the following structure: # {format}_{sample_rate}_{bitrate} artifact_format = self.output_format.split("_")[0] return AudioArtifact(value=content, format=artifact_format)
ExaWebSearchDriver
Bases:
BaseWebSearchDriver
Source Code in griptape/drivers/web_search/exa_web_search_driver.py
@define class ExaWebSearchDriver(BaseWebSearchDriver): api_key: str = field(kw_only=True, default=None) highlights: bool = field(default=False, kw_only=True) use_autoprompt: bool = field(default=False, kw_only=True) params: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) _client: Optional[Exa] = field(default=None, kw_only=True, alias="client") @lazy_property() def client(self) -> Exa: return import_optional_dependency("exa_py").Exa(api_key=self.api_key) def search(self, query: str, **kwargs) -> ListArtifact[JsonArtifact]: response = self.client.search_and_contents( # pyright: ignore[reportCallIssue] highlights=self.highlights, # pyright: ignore[reportArgumentType] use_autoprompt=self.use_autoprompt, query=query, num_results=self.results_count, text=True, **self.params, **kwargs, ) results = [ {"title": result.title, "url": result.url, "highlights": result.highlights, "text": result.text} for result in response.results ] return ListArtifact([JsonArtifact(result) for result in results])
_client = field(default=None, kw_only=True, alias='client')
class-attribute instance-attributeapi_key = field(kw_only=True, default=None)
class-attribute instance-attributehighlights = field(default=False, kw_only=True)
class-attribute instance-attributeparams = field(factory=dict, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeuse_autoprompt = field(default=False, kw_only=True)
class-attribute instance-attribute
client()
Source Code in griptape/drivers/web_search/exa_web_search_driver.py
@lazy_property() def client(self) -> Exa: return import_optional_dependency("exa_py").Exa(api_key=self.api_key)
search(query, **kwargs)
Source Code in griptape/drivers/web_search/exa_web_search_driver.py
def search(self, query: str, **kwargs) -> ListArtifact[JsonArtifact]: response = self.client.search_and_contents( # pyright: ignore[reportCallIssue] highlights=self.highlights, # pyright: ignore[reportArgumentType] use_autoprompt=self.use_autoprompt, query=query, num_results=self.results_count, text=True, **self.params, **kwargs, ) results = [ {"title": result.title, "url": result.url, "highlights": result.highlights, "text": result.text} for result in response.results ] return ListArtifact([JsonArtifact(result) for result in results])
GoogleEmbeddingDriver
Bases:
BaseEmbeddingDriver
Attributes
Name | Type | Description |
---|---|---|
api_key | Optional[str] | Google API key. |
model | str | Google model name. |
task_type | str | Embedding model task type (https://ai.google.dev/tutorials/python\_quickstart#use\_embeddings). Defaults to retrieval_document . |
title | Optional[str] | Optional title for the content. Only works with retrieval_document task type. |
Source Code in griptape/drivers/embedding/google_embedding_driver.py
@define class GoogleEmbeddingDriver(BaseEmbeddingDriver): """Google Embedding Driver. Attributes: api_key: Google API key. model: Google model name. task_type: Embedding model task type (https://ai.google.dev/tutorials/python_quickstart#use_embeddings). Defaults to `retrieval_document`. title: Optional title for the content. Only works with `retrieval_document` task type. """ DEFAULT_MODEL = "models/embedding-001" model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) task_type: str = field(default="retrieval_document", kw_only=True, metadata={"serializable": True}) title: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) result = genai.embed_content(model=self.model, content=chunk, task_type=self.task_type, title=self.title) return result["embedding"] def _params(self, chunk: str) -> dict: return {"input": chunk, "model": self.model}
DEFAULT_MODEL = 'models/embedding-001'
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributemodel = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetask_type = field(default='retrieval_document', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetitle = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
_params(chunk)
Source Code in griptape/drivers/embedding/google_embedding_driver.py
def _params(self, chunk: str) -> dict: return {"input": chunk, "model": self.model}
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/google_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) result = genai.embed_content(model=self.model, content=chunk, task_type=self.task_type, title=self.title) return result["embedding"]
GooglePromptDriver
Bases:
BasePromptDriver
Attributes
Name | Type | Description |
---|---|---|
api_key | Optional[str] | Google API key. |
model | str | Google model name. |
client | GenerativeModel | Custom GenerativeModel client. |
top_p | Optional[float] | Optional value for top_p. |
top_k | Optional[int] | Optional value for top_k. |
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@define class GooglePromptDriver(BasePromptDriver): """Google Prompt Driver. Attributes: api_key: Google API key. model: Google model name. client: Custom `GenerativeModel` client. top_p: Optional value for top_p. top_k: Optional value for top_k. """ api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) model: str = field(kw_only=True, metadata={"serializable": True}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True, ) top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: StructuredOutputStrategy = field( default="tool", kw_only=True, metadata={"serializable": True} ) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) _client: Optional[GenerativeModel] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @lazy_property() def client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) return genai.GenerativeModel(self.model) @observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self.__to_google_messages(prompt_stack) params = self._base_params(prompt_stack) logger.debug((messages, params["generation_config"].__dict__)) response: GenerateContentResponse = self.client.generate_content(messages, **params) logger.debug(response.to_dict()) usage_metadata = response.usage_metadata return Message( content=[self.__to_prompt_stack_message_content(part) for part in response.parts], role=Message.ASSISTANT_ROLE, usage=Message.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count, ), ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: messages = self.__to_google_messages(prompt_stack) params = {**self._base_params(prompt_stack), "stream": True} logger.debug((messages, params)) response: GenerateContentResponse = self.client.generate_content( messages, **params, ) prompt_token_count = None for chunk in response: logger.debug(chunk.to_dict()) usage_metadata = chunk.usage_metadata content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None # Only want to output the prompt token count once since it is static each chunk if prompt_token_count is None: prompt_token_count = usage_metadata.prompt_token_count yield DeltaMessage( content=content, usage=DeltaMessage.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count, ), ) else: yield DeltaMessage( content=content, usage=DeltaMessage.Usage(output_tokens=usage_metadata.candidates_token_count), ) def _base_params(self, prompt_stack: PromptStack) -> dict: types = import_optional_dependency("google.generativeai.types") protos = import_optional_dependency("google.generativeai.protos") system_messages = prompt_stack.system_messages if system_messages: self.client._system_instruction = types.ContentDict( role="system", parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) params = { "generation_config": types.GenerationConfig( **{ # For some reason, providing stop sequences when streaming breaks native functions # https://github.com/google-gemini/generative-ai-python/issues/446 "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences, "max_output_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, **self.extra_params, }, ), } if prompt_stack.tools and self.use_native_tools: params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_config"]["function_calling_config"]["mode"] = "auto" params["tools"] = self.__to_google_tools(prompt_stack.tools) return params def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") return [ types.ContentDict( { "role": self.__to_google_role(message), "parts": [self.__to_google_message_content(content) for content in message.content], }, ) for message in prompt_stack.messages if not message.is_system() ] def __to_google_role(self, message: Message) -> str: if message.is_assistant(): return "model" return "user" def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]: types = import_optional_dependency("google.generativeai.types") tool_declarations = [] for tool in tools: for activity in tool.activities(): schema = tool.to_activity_json_schema(activity, "Parameters Schema") if "values" in schema["properties"]: schema = schema["properties"]["values"] schema = remove_key_in_dict_recursively(schema, "additionalProperties") tool_declaration = types.FunctionDeclaration( name=tool.to_native_tool_name(activity), description=tool.activity_description(activity), **( { "parameters": { "type": schema["type"], "properties": schema["properties"], "required": schema.get("required", []), } } if schema.get("properties") else {} ), ) tool_declarations.append(tool_declaration) return tool_declarations def __to_google_message_content(self, content: BaseMessageContent) -> ContentDict | Part | str: types = import_optional_dependency("google.generativeai.types") protos = import_optional_dependency("google.generativeai.protos") if isinstance(content, TextMessageContent): return content.artifact.to_text() if isinstance(content, ImageMessageContent): if isinstance(content.artifact, ImageArtifact): return types.ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value) # TODO: Google requires uploading to the files endpoint: https://ai.google.dev/gemini-api/docs/image-understanding#upload-image # Can be worked around by using GenericMessageContent, similar to videos. raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}") if isinstance(content, ActionCallMessageContent): action = content.artifact.value return protos.Part(function_call=protos.FunctionCall(name=action.tag, args=action.input)) if isinstance(content, ActionResultMessageContent): artifact = content.artifact return protos.Part( function_response=protos.FunctionResponse( name=content.action.to_native_tool_name(), response=artifact.to_dict(), ), ) if isinstance(content, GenericMessageContent): return content.artifact.value raise ValueError(f"Unsupported prompt stack content type: {type(content)}") def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent: json_format = import_optional_dependency("google.protobuf.json_format") if content.text: return TextMessageContent(TextArtifact(content.text)) if content.function_call: function_call = content.function_call name, path = ToolAction.from_native_tool_name(function_call.name) args = json_format.MessageToDict(function_call._pb).get("args", {}) return ActionCallMessageContent( artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args)), ) raise ValueError(f"Unsupported message content type {content}") def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent: json_format = import_optional_dependency("google.protobuf.json_format") if content.text: return TextDeltaMessageContent(content.text) if content.function_call: function_call = content.function_call name, path = ToolAction.from_native_tool_name(function_call.name) args = json_format.MessageToDict(function_call._pb).get("args", {}) return ActionCallDeltaMessageContent( tag=function_call.name, name=name, path=path, partial_input=json.dumps(args), ) raise ValueError(f"Unsupported message content type {content}")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributemodel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestructured_output_strategy = field(default='tool', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True)
class-attribute instance-attributetool_choice = field(default='auto', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetop_k = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetop_p = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeuse_native_tools = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
__to_google_message_content(content)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_message_content(self, content: BaseMessageContent) -> ContentDict | Part | str: types = import_optional_dependency("google.generativeai.types") protos = import_optional_dependency("google.generativeai.protos") if isinstance(content, TextMessageContent): return content.artifact.to_text() if isinstance(content, ImageMessageContent): if isinstance(content.artifact, ImageArtifact): return types.ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value) # TODO: Google requires uploading to the files endpoint: https://ai.google.dev/gemini-api/docs/image-understanding#upload-image # Can be worked around by using GenericMessageContent, similar to videos. raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}") if isinstance(content, ActionCallMessageContent): action = content.artifact.value return protos.Part(function_call=protos.FunctionCall(name=action.tag, args=action.input)) if isinstance(content, ActionResultMessageContent): artifact = content.artifact return protos.Part( function_response=protos.FunctionResponse( name=content.action.to_native_tool_name(), response=artifact.to_dict(), ), ) if isinstance(content, GenericMessageContent): return content.artifact.value raise ValueError(f"Unsupported prompt stack content type: {type(content)}")
__to_google_messages(prompt_stack)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: types = import_optional_dependency("google.generativeai.types") return [ types.ContentDict( { "role": self.__to_google_role(message), "parts": [self.__to_google_message_content(content) for content in message.content], }, ) for message in prompt_stack.messages if not message.is_system() ]
__to_google_role(message)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_role(self, message: Message) -> str: if message.is_assistant(): return "model" return "user"
__to_google_tools(tools)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]: types = import_optional_dependency("google.generativeai.types") tool_declarations = [] for tool in tools: for activity in tool.activities(): schema = tool.to_activity_json_schema(activity, "Parameters Schema") if "values" in schema["properties"]: schema = schema["properties"]["values"] schema = remove_key_in_dict_recursively(schema, "additionalProperties") tool_declaration = types.FunctionDeclaration( name=tool.to_native_tool_name(activity), description=tool.activity_description(activity), **( { "parameters": { "type": schema["type"], "properties": schema["properties"], "required": schema.get("required", []), } } if schema.get("properties") else {} ), ) tool_declarations.append(tool_declaration) return tool_declarations
__to_prompt_stack_delta_message_content(content)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent: json_format = import_optional_dependency("google.protobuf.json_format") if content.text: return TextDeltaMessageContent(content.text) if content.function_call: function_call = content.function_call name, path = ToolAction.from_native_tool_name(function_call.name) args = json_format.MessageToDict(function_call._pb).get("args", {}) return ActionCallDeltaMessageContent( tag=function_call.name, name=name, path=path, partial_input=json.dumps(args), ) raise ValueError(f"Unsupported message content type {content}")
__to_prompt_stack_message_content(content)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent: json_format = import_optional_dependency("google.protobuf.json_format") if content.text: return TextMessageContent(TextArtifact(content.text)) if content.function_call: function_call = content.function_call name, path = ToolAction.from_native_tool_name(function_call.name) args = json_format.MessageToDict(function_call._pb).get("args", {}) return ActionCallMessageContent( artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args)), ) raise ValueError(f"Unsupported message content type {content}")
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: types = import_optional_dependency("google.generativeai.types") protos = import_optional_dependency("google.generativeai.protos") system_messages = prompt_stack.system_messages if system_messages: self.client._system_instruction = types.ContentDict( role="system", parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], ) params = { "generation_config": types.GenerationConfig( **{ # For some reason, providing stop sequences when streaming breaks native functions # https://github.com/google-gemini/generative-ai-python/issues/446 "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences, "max_output_tokens": self.max_tokens, "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, **self.extra_params, }, ), } if prompt_stack.tools and self.use_native_tools: params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": params["tool_config"]["function_calling_config"]["mode"] = "auto" params["tools"] = self.__to_google_tools(prompt_stack.tools) return params
client()
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@lazy_property() def client(self) -> GenerativeModel: genai = import_optional_dependency("google.generativeai") genai.configure(api_key=self.api_key) return genai.GenerativeModel(self.model)
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self.__to_google_messages(prompt_stack) params = self._base_params(prompt_stack) logger.debug((messages, params["generation_config"].__dict__)) response: GenerateContentResponse = self.client.generate_content(messages, **params) logger.debug(response.to_dict()) usage_metadata = response.usage_metadata return Message( content=[self.__to_prompt_stack_message_content(part) for part in response.parts], role=Message.ASSISTANT_ROLE, usage=Message.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count, ), )
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: messages = self.__to_google_messages(prompt_stack) params = {**self._base_params(prompt_stack), "stream": True} logger.debug((messages, params)) response: GenerateContentResponse = self.client.generate_content( messages, **params, ) prompt_token_count = None for chunk in response: logger.debug(chunk.to_dict()) usage_metadata = chunk.usage_metadata content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None # Only want to output the prompt token count once since it is static each chunk if prompt_token_count is None: prompt_token_count = usage_metadata.prompt_token_count yield DeltaMessage( content=content, usage=DeltaMessage.Usage( input_tokens=usage_metadata.prompt_token_count, output_tokens=usage_metadata.candidates_token_count, ), ) else: yield DeltaMessage( content=content, usage=DeltaMessage.Usage(output_tokens=usage_metadata.candidates_token_count), )
validatestructured_output_strategy(, value)
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "native": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value
GoogleWebSearchDriver
Bases:
BaseWebSearchDriver
Source Code in griptape/drivers/web_search/google_web_search_driver.py
@define class GoogleWebSearchDriver(BaseWebSearchDriver): api_key: str = field(kw_only=True) search_id: str = field(kw_only=True) language: str = field(default="en", kw_only=True) country: str = field(default="us", kw_only=True) def search(self, query: str, **kwargs) -> ListArtifact: return ListArtifact([TextArtifact(json.dumps(result)) for result in self._search_google(query, **kwargs)]) def _search_google(self, query: str, **kwargs) -> list[dict]: query_params = { "key": self.api_key, "cx": self.search_id, "q": query, "start": 0, "lr": f"lang_{self.language}", "num": self.results_count, "gl": self.country, **kwargs, } response = requests.get("https://www.googleapis.com/customsearch/v1", params=query_params) if response.status_code == 200: data = response.json() return [{"url": r["link"], "title": r["title"], "description": r["snippet"]} for r in data["items"]] raise Exception( f"Google Search API returned an error with status code " f"{response.status_code} and reason '{response.reason}'", )
api_key = field(kw_only=True)
class-attribute instance-attributecountry = field(default='us', kw_only=True)
class-attribute instance-attributelanguage = field(default='en', kw_only=True)
class-attribute instance-attributesearch_id = field(kw_only=True)
class-attribute instance-attribute
_search_google(query, **kwargs)
Source Code in griptape/drivers/web_search/google_web_search_driver.py
def _search_google(self, query: str, **kwargs) -> list[dict]: query_params = { "key": self.api_key, "cx": self.search_id, "q": query, "start": 0, "lr": f"lang_{self.language}", "num": self.results_count, "gl": self.country, **kwargs, } response = requests.get("https://www.googleapis.com/customsearch/v1", params=query_params) if response.status_code == 200: data = response.json() return [{"url": r["link"], "title": r["title"], "description": r["snippet"]} for r in data["items"]] raise Exception( f"Google Search API returned an error with status code " f"{response.status_code} and reason '{response.reason}'", )
search(query, **kwargs)
Source Code in griptape/drivers/web_search/google_web_search_driver.py
def search(self, query: str, **kwargs) -> ListArtifact: return ListArtifact([TextArtifact(json.dumps(result)) for result in self._search_google(query, **kwargs)])
GriptapeCloudAssistantDriver
Bases:
BaseAssistantDriver
Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
@define class GriptapeCloudAssistantDriver(BaseAssistantDriver): base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), ) api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True, ) input: Optional[str] = field(default=None, kw_only=True) assistant_id: str = field(kw_only=True) thread_id: Optional[str] = field(default=None, kw_only=True) thread_alias: Optional[str] = field(default=None, kw_only=True) ruleset_ids: Optional[list[str]] = field(default=None, kw_only=True) additional_ruleset_ids: list[str] = field(factory=list, kw_only=True) knowledge_base_ids: Optional[list[str]] = field(default=None, kw_only=True) additional_knowledge_base_ids: list[str] = field(factory=list, kw_only=True) structure_ids: Optional[list[str]] = field(default=None, kw_only=True) additional_structure_ids: list[str] = field(factory=list, kw_only=True) tool_ids: Optional[list[str]] = field(default=None, kw_only=True) additional_tool_ids: list[str] = field(factory=list, kw_only=True) stream: bool = field(default=False, kw_only=True) poll_interval: int = field(default=1, kw_only=True) max_attempts: int = field(default=20, kw_only=True) auto_create_thread: bool = field(default=True, kw_only=True) def try_run(self, *args: BaseArtifact) -> TextArtifact: if self.thread_id is None and self.auto_create_thread: self._create_or_find_thread(self.thread_alias) assistant_run_id = self._create_run(*args) run_result = self._get_run_result(assistant_run_id) run_result.meta.update( {"assistant_id": self.assistant_id, "assistant_run_id": assistant_run_id, "thread_id": self.thread_id} ) return run_result def _create_or_find_thread(self, thread_alias: Optional[str] = None) -> None: if thread_alias is None: self.thread_id = self._create_thread() else: thread = self._find_thread_by_alias(thread_alias) if thread is None: self.thread_id = self._create_thread(thread_alias) else: self.thread_id = thread["thread_id"] def _create_thread(self, thread_alias: Optional[str] = None) -> str: url = griptape_cloud_url(self.base_url, "api/threads") body = {"name": uuid.uuid4().hex} if thread_alias is not None: body["alias"] = thread_alias response = requests.post(url, json=body, headers=self.headers) response.raise_for_status() return response.json()["thread_id"] def _create_run(self, *args: BaseArtifact) -> str: url = griptape_cloud_url(self.base_url, f"api/assistants/{self.assistant_id}/runs") response = requests.post( url, json={ "args": [arg.value for arg in args], "stream": self.stream, "thread_id": self.thread_id, "input": self.input, **({"ruleset_ids": self.ruleset_ids} if self.ruleset_ids is not None else {}), "additional_ruleset_ids": self.additional_ruleset_ids, **({"knowledge_base_ids": self.knowledge_base_ids} if self.knowledge_base_ids is not None else {}), "additional_knowledge_base_ids": self.additional_knowledge_base_ids, **({"structure_ids": self.structure_ids} if self.structure_ids is not None else {}), "additional_structure_ids": self.additional_structure_ids, **({"tool_ids": self.tool_ids} if self.tool_ids is not None else {}), "additional_tool_ids": self.additional_tool_ids, }, headers=self.headers, ) response.raise_for_status() return response.json()["assistant_run_id"] def _get_run_result(self, assistant_run_id: str) -> TextArtifact: events = self._get_run_events(assistant_run_id) output = None for event in events: if event["origin"] == "ASSISTANT": event_payload = event["payload"] try: EventBus.publish_event(BaseEvent.from_dict(event_payload)) except ValueError as e: logger.warning("Failed to deserialize event: %s", e) if event["type"] == "FinishStructureRunEvent": output = TextArtifact.from_dict(event_payload["output_task_output"]) if output is None: raise ValueError("Output not found.") return output def _get_run_events(self, assistant_run_id: str) -> Iterator[dict]: url = griptape_cloud_url(self.base_url, f"api/assistant-runs/{assistant_run_id}/events/stream") with requests.get(url, headers=self.headers, stream=True) as response: response.raise_for_status() for line in response.iter_lines(): if line: decoded_line = line.decode("utf-8") if decoded_line.startswith("data:"): yield json.loads(decoded_line.removeprefix("data:").strip()) def _find_thread_by_alias(self, thread_alias: str) -> Optional[dict]: url = griptape_cloud_url(self.base_url, "api/threads") response = requests.get(url, params={"alias": thread_alias}, headers=self.headers) response.raise_for_status() threads = response.json()["threads"] return next((thread for thread in threads if thread["alias"] == thread_alias), None)
additional_knowledge_base_ids = field(factory=list, kw_only=True)
class-attribute instance-attributeadditional_ruleset_ids = field(factory=list, kw_only=True)
class-attribute instance-attributeadditional_structure_ids = field(factory=list, kw_only=True)
class-attribute instance-attributeadditional_tool_ids = field(factory=list, kw_only=True)
class-attribute instance-attributeapi_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY']))
class-attribute instance-attributeassistant_id = field(kw_only=True)
class-attribute instance-attributeauto_create_thread = field(default=True, kw_only=True)
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')))
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True)
class-attribute instance-attributeinput = field(default=None, kw_only=True)
class-attribute instance-attributeknowledge_base_ids = field(default=None, kw_only=True)
class-attribute instance-attributemax_attempts = field(default=20, kw_only=True)
class-attribute instance-attributepoll_interval = field(default=1, kw_only=True)
class-attribute instance-attributeruleset_ids = field(default=None, kw_only=True)
class-attribute instance-attributestream = field(default=False, kw_only=True)
class-attribute instance-attributestructure_ids = field(default=None, kw_only=True)
class-attribute instance-attributethread_alias = field(default=None, kw_only=True)
class-attribute instance-attributethread_id = field(default=None, kw_only=True)
class-attribute instance-attributetool_ids = field(default=None, kw_only=True)
class-attribute instance-attribute
_create_or_find_thread(thread_alias=None)
Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _create_or_find_thread(self, thread_alias: Optional[str] = None) -> None: if thread_alias is None: self.thread_id = self._create_thread() else: thread = self._find_thread_by_alias(thread_alias) if thread is None: self.thread_id = self._create_thread(thread_alias) else: self.thread_id = thread["thread_id"]
_create_run(*args)
Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _create_run(self, *args: BaseArtifact) -> str: url = griptape_cloud_url(self.base_url, f"api/assistants/{self.assistant_id}/runs") response = requests.post( url, json={ "args": [arg.value for arg in args], "stream": self.stream, "thread_id": self.thread_id, "input": self.input, **({"ruleset_ids": self.ruleset_ids} if self.ruleset_ids is not None else {}), "additional_ruleset_ids": self.additional_ruleset_ids, **({"knowledge_base_ids": self.knowledge_base_ids} if self.knowledge_base_ids is not None else {}), "additional_knowledge_base_ids": self.additional_knowledge_base_ids, **({"structure_ids": self.structure_ids} if self.structure_ids is not None else {}), "additional_structure_ids": self.additional_structure_ids, **({"tool_ids": self.tool_ids} if self.tool_ids is not None else {}), "additional_tool_ids": self.additional_tool_ids, }, headers=self.headers, ) response.raise_for_status() return response.json()["assistant_run_id"]
_create_thread(thread_alias=None)
Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _create_thread(self, thread_alias: Optional[str] = None) -> str: url = griptape_cloud_url(self.base_url, "api/threads") body = {"name": uuid.uuid4().hex} if thread_alias is not None: body["alias"] = thread_alias response = requests.post(url, json=body, headers=self.headers) response.raise_for_status() return response.json()["thread_id"]
_find_thread_by_alias(thread_alias)
Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _find_thread_by_alias(self, thread_alias: str) -> Optional[dict]: url = griptape_cloud_url(self.base_url, "api/threads") response = requests.get(url, params={"alias": thread_alias}, headers=self.headers) response.raise_for_status() threads = response.json()["threads"] return next((thread for thread in threads if thread["alias"] == thread_alias), None)
_get_run_events(assistant_run_id)
Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _get_run_events(self, assistant_run_id: str) -> Iterator[dict]: url = griptape_cloud_url(self.base_url, f"api/assistant-runs/{assistant_run_id}/events/stream") with requests.get(url, headers=self.headers, stream=True) as response: response.raise_for_status() for line in response.iter_lines(): if line: decoded_line = line.decode("utf-8") if decoded_line.startswith("data:"): yield json.loads(decoded_line.removeprefix("data:").strip())
_get_run_result(assistant_run_id)
Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _get_run_result(self, assistant_run_id: str) -> TextArtifact: events = self._get_run_events(assistant_run_id) output = None for event in events: if event["origin"] == "ASSISTANT": event_payload = event["payload"] try: EventBus.publish_event(BaseEvent.from_dict(event_payload)) except ValueError as e: logger.warning("Failed to deserialize event: %s", e) if event["type"] == "FinishStructureRunEvent": output = TextArtifact.from_dict(event_payload["output_task_output"]) if output is None: raise ValueError("Output not found.") return output
try_run(*args)
Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def try_run(self, *args: BaseArtifact) -> TextArtifact: if self.thread_id is None and self.auto_create_thread: self._create_or_find_thread(self.thread_alias) assistant_run_id = self._create_run(*args) run_result = self._get_run_result(assistant_run_id) run_result.meta.update( {"assistant_id": self.assistant_id, "assistant_run_id": assistant_run_id, "thread_id": self.thread_id} ) return run_result
GriptapeCloudConversationMemoryDriver
Bases:
BaseConversationMemoryDriver
Attributes
Name | Type | Description |
---|---|---|
thread_id | Optional[str] | The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to retrieve the ID from the environment variable GT_CLOUD_THREAD_ID . |
alias | Optional[str] | The alias of the Thread to store the conversation memory in. |
base_url | str | The base URL of the Gen AI Builder API. Defaults to the value of the environment variable GT_CLOUD_BASE_URL or https://cloud.griptape.ai . |
api_key | str | The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable GT_CLOUD_API_KEY . |
Raises
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
@define(kw_only=True) class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver): """A driver for storing conversation memory in the Gen AI Builder. Attributes: thread_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. alias: The alias of the Thread to store the conversation memory in. base_url: The base URL of the Gen AI Builder API. Defaults to the value of the environment variable `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`. api_key: The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`. Raises: ValueError: If `api_key` is not provided. """ thread_id: Optional[str] = field( default=None, metadata={"serializable": True}, ) alias: Optional[str] = field( default=None, metadata={"serializable": True}, ) base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), ) api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), init=False, ) _thread: Optional[dict] = field(default=None, init=False) @api_key.validator # pyright: ignore[reportAttributeAccessIssue] def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: if value is None: raise ValueError(f"{self.__class__.__name__} requires an API key") return value @property def thread(self) -> dict: """Try to get the Thread by ID, alias, or create a new one.""" if self._thread is None: thread = None if self.thread_id is None: self.thread_id = os.getenv("GT_CLOUD_THREAD_ID") if self.thread_id is not None: res = self._call_api("get", f"/threads/{self.thread_id}", raise_for_status=False) if res.status_code == 200: thread = res.json() # use name as 'alias' to get thread if thread is None and self.alias is not None: res = self._call_api("get", f"/threads?alias={self.alias}").json() if res.get("threads"): thread = res["threads"][0] self.thread_id = thread.get("thread_id") # no thread by name or thread_id if thread is None: data = {"name": uuid.uuid4().hex} if self.alias is None else {"name": self.alias, "alias": self.alias} thread = self._call_api("post", "/threads", data).json() self.thread_id = thread["thread_id"] self.alias = thread.get("alias") self._thread = thread return self._thread # pyright: ignore[reportReturnType] def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: # serialize the run artifacts to json strings messages = [ dict_merge( { "input": run.input.to_json(), "output": run.output.to_json(), "metadata": {"run_id": run.id}, }, run.meta, ) for run in runs ] body = dict_merge( { "messages": messages, }, metadata, ) # patch the Thread with the new messages and metadata # all old Messages are replaced with the new ones thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id self._call_api("patch", f"/threads/{thread_id}", body) self._thread = None def load(self) -> tuple[list[Run], dict[str, Any]]: from griptape.memory.structure import Run thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id # get the Messages from the Thread messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json() # retrieve the Thread to get the metadata thread_response = self._call_api("get", f"/threads/{thread_id}").json() runs = [ Run( **({"id": message["metadata"].pop("run_id", None)} if "run_id" in message.get("metadata") else {}), meta=message["metadata"], input=BaseArtifact.from_json(message["input"]), output=BaseArtifact.from_json(message["output"]), ) for message in messages_response.get("messages", []) ] return runs, thread_response.get("metadata", {}) def _get_url(self, path: str) -> str: path = path.lstrip("/") return griptape_cloud_url(self.base_url, f"api/{path}") def _call_api( self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True ) -> requests.Response: res = requests.request(method, self._get_url(path), json=json, headers=self.headers) if raise_for_status: res.raise_for_status() return res
_thread = field(default=None, init=False)
class-attribute instance-attributealias = field(default=None, metadata={'serializable': True})
class-attribute instance-attributeapi_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY']))
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')))
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), init=False)
class-attribute instance-attributethread
propertythread_id = field(default=None, metadata={'serializable': True})
class-attribute instance-attribute
_call_api(method, path, json=None, *, raise_for_status=True)
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def _call_api( self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True ) -> requests.Response: res = requests.request(method, self._get_url(path), json=json, headers=self.headers) if raise_for_status: res.raise_for_status() return res
_get_url(path)
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def _get_url(self, path: str) -> str: path = path.lstrip("/") return griptape_cloud_url(self.base_url, f"api/{path}")
load()
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]: from griptape.memory.structure import Run thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id # get the Messages from the Thread messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json() # retrieve the Thread to get the metadata thread_response = self._call_api("get", f"/threads/{thread_id}").json() runs = [ Run( **({"id": message["metadata"].pop("run_id", None)} if "run_id" in message.get("metadata") else {}), meta=message["metadata"], input=BaseArtifact.from_json(message["input"]), output=BaseArtifact.from_json(message["output"]), ) for message in messages_response.get("messages", []) ] return runs, thread_response.get("metadata", {})
store(runs, metadata)
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: # serialize the run artifacts to json strings messages = [ dict_merge( { "input": run.input.to_json(), "output": run.output.to_json(), "metadata": {"run_id": run.id}, }, run.meta, ) for run in runs ] body = dict_merge( { "messages": messages, }, metadata, ) # patch the Thread with the new messages and metadata # all old Messages are replaced with the new ones thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id self._call_api("patch", f"/threads/{thread_id}", body) self._thread = None
validateapi_key(, value)
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
@api_key.validator # pyright: ignore[reportAttributeAccessIssue] def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: if value is None: raise ValueError(f"{self.__class__.__name__} requires an API key") return value
GriptapeCloudEventListenerDriver
Bases:
BaseEventListenerDriver
Attributes
Name | Type | Description |
---|---|---|
base_url | str | The base URL of Gen AI Builder. Defaults to the GT_CLOUD_BASE_URL environment variable. |
api_key | Optional[str] | The API key to authenticate with Gen AI Builder. |
headers | dict | The headers to use when making requests to Gen AI Builder. Defaults to include the Authorization header. |
structure_run_id | Optional[str] | The ID of the Structure Run to publish events to. Defaults to the GT_CLOUD_STRUCTURE_RUN_ID environment variable. |
Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
@define class GriptapeCloudEventListenerDriver(BaseEventListenerDriver): """Driver for publishing events to Gen AI Builder. Attributes: base_url: The base URL of Gen AI Builder. Defaults to the GT_CLOUD_BASE_URL environment variable. api_key: The API key to authenticate with Gen AI Builder. headers: The headers to use when making requests to Gen AI Builder. Defaults to include the Authorization header. structure_run_id: The ID of the Structure Run to publish events to. Defaults to the GT_CLOUD_STRUCTURE_RUN_ID environment variable. """ base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), kw_only=True, ) api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")), kw_only=True) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True, ) structure_run_id: Optional[str] = field( default=Factory(lambda: os.getenv("GT_CLOUD_STRUCTURE_RUN_ID")), kw_only=True ) @structure_run_id.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_run_id(self, _: Attribute, structure_run_id: Optional[str]) -> None: if structure_run_id is None: raise ValueError( "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID).", ) @api_key.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_api_key(self, _: Attribute, api_key: Optional[str]) -> None: if api_key is None: raise ValueError( "No value was found for the 'GT_CLOUD_API_KEY' environment variable. " "This environment variable is required when running in Gen AI Builder for authorization. " "You can generate a Gen AI Builder API Key by visiting https://cloud.griptape.ai/keys . " "Specify it as an environment variable when creating a Managed Structure in Gen AI Builder." ) def publish_event(self, event: BaseEvent | dict) -> None: from griptape.observability.observability import Observability event_payload = event.to_dict() if isinstance(event, BaseEvent) else event span_id = Observability.get_span_id() if span_id is not None: event_payload["span_id"] = span_id super().publish_event(event_payload) def try_publish_event_payload(self, event_payload: dict) -> None: self._post_event(self._get_event_request(event_payload)) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: self._post_event([self._get_event_request(event_payload) for event_payload in event_payload_batch]) def _get_event_request(self, event_payload: dict) -> dict: return { "payload": event_payload, "timestamp": event_payload.get("timestamp", time.time()), "type": event_payload.get("type", "UserEvent"), } def _post_event(self, json: list[dict] | dict) -> None: requests.post( url=griptape_cloud_url(self.base_url, f"api/structure-runs/{self.structure_run_id}/events"), json=json, headers=self.headers, ).raise_for_status()
api_key = field(default=Factory(lambda: os.getenv('GT_CLOUD_API_KEY')), kw_only=True)
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')), kw_only=True)
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True)
class-attribute instance-attributestructure_run_id = field(default=Factory(lambda: os.getenv('GT_CLOUD_STRUCTURE_RUN_ID')), kw_only=True)
class-attribute instance-attribute
_get_event_request(event_payload)
Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def _get_event_request(self, event_payload: dict) -> dict: return { "payload": event_payload, "timestamp": event_payload.get("timestamp", time.time()), "type": event_payload.get("type", "UserEvent"), }
_post_event(json)
Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def _post_event(self, json: list[dict] | dict) -> None: requests.post( url=griptape_cloud_url(self.base_url, f"api/structure-runs/{self.structure_run_id}/events"), json=json, headers=self.headers, ).raise_for_status()
publish_event(event)
Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def publish_event(self, event: BaseEvent | dict) -> None: from griptape.observability.observability import Observability event_payload = event.to_dict() if isinstance(event, BaseEvent) else event span_id = Observability.get_span_id() if span_id is not None: event_payload["span_id"] = span_id super().publish_event(event_payload)
try_publish_event_payload(event_payload)
Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None: self._post_event(self._get_event_request(event_payload))
try_publish_event_payload_batch(event_payload_batch)
Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: self._post_event([self._get_event_request(event_payload) for event_payload in event_payload_batch])
validateapi_key(, api_key)
Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
@api_key.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_api_key(self, _: Attribute, api_key: Optional[str]) -> None: if api_key is None: raise ValueError( "No value was found for the 'GT_CLOUD_API_KEY' environment variable. " "This environment variable is required when running in Gen AI Builder for authorization. " "You can generate a Gen AI Builder API Key by visiting https://cloud.griptape.ai/keys . " "Specify it as an environment variable when creating a Managed Structure in Gen AI Builder." )
validaterun_id(, structure_run_id)
Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
@structure_run_id.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_run_id(self, _: Attribute, structure_run_id: Optional[str]) -> None: if structure_run_id is None: raise ValueError( "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID).", )
GriptapeCloudFileManagerDriver
Bases:
BaseFileManagerDriver
Attributes
Name | Type | Description |
---|---|---|
bucket_id | Optional[str] | The ID of the Bucket to list, load, and save Assets in. If not provided, the driver will attempt to retrieve the ID from the environment variable GT_CLOUD_BUCKET_ID . |
workdir | str | The working directory. List, load, and save operations will be performed relative to this directory. |
base_url | str | The base URL of the Gen AI Builder API. Defaults to the value of the environment variable GT_CLOUD_BASE_URL or https://cloud.griptape.ai . |
api_key | str | The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable GT_CLOUD_API_KEY . |
Raises
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
@define class GriptapeCloudFileManagerDriver(BaseFileManagerDriver): """GriptapeCloudFileManagerDriver can be used to list, load, and save files as Assets in Gen AI Builder Buckets. Attributes: bucket_id: The ID of the Bucket to list, load, and save Assets in. If not provided, the driver will attempt to retrieve the ID from the environment variable `GT_CLOUD_BUCKET_ID`. workdir: The working directory. List, load, and save operations will be performed relative to this directory. base_url: The base URL of the Gen AI Builder API. Defaults to the value of the environment variable `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`. api_key: The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`. Raises: ValueError: If `api_key` is not provided, if `workdir` does not start with "/"", or invalid `bucket_id` and/or `bucket_name` value(s) are provided. """ bucket_id: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_BUCKET_ID")), kw_only=True) base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), ) api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), init=False, ) _workdir: str = field(default="/", kw_only=True, alias="workdir") @property def workdir(self) -> str: if self._workdir.startswith("/"): return self._workdir return f"/{self._workdir}" @workdir.setter def workdir(self, value: str) -> None: self._workdir = value @bucket_id.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_bucket_id(self, _: Attribute, value: Optional[str]) -> str: if value is None: raise ValueError(f"{self.__class__.__name__} requires an Bucket ID") return value def __attrs_post_init__(self) -> None: try: self._call_api(method="get", path=f"/buckets/{self.bucket_id}").json() except requests.exceptions.HTTPError as e: if e.response.status_code == 404: raise ValueError(f"No Bucket found with ID: {self.bucket_id}") from e raise ValueError(f"Unexpected error when retrieving Bucket with ID: {self.bucket_id}") from e def try_list_files(self, path: str, postfix: str = "") -> list[str]: full_key = self._to_full_key(path) data = {"prefix": full_key} if postfix: data["postfix"] = postfix list_assets_response = self._call_api( method="get", path=f"/buckets/{self.bucket_id}/assets", params=data, raise_for_status=False ).json() return [asset["name"] for asset in list_assets_response.get("assets", [])] def try_load_file(self, path: str) -> bytes: full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError try: sas_url, headers = self._get_asset_url(full_key) response = requests.get(sas_url, headers=headers) response.raise_for_status() return response.content except requests.exceptions.HTTPError as e: if e.response.status_code == 404: raise FileNotFoundError from e raise e def try_save_file(self, path: str, value: bytes) -> str: full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError self._call_api( method="put", path=f"/buckets/{self.bucket_id}/assets", json={"name": full_key}, raise_for_status=True, ) sas_url, headers = self._get_asset_url(full_key) response = requests.put(sas_url, data=value, headers=headers) response.raise_for_status() return f"buckets/{self.bucket_id}/assets/{full_key}" def _get_asset_url(self, full_key: str) -> tuple[str, dict]: url_response = self._call_api( method="post", path=f"/buckets/{self.bucket_id}/asset-urls/{full_key}", raise_for_status=True ).json() return url_response["url"], url_response.get("headers", {}) def _get_url(self, path: str) -> str: path = path.lstrip("/") return griptape_cloud_url(self.base_url, f"api/{path}") def _call_api( self, method: str, path: str, json: Optional[dict] = None, params: Optional[dict] = None, *, raise_for_status: bool = True, ) -> requests.Response: res = requests.request(method, self._get_url(path), json=json, params=params, headers=self.headers) if raise_for_status: res.raise_for_status() return res def _is_a_directory(self, path: str) -> bool: return path == "" or path.endswith("/") def _to_full_key(self, path: str) -> str: path = path.lstrip("/") full_key = f"{self.workdir}/{path}" return full_key.lstrip("/")
_workdir = field(default='/', kw_only=True, alias='workdir')
class-attribute instance-attributeapi_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY']))
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')))
class-attribute instance-attributebucket_id = field(default=Factory(lambda: os.getenv('GT_CLOUD_BUCKET_ID')), kw_only=True)
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), init=False)
class-attribute instance-attributeworkdir
property writable
attrs_post_init()
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def __attrs_post_init__(self) -> None: try: self._call_api(method="get", path=f"/buckets/{self.bucket_id}").json() except requests.exceptions.HTTPError as e: if e.response.status_code == 404: raise ValueError(f"No Bucket found with ID: {self.bucket_id}") from e raise ValueError(f"Unexpected error when retrieving Bucket with ID: {self.bucket_id}") from e
_call_api(method, path, json=None, params=None, *, raise_for_status=True)
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def _call_api( self, method: str, path: str, json: Optional[dict] = None, params: Optional[dict] = None, *, raise_for_status: bool = True, ) -> requests.Response: res = requests.request(method, self._get_url(path), json=json, params=params, headers=self.headers) if raise_for_status: res.raise_for_status() return res
_get_asset_url(full_key)
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def _get_asset_url(self, full_key: str) -> tuple[str, dict]: url_response = self._call_api( method="post", path=f"/buckets/{self.bucket_id}/asset-urls/{full_key}", raise_for_status=True ).json() return url_response["url"], url_response.get("headers", {})
_get_url(path)
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def _get_url(self, path: str) -> str: path = path.lstrip("/") return griptape_cloud_url(self.base_url, f"api/{path}")
_is_a_directory(path)
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def _is_a_directory(self, path: str) -> bool: return path == "" or path.endswith("/")
_to_full_key(path)
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def _to_full_key(self, path: str) -> str: path = path.lstrip("/") full_key = f"{self.workdir}/{path}" return full_key.lstrip("/")
try_list_files(path, postfix='')
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def try_list_files(self, path: str, postfix: str = "") -> list[str]: full_key = self._to_full_key(path) data = {"prefix": full_key} if postfix: data["postfix"] = postfix list_assets_response = self._call_api( method="get", path=f"/buckets/{self.bucket_id}/assets", params=data, raise_for_status=False ).json() return [asset["name"] for asset in list_assets_response.get("assets", [])]
try_load_file(path)
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def try_load_file(self, path: str) -> bytes: full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError try: sas_url, headers = self._get_asset_url(full_key) response = requests.get(sas_url, headers=headers) response.raise_for_status() return response.content except requests.exceptions.HTTPError as e: if e.response.status_code == 404: raise FileNotFoundError from e raise e
try_save_file(path, value)
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def try_save_file(self, path: str, value: bytes) -> str: full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError self._call_api( method="put", path=f"/buckets/{self.bucket_id}/assets", json={"name": full_key}, raise_for_status=True, ) sas_url, headers = self._get_asset_url(full_key) response = requests.put(sas_url, data=value, headers=headers) response.raise_for_status() return f"buckets/{self.bucket_id}/assets/{full_key}"
validatebucket_id(, value)
Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
@bucket_id.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_bucket_id(self, _: Attribute, value: Optional[str]) -> str: if value is None: raise ValueError(f"{self.__class__.__name__} requires an Bucket ID") return value
GriptapeCloudImageGenerationDriver
Bases:
BaseImageGenerationDriver
Source Code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
@define class GriptapeCloudImageGenerationDriver(BaseImageGenerationDriver): model: Optional[str] = field(default=None, kw_only=True) base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), ) api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True ) style: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) quality: Literal["standard", "hd"] = field(default="standard", kw_only=True, metadata={"serializable": True}) image_size: Literal["1024x1024", "1024x1792", "1792x1024"] = field( default="1024x1024", kw_only=True, metadata={"serializable": True} ) def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: url = griptape_cloud_url(self.base_url, "api/images/generations") response = requests.post( url, headers=self.headers, json={ "prompts": prompts, "driver_configuration": { "model": self.model, "image_size": self.image_size, "quality": self.quality, "style": self.style, }, }, ) response.raise_for_status() response = response.json() return ImageArtifact.from_dict(response["artifact"]) def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support image variation") def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting") def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")
api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY']))
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')))
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True)
class-attribute instance-attributeimage_size = field(default='1024x1024', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(default=None, kw_only=True)
class-attribute instance-attributequality = field(default='standard', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestyle = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
try_image_inpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting")
try_image_outpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")
try_image_variation(prompts, image, negative_prompts=None)
Source Code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support image variation")
try_text_to_image(prompts, negative_prompts=None)
Source Code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: url = griptape_cloud_url(self.base_url, "api/images/generations") response = requests.post( url, headers=self.headers, json={ "prompts": prompts, "driver_configuration": { "model": self.model, "image_size": self.image_size, "quality": self.quality, "style": self.style, }, }, ) response.raise_for_status() response = response.json() return ImageArtifact.from_dict(response["artifact"])
GriptapeCloudObservabilityDriver
Bases:
OpenTelemetryObservabilityDriver
Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
@define class GriptapeCloudObservabilityDriver(OpenTelemetryObservabilityDriver): base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), kw_only=True ) api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]), kw_only=True) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True ) structure_run_id: Optional[str] = field( default=Factory(lambda: os.getenv("GT_CLOUD_STRUCTURE_RUN_ID")), kw_only=True ) span_processor: SpanProcessor = field( default=Factory( lambda self: import_optional_dependency("opentelemetry.sdk.trace.export").BatchSpanProcessor( GriptapeCloudObservabilityDriver.build_span_exporter( base_url=self.base_url, api_key=self.api_key, headers=self.headers, structure_run_id=self.structure_run_id, ) ), takes_self=True, ), kw_only=True, ) @structure_run_id.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_run_id(self, _: Attribute, structure_run_id: Optional[str]) -> None: if structure_run_id is None: raise ValueError( "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID)." ) @staticmethod def format_trace_id(trace_id: int) -> str: return str(UUID(int=trace_id)) @staticmethod def format_span_id(span_id: int) -> str: return str(UUID(int=span_id)) @staticmethod def build_span_exporter(base_url: str, api_key: str, headers: dict, structure_run_id: str) -> SpanExporter: @define class SpanExporter(import_optional_dependency("opentelemetry.sdk.trace.export").SpanExporter): base_url: str = field(kw_only=True) api_key: str = field(kw_only=True) headers: dict = field(kw_only=True) structure_run_id: str = field(kw_only=True) def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: opentelemetry_util = import_optional_dependency("opentelemetry.sdk.util") opentelemetry_trace_export = import_optional_dependency("opentelemetry.sdk.trace.export") url = griptape_cloud_url(self.base_url, f"api/structure-runs/{self.structure_run_id}/spans") payload = [ { "trace_id": GriptapeCloudObservabilityDriver.format_trace_id(span.context.trace_id), "span_id": GriptapeCloudObservabilityDriver.format_span_id(span.context.span_id), "parent_id": GriptapeCloudObservabilityDriver.format_span_id(span.parent.span_id) if span.parent else None, "name": span.name, "start_time": opentelemetry_util.ns_to_iso_str(span.start_time) if span.start_time else None, "end_time": opentelemetry_util.ns_to_iso_str(span.end_time) if span.end_time else None, "status": span.status.status_code.name, "attributes": {**span.attributes} if span.attributes else {}, "events": [ { "name": event.name, "timestamp": opentelemetry_util.ns_to_iso_str(event.timestamp) if event.timestamp else None, "attributes": {**event.attributes} if event.attributes else {}, } for event in span.events ], } for span in spans if span.context is not None ] response = requests.post(url=url, json=payload, headers=self.headers) return ( opentelemetry_trace_export.SpanExportResult.SUCCESS if response.status_code == 200 else opentelemetry_trace_export.SpanExportResult.FAILURE ) return SpanExporter( base_url=base_url, api_key=api_key, headers=headers, structure_run_id=structure_run_id, ) def get_span_id(self) -> Optional[str]: opentelemetry_trace = import_optional_dependency("opentelemetry.trace") span = opentelemetry_trace.get_current_span() if span is opentelemetry_trace.INVALID_SPAN: return None return GriptapeCloudObservabilityDriver.format_span_id(span.get_span_context().span_id)
api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY']), kw_only=True)
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')), kw_only=True)
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True)
class-attribute instance-attributespan_processor = field(default=Factory(lambda self: import_optional_dependency('opentelemetry.sdk.trace.export').BatchSpanProcessor(GriptapeCloudObservabilityDriver.build_span_exporter(base_url=self.base_url, api_key=self.api_key, headers=self.headers, structure_run_id=self.structure_run_id)), takes_self=True), kw_only=True)
class-attribute instance-attributestructure_run_id = field(default=Factory(lambda: os.getenv('GT_CLOUD_STRUCTURE_RUN_ID')), kw_only=True)
class-attribute instance-attribute
build_span_exporter(base_url, api_key, headers, structure_run_id)staticmethod
Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
@staticmethod def build_span_exporter(base_url: str, api_key: str, headers: dict, structure_run_id: str) -> SpanExporter: @define class SpanExporter(import_optional_dependency("opentelemetry.sdk.trace.export").SpanExporter): base_url: str = field(kw_only=True) api_key: str = field(kw_only=True) headers: dict = field(kw_only=True) structure_run_id: str = field(kw_only=True) def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: opentelemetry_util = import_optional_dependency("opentelemetry.sdk.util") opentelemetry_trace_export = import_optional_dependency("opentelemetry.sdk.trace.export") url = griptape_cloud_url(self.base_url, f"api/structure-runs/{self.structure_run_id}/spans") payload = [ { "trace_id": GriptapeCloudObservabilityDriver.format_trace_id(span.context.trace_id), "span_id": GriptapeCloudObservabilityDriver.format_span_id(span.context.span_id), "parent_id": GriptapeCloudObservabilityDriver.format_span_id(span.parent.span_id) if span.parent else None, "name": span.name, "start_time": opentelemetry_util.ns_to_iso_str(span.start_time) if span.start_time else None, "end_time": opentelemetry_util.ns_to_iso_str(span.end_time) if span.end_time else None, "status": span.status.status_code.name, "attributes": {**span.attributes} if span.attributes else {}, "events": [ { "name": event.name, "timestamp": opentelemetry_util.ns_to_iso_str(event.timestamp) if event.timestamp else None, "attributes": {**event.attributes} if event.attributes else {}, } for event in span.events ], } for span in spans if span.context is not None ] response = requests.post(url=url, json=payload, headers=self.headers) return ( opentelemetry_trace_export.SpanExportResult.SUCCESS if response.status_code == 200 else opentelemetry_trace_export.SpanExportResult.FAILURE ) return SpanExporter( base_url=base_url, api_key=api_key, headers=headers, structure_run_id=structure_run_id, )
format_span_id(span_id)staticmethod
Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
@staticmethod def format_span_id(span_id: int) -> str: return str(UUID(int=span_id))
format_trace_id(trace_id)staticmethod
Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
@staticmethod def format_trace_id(trace_id: int) -> str: return str(UUID(int=trace_id))
get_span_id()
Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
def get_span_id(self) -> Optional[str]: opentelemetry_trace = import_optional_dependency("opentelemetry.trace") span = opentelemetry_trace.get_current_span() if span is opentelemetry_trace.INVALID_SPAN: return None return GriptapeCloudObservabilityDriver.format_span_id(span.get_span_context().span_id)
validaterun_id(, structure_run_id)
Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
@structure_run_id.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_run_id(self, _: Attribute, structure_run_id: Optional[str]) -> None: if structure_run_id is None: raise ValueError( "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID)." )
GriptapeCloudPromptDriver
Bases:
BasePromptDriver
Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
@define class GriptapeCloudPromptDriver(BasePromptDriver): model: Optional[str] = field(default=None, kw_only=True) base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), ) api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True ) tokenizer: BaseTokenizer = field( default=Factory( lambda self: SimpleTokenizer( characters_per_token=4, max_input_tokens=2000, max_output_tokens=self.max_tokens, ), takes_self=True, ), kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True) structured_output_strategy: StructuredOutputStrategy = field( default="native", kw_only=True, metadata={"serializable": True} ) @observable def try_run(self, prompt_stack: PromptStack) -> Message: url = griptape_cloud_url(self.base_url, "api/chat/messages") params = self._base_params(prompt_stack) logger.debug(params) response = requests.post(url, headers=self.headers, json=params) response.raise_for_status() response_json = response.json() logger.debug(response_json) return Message.from_dict(response_json) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: url = griptape_cloud_url(self.base_url, "api/chat/messages/stream") params = self._base_params(prompt_stack) logger.debug(params) with requests.post(url, headers=self.headers, json=params, stream=True) as response: response.raise_for_status() for line in response.iter_lines(): if line: decoded_line = line.decode("utf-8") if decoded_line.startswith("data:"): delta_message_payload = decoded_line.removeprefix("data:").strip() logger.debug(delta_message_payload) yield DeltaMessage.from_json(delta_message_payload) def _base_params(self, prompt_stack: PromptStack) -> dict: return { "messages": prompt_stack.to_dict()["messages"], "tools": self.__to_griptape_tools(prompt_stack.tools), **({"output_schema": prompt_stack.to_output_json_schema()} if prompt_stack.output_schema else {}), "driver_configuration": { **({"model": self.model} if self.model else {}), "max_tokens": self.max_tokens, "use_native_tools": self.use_native_tools, "temperature": self.temperature, "structured_output_strategy": self.structured_output_strategy, "extra_params": self.extra_params, }, } def __to_griptape_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "name": tool.name, "activities": [ { "name": activity.__name__, "description": tool.activity_description(activity), "json_schema": tool.to_activity_json_schema(activity, "Schema"), } for activity in tool.activities() ], } for tool in tools ]
api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY']))
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')))
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True)
class-attribute instance-attributemodel = field(default=None, kw_only=True)
class-attribute instance-attributestructured_output_strategy = field(default='native', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: SimpleTokenizer(characters_per_token=4, max_input_tokens=2000, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True)
class-attribute instance-attributeuse_native_tools = field(default=True, kw_only=True)
class-attribute instance-attribute
__to_griptape_tools(tools)
Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
def __to_griptape_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "name": tool.name, "activities": [ { "name": activity.__name__, "description": tool.activity_description(activity), "json_schema": tool.to_activity_json_schema(activity, "Schema"), } for activity in tool.activities() ], } for tool in tools ]
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: return { "messages": prompt_stack.to_dict()["messages"], "tools": self.__to_griptape_tools(prompt_stack.tools), **({"output_schema": prompt_stack.to_output_json_schema()} if prompt_stack.output_schema else {}), "driver_configuration": { **({"model": self.model} if self.model else {}), "max_tokens": self.max_tokens, "use_native_tools": self.use_native_tools, "temperature": self.temperature, "structured_output_strategy": self.structured_output_strategy, "extra_params": self.extra_params, }, }
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: url = griptape_cloud_url(self.base_url, "api/chat/messages") params = self._base_params(prompt_stack) logger.debug(params) response = requests.post(url, headers=self.headers, json=params) response.raise_for_status() response_json = response.json() logger.debug(response_json) return Message.from_dict(response_json)
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: url = griptape_cloud_url(self.base_url, "api/chat/messages/stream") params = self._base_params(prompt_stack) logger.debug(params) with requests.post(url, headers=self.headers, json=params, stream=True) as response: response.raise_for_status() for line in response.iter_lines(): if line: decoded_line = line.decode("utf-8") if decoded_line.startswith("data:"): delta_message_payload = decoded_line.removeprefix("data:").strip() logger.debug(delta_message_payload) yield DeltaMessage.from_json(delta_message_payload)
GriptapeCloudRulesetDriver
Bases:
BaseRulesetDriver
Attributes
Name | Type | Description |
---|---|---|
ruleset_id | Optional[str] | The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to retrieve the ID from the environment variable GT_CLOUD_THREAD_ID . If that is not set, a new Thread will be created. |
base_url | str | The base URL of the Gen AI Builder API. Defaults to the value of the environment variable GT_CLOUD_BASE_URL or https://cloud.griptape.ai . |
api_key | Optional[str] | The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable GT_CLOUD_API_KEY . |
Raises
Source Code in griptape/drivers/ruleset/griptape_cloud_ruleset_driver.py
@define(kw_only=True) class GriptapeCloudRulesetDriver(BaseRulesetDriver): """A driver for storing conversation memory in the Gen AI Builder. Attributes: ruleset_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. If that is not set, a new Thread will be created. base_url: The base URL of the Gen AI Builder API. Defaults to the value of the environment variable `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`. api_key: The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`. Raises: ValueError: If `api_key` is not provided. """ ruleset_id: Optional[str] = field( default=None, metadata={"serializable": True}, ) base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), ) api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY"))) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), init=False, ) @api_key.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: if value is None: raise ValueError(f"{self.__class__.__name__} requires an API key") return value def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]: """Load the ruleset from Gen AI Builder, using the ruleset name as an alias if ruleset_id is not provided.""" ruleset = None if self.ruleset_id is not None: res = self._call_api("get", f"/rulesets/{self.ruleset_id}", raise_for_status=False) if res.status_code == 200: ruleset = res.json() # use name as 'alias' to get ruleset if ruleset is None: res = self._call_api("get", f"/rulesets?alias={ruleset_name}").json() if res.get("rulesets"): ruleset = res["rulesets"][0] # no ruleset by name or ruleset_id if ruleset is None: if self.raise_not_found: raise ValueError(f"No ruleset found with alias: {ruleset_name} or ruleset_id: {self.ruleset_id}") return [], {} rules = self._call_api("get", f"/rules?ruleset_id={ruleset['ruleset_id']}").json().get("rules", []) for rule in rules: rule["metadata"] = dict_merge(rule.get("metadata", {}), {"griptape_cloud_rule_id": rule["rule_id"]}) return [self._get_rule(rule["rule"], rule["metadata"]) for rule in rules], ruleset.get("metadata", {}) def _get_url(self, path: str) -> str: path = path.lstrip("/") return griptape_cloud_url(self.base_url, f"api/{path}") def _call_api(self, method: str, path: str, *, raise_for_status: bool = True) -> requests.Response: res = requests.request(method, self._get_url(path), headers=self.headers) if raise_for_status: res.raise_for_status() return res
api_key = field(default=Factory(lambda: os.getenv('GT_CLOUD_API_KEY')))
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')))
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), init=False)
class-attribute instance-attributeruleset_id = field(default=None, metadata={'serializable': True})
class-attribute instance-attribute
_call_api(method, path, *, raise_for_status=True)
Source Code in griptape/drivers/ruleset/griptape_cloud_ruleset_driver.py
def _call_api(self, method: str, path: str, *, raise_for_status: bool = True) -> requests.Response: res = requests.request(method, self._get_url(path), headers=self.headers) if raise_for_status: res.raise_for_status() return res
_get_url(path)
Source Code in griptape/drivers/ruleset/griptape_cloud_ruleset_driver.py
def _get_url(self, path: str) -> str: path = path.lstrip("/") return griptape_cloud_url(self.base_url, f"api/{path}")
load(ruleset_name)
Source Code in griptape/drivers/ruleset/griptape_cloud_ruleset_driver.py
def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]: """Load the ruleset from Gen AI Builder, using the ruleset name as an alias if ruleset_id is not provided.""" ruleset = None if self.ruleset_id is not None: res = self._call_api("get", f"/rulesets/{self.ruleset_id}", raise_for_status=False) if res.status_code == 200: ruleset = res.json() # use name as 'alias' to get ruleset if ruleset is None: res = self._call_api("get", f"/rulesets?alias={ruleset_name}").json() if res.get("rulesets"): ruleset = res["rulesets"][0] # no ruleset by name or ruleset_id if ruleset is None: if self.raise_not_found: raise ValueError(f"No ruleset found with alias: {ruleset_name} or ruleset_id: {self.ruleset_id}") return [], {} rules = self._call_api("get", f"/rules?ruleset_id={ruleset['ruleset_id']}").json().get("rules", []) for rule in rules: rule["metadata"] = dict_merge(rule.get("metadata", {}), {"griptape_cloud_rule_id": rule["rule_id"]}) return [self._get_rule(rule["rule"], rule["metadata"]) for rule in rules], ruleset.get("metadata", {})
validateapi_key(, value)
Source Code in griptape/drivers/ruleset/griptape_cloud_ruleset_driver.py
@api_key.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: if value is None: raise ValueError(f"{self.__class__.__name__} requires an API key") return value
GriptapeCloudStructureRunDriver
Bases:
BaseStructureRunDriver
Source Code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
@define class GriptapeCloudStructureRunDriver(BaseStructureRunDriver): base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) api_key: str = field(kw_only=True) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True, ) structure_id: str = field(kw_only=True) structure_run_wait_time_interval: int = field(default=2, kw_only=True) structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True) async_run: bool = field(default=False, kw_only=True) def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact: structure_run_id = self._create_run(*args) if self.async_run: return InfoArtifact("Run started successfully") return self._get_run_result(structure_run_id) def _create_run(self, *args: BaseArtifact) -> str: url = griptape_cloud_url(self.base_url, f"api/structures/{self.structure_id}/runs") env_vars = [{"name": key, "value": value, "source": "manual"} for key, value in self.env.items()] response = requests.post( url, json={"args": [arg.value for arg in args], "env_vars": env_vars}, headers=self.headers, ) response.raise_for_status() response_json = response.json() return response_json["structure_run_id"] def _get_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact: events = self._get_run_events(structure_run_id) output = None for event in events: event_type = event["type"] event_payload = event.get("payload", {}) if event["origin"] == "USER": try: if "span_id" in event_payload: span_id = event_payload.pop("span_id") if "meta" in event_payload: event_payload["meta"]["span_id"] = span_id else: event_payload["meta"] = {"span_id": span_id} EventBus.publish_event(BaseEvent.from_dict(event_payload)) except ValueError as e: logger.warning("Failed to deserialize event: %s", e) if event["type"] == "FinishStructureRunEvent": output = BaseArtifact.from_dict(event_payload["output_task_output"]) elif event["origin"] == "SYSTEM": if event_type == "StructureRunError": output = ErrorArtifact(event_payload["status_detail"]["error"]) if output is None: raise ValueError("Output not found.") return output def _get_run_events(self, structure_run_id: str) -> Iterator[dict]: url = griptape_cloud_url(self.base_url, f"api/structure-runs/{structure_run_id}/events/stream") with requests.get(url, headers=self.headers, stream=True) as response: response.raise_for_status() for line in response.iter_lines(): if line: decoded_line = line.decode("utf-8") if decoded_line.startswith("data:"): yield json.loads(decoded_line.removeprefix("data:").strip())
api_key = field(kw_only=True)
class-attribute instance-attributeasync_run = field(default=False, kw_only=True)
class-attribute instance-attributebase_url = field(default='https://cloud.griptape.ai', kw_only=True)
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True)
class-attribute instance-attributestructure_id = field(kw_only=True)
class-attribute instance-attributestructure_run_max_wait_time_attempts = field(default=20, kw_only=True)
class-attribute instance-attributestructure_run_wait_time_interval = field(default=2, kw_only=True)
class-attribute instance-attribute
_create_run(*args)
Source Code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def _create_run(self, *args: BaseArtifact) -> str: url = griptape_cloud_url(self.base_url, f"api/structures/{self.structure_id}/runs") env_vars = [{"name": key, "value": value, "source": "manual"} for key, value in self.env.items()] response = requests.post( url, json={"args": [arg.value for arg in args], "env_vars": env_vars}, headers=self.headers, ) response.raise_for_status() response_json = response.json() return response_json["structure_run_id"]
_get_run_events(structure_run_id)
Source Code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def _get_run_events(self, structure_run_id: str) -> Iterator[dict]: url = griptape_cloud_url(self.base_url, f"api/structure-runs/{structure_run_id}/events/stream") with requests.get(url, headers=self.headers, stream=True) as response: response.raise_for_status() for line in response.iter_lines(): if line: decoded_line = line.decode("utf-8") if decoded_line.startswith("data:"): yield json.loads(decoded_line.removeprefix("data:").strip())
_get_run_result(structure_run_id)
Source Code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def _get_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact: events = self._get_run_events(structure_run_id) output = None for event in events: event_type = event["type"] event_payload = event.get("payload", {}) if event["origin"] == "USER": try: if "span_id" in event_payload: span_id = event_payload.pop("span_id") if "meta" in event_payload: event_payload["meta"]["span_id"] = span_id else: event_payload["meta"] = {"span_id": span_id} EventBus.publish_event(BaseEvent.from_dict(event_payload)) except ValueError as e: logger.warning("Failed to deserialize event: %s", e) if event["type"] == "FinishStructureRunEvent": output = BaseArtifact.from_dict(event_payload["output_task_output"]) elif event["origin"] == "SYSTEM": if event_type == "StructureRunError": output = ErrorArtifact(event_payload["status_detail"]["error"]) if output is None: raise ValueError("Output not found.") return output
try_run(*args)
Source Code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact: structure_run_id = self._create_run(*args) if self.async_run: return InfoArtifact("Run started successfully") return self._get_run_result(structure_run_id)
GriptapeCloudVectorStoreDriver
Bases:
BaseVectorStoreDriver
Attributes
Name | Type | Description |
---|---|---|
api_key | str | API Key for Gen AI Builder. |
knowledge_base_id | str | Knowledge Base ID for Gen AI Builder. |
base_url | str | Base URL for Gen AI Builder. |
headers | dict | Headers for Gen AI Builder. |
Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
@define class GriptapeCloudVectorStoreDriver(BaseVectorStoreDriver): """A vector store driver for Gen AI Builder Knowledge Bases. Attributes: api_key: API Key for Gen AI Builder. knowledge_base_id: Knowledge Base ID for Gen AI Builder. base_url: Base URL for Gen AI Builder. headers: Headers for Gen AI Builder. """ base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), ) api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) knowledge_base_id: str = field(kw_only=True, metadata={"serializable": True}) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True, ) embedding_driver: BaseEmbeddingDriver = field( default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True}, kw_only=True, init=False, ) def upsert_vector( self, vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support vector upsert.") def upsert_text_artifact( self, artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support text artifact upsert.") def upsert_text( self, string: str, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.") def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry: raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.") def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.") def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.") def query( self, query: str | TextArtifact | ImageArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: Optional[bool] = None, distance_metric: Optional[str] = None, # GriptapeCloudVectorStoreDriver-specific params: filter: Optional[dict] = None, # noqa: A002 **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Performs a query on the Knowledge Base. Performs a query on the Knowledge Base and returns Artifacts with close vector proximity to the query, optionally filtering to only those that match the provided filter(s). """ if isinstance(query, ImageArtifact): raise ValueError(f"{self.__class__.__name__} does not support querying with Image Artifacts.") url = griptape_cloud_url(self.base_url, f"api/knowledge-bases/{self.knowledge_base_id}/query") query_args = { "count": count, "distance_metric": distance_metric, "filter": filter, "include_vectors": include_vectors, } query_args = {k: v for k, v in query_args.items() if v is not None} request: dict[str, Any] = { "query": str(query), "query_args": query_args, } response = requests.post(url, json=request, headers=self.headers).json() entries = response.get("entries", []) return [BaseVectorStoreDriver.Entry.from_dict(entry) for entry in entries] def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY']))
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')))
class-attribute instance-attributeembedding_driver = field(default=Factory(lambda: DummyEmbeddingDriver()), metadata={'serializable': True}, kw_only=True, init=False)
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True)
class-attribute instance-attributeknowledge_base_id = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
delete_vector(vector_id)
Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
load_artifacts(*, namespace=None)
Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.")
load_entries(*, namespace=None)
Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")
load_entry(vector_id, *, namespace=None)
Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry: raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")
query(query, *, count=None, namespace=None, include_vectors=None, distance_metric=None, filter=None, **kwargs)
Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def query( self, query: str | TextArtifact | ImageArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: Optional[bool] = None, distance_metric: Optional[str] = None, # GriptapeCloudVectorStoreDriver-specific params: filter: Optional[dict] = None, # noqa: A002 **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Performs a query on the Knowledge Base. Performs a query on the Knowledge Base and returns Artifacts with close vector proximity to the query, optionally filtering to only those that match the provided filter(s). """ if isinstance(query, ImageArtifact): raise ValueError(f"{self.__class__.__name__} does not support querying with Image Artifacts.") url = griptape_cloud_url(self.base_url, f"api/knowledge-bases/{self.knowledge_base_id}/query") query_args = { "count": count, "distance_metric": distance_metric, "filter": filter, "include_vectors": include_vectors, } query_args = {k: v for k, v in query_args.items() if v is not None} request: dict[str, Any] = { "query": str(query), "query_args": query_args, } response = requests.post(url, json=request, headers=self.headers).json() entries = response.get("entries", []) return [BaseVectorStoreDriver.Entry.from_dict(entry) for entry in entries]
upsert_text(string, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def upsert_text( self, string: str, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.")
upsert_text_artifact(artifact, namespace=None, meta=None, vector_id=None, **kwargs)
Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def upsert_text_artifact( self, artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support text artifact upsert.")
upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def upsert_vector( self, vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support vector upsert.")
GrokPromptDriver
Bases:
OpenAiChatPromptDriver
Source Code in griptape/drivers/prompt/grok_prompt_driver.py
@define class GrokPromptDriver(OpenAiChatPromptDriver): base_url: str = field(default="https://api.x.ai/v1", kw_only=True, metadata={"serializable": True}) tokenizer: GrokTokenizer = field( default=Factory( lambda self: GrokTokenizer(base_url=self.base_url, api_key=self.api_key, model=self.model), takes_self=True ), kw_only=True, metadata={"serializable": True}, )
base_url = field(default='https://api.x.ai/v1', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: GrokTokenizer(base_url=self.base_url, api_key=self.api_key, model=self.model), takes_self=True), kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
HuggingFaceHubEmbeddingDriver
Bases:
BaseEmbeddingDriver
Attributes
Name | Type | Description |
---|---|---|
api_token | str | Hugging Face Hub API token. |
model | str | Hugging Face Hub model name. |
client | InferenceClient | Custom InferenceApi . |
Source Code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
@define class HuggingFaceHubEmbeddingDriver(BaseEmbeddingDriver): """Hugging Face Hub Embedding Driver. Attributes: api_token: Hugging Face Hub API token. model: Hugging Face Hub model name. client: Custom `InferenceApi`. """ api_token: str = field(kw_only=True, metadata={"serializable": True}) _client: Optional[InferenceClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> InferenceClient: return import_optional_dependency("huggingface_hub").InferenceClient( model=self.model, token=self.api_token, ) def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: response = self.client.feature_extraction(chunk) return [float(val) for val in response.flatten().tolist()]
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_token = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
@lazy_property() def client(self) -> InferenceClient: return import_optional_dependency("huggingface_hub").InferenceClient( model=self.model, token=self.api_token, )
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: response = self.client.feature_extraction(chunk) return [float(val) for val in response.flatten().tolist()]
HuggingFaceHubPromptDriver
Bases:
BasePromptDriver
Attributes
Name | Type | Description |
---|---|---|
api_token | str | Hugging Face Hub API token. |
use_gpu | str | Use GPU during model run. |
model | str | Hugging Face Hub model name. |
client | InferenceClient | Custom InferenceApi . |
tokenizer | HuggingFaceTokenizer | Custom HuggingFaceTokenizer . |
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@define class HuggingFaceHubPromptDriver(BasePromptDriver): """Hugging Face Hub Prompt Driver. Attributes: api_token: Hugging Face Hub API token. use_gpu: Use GPU during model run. model: Hugging Face Hub model name. client: Custom `InferenceApi`. tokenizer: Custom `HuggingFaceTokenizer`. """ api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) structured_output_strategy: StructuredOutputStrategy = field( default="native", kw_only=True, metadata={"serializable": True} ) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True, ), kw_only=True, ) _client: Optional[InferenceClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> InferenceClient: return import_optional_dependency("huggingface_hub").InferenceClient( model=self.model, token=self.api_token, ) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "tool": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @observable def try_run(self, prompt_stack: PromptStack) -> Message: prompt = self.prompt_stack_to_string(prompt_stack) full_params = self._base_params(prompt_stack) logger.debug( { "prompt": prompt, **full_params, } ) response = self.client.text_generation( prompt, **full_params, ) logger.debug(response) input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(response)) # pyright: ignore[reportArgumentType] return Message( content=response, role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: prompt = self.prompt_stack_to_string(prompt_stack) full_params = {**self._base_params(prompt_stack), "stream": True} logger.debug( { "prompt": prompt, **full_params, } ) response = self.client.text_generation(prompt, **full_params) input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) full_text = "" for token in response: logger.debug(token) full_text += token yield DeltaMessage(content=TextDeltaMessageContent(token, index=0)) output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) # pyright: ignore[reportArgumentType] yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens)) def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) # pyright: ignore[reportArgumentType] def _base_params(self, prompt_stack: PromptStack) -> dict: params = { "return_full_text": False, "max_new_tokens": self.max_tokens, **self.extra_params, } if prompt_stack.output_schema and self.structured_output_strategy == "native": # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.to_output_json_schema() # Grammar does not support $schema and $id del output_schema["$schema"] del output_schema["$id"] params["grammar"] = {"type": "json", "value": output_schema} return params def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: if len(message.content) == 1: messages.append({"role": message.role, "content": message.to_text()}) else: raise ValueError("Invalid input content length.") return messages def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): return tokens # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int]. raise ValueError("Invalid output type.")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_token = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemax_tokens = field(default=250, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestructured_output_strategy = field(default='native', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True)
class-attribute instance-attribute
__prompt_stack_to_tokens(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): return tokens # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int]. raise ValueError("Invalid output type.")
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: params = { "return_full_text": False, "max_new_tokens": self.max_tokens, **self.extra_params, } if prompt_stack.output_schema and self.structured_output_strategy == "native": # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.to_output_json_schema() # Grammar does not support $schema and $id del output_schema["$schema"] del output_schema["$id"] params["grammar"] = {"type": "json", "value": output_schema} return params
_prompt_stack_to_messages(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: if len(message.content) == 1: messages.append({"role": message.role, "content": message.to_text()}) else: raise ValueError("Invalid input content length.") return messages
client()
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@lazy_property() def client(self) -> InferenceClient: return import_optional_dependency("huggingface_hub").InferenceClient( model=self.model, token=self.api_token, )
prompt_stack_to_string(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) # pyright: ignore[reportArgumentType]
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: prompt = self.prompt_stack_to_string(prompt_stack) full_params = self._base_params(prompt_stack) logger.debug( { "prompt": prompt, **full_params, } ) response = self.client.text_generation( prompt, **full_params, ) logger.debug(response) input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(response)) # pyright: ignore[reportArgumentType] return Message( content=response, role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), )
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: prompt = self.prompt_stack_to_string(prompt_stack) full_params = {**self._base_params(prompt_stack), "stream": True} logger.debug( { "prompt": prompt, **full_params, } ) response = self.client.text_generation(prompt, **full_params) input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) full_text = "" for token in response: logger.debug(token) full_text += token yield DeltaMessage(content=TextDeltaMessageContent(token, index=0)) output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) # pyright: ignore[reportArgumentType] yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens))
validatestructured_output_strategy(, value)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "tool": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value
HuggingFacePipelineImageGenerationDriver
Bases:
BaseImageGenerationDriver
, ABC
Attributes
Name | Type | Description |
---|---|---|
pipeline_driver | BaseDiffusionImageGenerationPipelineDriver | A pipeline image generation model driver typed for the specific pipeline required by the model. |
device | Optional[str] | The hardware device used for inference. For example, "cpu", "cuda", or "mps". |
output_format | str | The format the generated image is returned in. Defaults to "png". |
Source Code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
@define class HuggingFacePipelineImageGenerationDriver(BaseImageGenerationDriver, ABC): """Image generation driver for models hosted by Hugging Face's Diffusion Pipeline. For more information, see the HuggingFace documentation for Diffusers: https://huggingface.co/docs/diffusers/en/index Attributes: pipeline_driver: A pipeline image generation model driver typed for the specific pipeline required by the model. device: The hardware device used for inference. For example, "cpu", "cuda", or "mps". output_format: The format the generated image is returned in. Defaults to "png". """ pipeline_driver: BaseDiffusionImageGenerationPipelineDriver = field(kw_only=True, metadata={"serializable": True}) device: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) output_format: str = field(default="png", kw_only=True, metadata={"serializable": True}) def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device) prompt = ", ".join(prompts) output_image = pipeline( prompt, **self.pipeline_driver.make_additional_params(negative_prompts, self.device) ).images[0] buffer = io.BytesIO() output_image.save(buffer, format=self.output_format.upper()) return ImageArtifact( value=buffer.getvalue(), format=self.output_format.lower(), height=output_image.height, width=output_image.width, meta={"prompt": prompt}, ) def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None ) -> ImageArtifact: pil_image = import_optional_dependency("PIL.Image") pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device) prompt = ", ".join(prompts) input_image = pil_image.open(io.BytesIO(image.value)) # The size of the input image drives the size of the output image. # Resize the input image to the configured dimensions. output_width, output_height = self.pipeline_driver.output_image_dimensions if input_image.height != output_height or input_image.width != output_width: input_image = input_image.resize((output_width, output_height)) output_image = pipeline( prompt, **self.pipeline_driver.make_image_param(input_image), **self.pipeline_driver.make_additional_params(negative_prompts, self.device), ).images[0] buffer = io.BytesIO() output_image.save(buffer, format=self.output_format.upper()) return ImageArtifact( value=buffer.getvalue(), format=self.output_format.lower(), height=output_image.height, width=output_image.width, meta={"prompt": prompt}, ) def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError("Inpainting is not supported by this driver.") def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError("Outpainting is not supported by this driver.")
device = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeoutput_format = field(default='png', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributepipeline_driver = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
try_image_inpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError("Inpainting is not supported by this driver.")
try_image_outpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError("Outpainting is not supported by this driver.")
try_image_variation(prompts, image, negative_prompts=None)
Source Code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None ) -> ImageArtifact: pil_image = import_optional_dependency("PIL.Image") pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device) prompt = ", ".join(prompts) input_image = pil_image.open(io.BytesIO(image.value)) # The size of the input image drives the size of the output image. # Resize the input image to the configured dimensions. output_width, output_height = self.pipeline_driver.output_image_dimensions if input_image.height != output_height or input_image.width != output_width: input_image = input_image.resize((output_width, output_height)) output_image = pipeline( prompt, **self.pipeline_driver.make_image_param(input_image), **self.pipeline_driver.make_additional_params(negative_prompts, self.device), ).images[0] buffer = io.BytesIO() output_image.save(buffer, format=self.output_format.upper()) return ImageArtifact( value=buffer.getvalue(), format=self.output_format.lower(), height=output_image.height, width=output_image.width, meta={"prompt": prompt}, )
try_text_to_image(prompts, negative_prompts=None)
Source Code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device) prompt = ", ".join(prompts) output_image = pipeline( prompt, **self.pipeline_driver.make_additional_params(negative_prompts, self.device) ).images[0] buffer = io.BytesIO() output_image.save(buffer, format=self.output_format.upper()) return ImageArtifact( value=buffer.getvalue(), format=self.output_format.lower(), height=output_image.height, width=output_image.width, meta={"prompt": prompt}, )
HuggingFacePipelinePromptDriver
Bases:
BasePromptDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | Hugging Face Hub model name. |
Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@define class HuggingFacePipelinePromptDriver(BasePromptDriver): """Hugging Face Pipeline Prompt Driver. Attributes: model: Hugging Face Hub model name. """ max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True, ), kw_only=True, ) structured_output_strategy: StructuredOutputStrategy = field( default="rule", kw_only=True, metadata={"serializable": True} ) _pipeline: Optional[TextGenerationPipeline] = field( default=None, kw_only=True, alias="pipeline", metadata={"serializable": False} ) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value in ("native", "tool"): raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @lazy_property() def pipeline(self) -> TextGenerationPipeline: return import_optional_dependency("transformers").pipeline( task="text-generation", model=self.model, max_new_tokens=self.max_tokens, tokenizer=self.tokenizer.tokenizer, ) @observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self._prompt_stack_to_messages(prompt_stack) full_params = self._base_params(prompt_stack) logger.debug( ( messages, full_params, ) ) result = self.pipeline(messages, **full_params) logger.debug(result) if isinstance(result, list): if len(result) == 1: generated_text = result[0]["generated_text"][-1]["content"] input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) # pyright: ignore[reportArgumentType] return Message( content=[TextMessageContent(TextArtifact(generated_text))], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) raise Exception("completion with more than one choice is not supported yet") raise Exception("invalid output format") @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise NotImplementedError("streaming is not supported") def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) # pyright: ignore[reportArgumentType] def _base_params(self, prompt_stack: PromptStack) -> dict: return { "max_new_tokens": self.max_tokens, "temperature": self.temperature, "do_sample": True, **self.extra_params, } def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: messages.append({"role": message.role, "content": message.to_text()}) return messages def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): return tokens # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int]. raise ValueError("Invalid output type.")
_pipeline = field(default=None, kw_only=True, alias='pipeline', metadata={'serializable': False})
class-attribute instance-attributemax_tokens = field(default=250, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestructured_output_strategy = field(default='rule', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True)
class-attribute instance-attribute
__prompt_stack_to_tokens(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): return tokens # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int]. raise ValueError("Invalid output type.")
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: return { "max_new_tokens": self.max_tokens, "temperature": self.temperature, "do_sample": True, **self.extra_params, }
_prompt_stack_to_messages(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: messages.append({"role": message.role, "content": message.to_text()}) return messages
pipeline()
Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@lazy_property() def pipeline(self) -> TextGenerationPipeline: return import_optional_dependency("transformers").pipeline( task="text-generation", model=self.model, max_new_tokens=self.max_tokens, tokenizer=self.tokenizer.tokenizer, )
prompt_stack_to_string(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) # pyright: ignore[reportArgumentType]
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: messages = self._prompt_stack_to_messages(prompt_stack) full_params = self._base_params(prompt_stack) logger.debug( ( messages, full_params, ) ) result = self.pipeline(messages, **full_params) logger.debug(result) if isinstance(result, list): if len(result) == 1: generated_text = result[0]["generated_text"][-1]["content"] input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(generated_text)) # pyright: ignore[reportArgumentType] return Message( content=[TextMessageContent(TextArtifact(generated_text))], role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) raise Exception("completion with more than one choice is not supported yet") raise Exception("invalid output format")
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: raise NotImplementedError("streaming is not supported")
validatestructured_output_strategy(, value)
Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value in ("native", "tool"): raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value
LeonardoImageGenerationDriver
Bases:
BaseImageGenerationDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | The ID of the model to use when generating images. |
api_key | str | The API key to use when making requests to the Leonardo API. |
requests_session | Session | The requests session to use when making requests to the Leonardo API. |
api_base | str | The base URL of the Leonardo API. |
max_attempts | int | The maximum number of times to poll the Leonardo API for a completed image. |
image_width | int | The width of the generated image in the range [32, 1024] and divisible by 8. |
image_height | int | The height of the generated image in the range [32, 1024] and divisible by 8. |
steps | Optional[int] | Optionally specify the number of inference steps to run for each image generation request, [30, 60]. |
seed | Optional[int] | Optionally provide a consistent seed to generation requests, increasing consistency in output. |
init_strength | Optional[float] | Optionally specify the strength of the initial image, [0.0, 1.0]. |
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
@define class LeonardoImageGenerationDriver(BaseImageGenerationDriver): """Driver for the Leonardo image generation API. Details on Leonardo image generation parameters can be found here: https://docs.leonardo.ai/reference/creategeneration Attributes: model: The ID of the model to use when generating images. api_key: The API key to use when making requests to the Leonardo API. requests_session: The requests session to use when making requests to the Leonardo API. api_base: The base URL of the Leonardo API. max_attempts: The maximum number of times to poll the Leonardo API for a completed image. image_width: The width of the generated image in the range [32, 1024] and divisible by 8. image_height: The height of the generated image in the range [32, 1024] and divisible by 8. steps: Optionally specify the number of inference steps to run for each image generation request, [30, 60]. seed: Optionally provide a consistent seed to generation requests, increasing consistency in output. init_strength: Optionally specify the strength of the initial image, [0.0, 1.0]. """ api_key: str = field(kw_only=True, metadata={"serializable": True}) requests_session: requests.Session = field(default=Factory(lambda: requests.Session()), kw_only=True) api_base: str = "https://cloud.leonardo.ai/api/rest/v1" max_attempts: int = field(default=10, kw_only=True, metadata={"serializable": True}) image_width: int = field(default=512, kw_only=True, metadata={"serializable": True}) image_height: int = field(default=512, kw_only=True, metadata={"serializable": True}) steps: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) init_strength: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) control_net: bool = field(default=False, kw_only=True, metadata={"serializable": True}) control_net_type: Optional[Literal["POSE", "CANNY", "DEPTH"]] = field( default=None, kw_only=True, metadata={"serializable": True}, ) def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: if negative_prompts is None: negative_prompts = [] generation_id = self._create_generation(prompts=prompts, negative_prompts=negative_prompts) image_url = self._get_image_url(generation_id=generation_id) image_data = self._download_image(url=image_url) return ImageArtifact( value=image_data, format="png", width=self.image_width, height=self.image_height, meta={ "model": self.model, "prompt": ", ".join(prompts), }, ) def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: if negative_prompts is None: negative_prompts = [] init_image_id = self._upload_init_image(image) generation_id = self._create_generation( prompts=prompts, negative_prompts=negative_prompts, init_image_id=init_image_id, ) image_url = self._get_image_url(generation_id=generation_id) image_data = self._download_image(url=image_url) return ImageArtifact( value=image_data, format="png", width=self.image_width, height=self.image_height, meta={ "model": self.model, "prompt": ", ".join(prompts), }, ) def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting") def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting") def _upload_init_image(self, image: ImageArtifact) -> str: request = {"extension": image.mime_type.split("/")[1]} prep_response = self._make_api_request("/init-image", request=request) if prep_response is None or prep_response["uploadInitImage"] is None: raise Exception(f"failed to prepare init image: {prep_response}") fields = json.loads(prep_response["uploadInitImage"]["fields"]) pre_signed_url = prep_response["uploadInitImage"]["url"] init_image_id = prep_response["uploadInitImage"]["id"] files = {"file": image.value} upload_response = requests.post(pre_signed_url, data=fields, files=files) if not upload_response.ok: raise Exception(f"failed to upload init image: {upload_response.text}") return init_image_id def _create_generation( self, prompts: list[str], negative_prompts: list[str], init_image_id: Optional[str] = None, ) -> str: prompt = ", ".join(prompts) negative_prompt = ", ".join(negative_prompts) request = { "prompt": prompt, "negative_prompt": negative_prompt, "width": self.image_width, "height": self.image_height, "num_images": 1, "modelId": self.model, } if init_image_id is not None: request["init_image_id"] = init_image_id if self.init_strength is not None: request["init_strength"] = self.init_strength if self.steps: request["num_inference_steps"] = self.steps if self.seed is not None: request["seed"] = self.seed if self.control_net: request["controlNet"] = self.control_net request["controlNetType"] = self.control_net_type response = self._make_api_request("/generations", request=request) if response is None or response["sdGenerationJob"] is None: raise Exception(f"failed to create generation: {response}") return response["sdGenerationJob"]["generationId"] def _make_api_request(self, endpoint: str, request: dict, method: str = "POST") -> dict: url = f"{self.api_base}{endpoint}" headers = {"Authorization": f"Bearer {self.api_key}"} response = self.requests_session.request(url=url, method=method, json=request, headers=headers) if not response.ok: raise Exception(f"failed to make API request: {response.text}") return response.json() def _get_image_url(self, generation_id: str) -> str: for attempt in range(self.max_attempts): response = self.requests_session.get( url=f"{self.api_base}/generations/{generation_id}", headers={"Authorization": f"Bearer {self.api_key}"}, ).json() if response["generations_by_pk"]["status"] == "PENDING": time.sleep(attempt + 1) continue return response["generations_by_pk"]["generated_images"][0]["url"] raise Exception("image generation failed to complete") def _download_image(self, url: str) -> bytes: response = self.requests_session.get(url=url, headers={"Authorization": f"Bearer {self.api_key}"}) return response.content
api_base = 'https://cloud.leonardo.ai/api/rest/v1'
class-attribute instance-attributeapi_key = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributecontrol_net = field(default=False, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributecontrol_net_type = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeimage_height = field(default=512, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeimage_width = field(default=512, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeinit_strength = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemax_attempts = field(default=10, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributerequests_session = field(default=Factory(lambda: requests.Session()), kw_only=True)
class-attribute instance-attributeseed = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesteps = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
_create_generation(prompts, negative_prompts, init_image_id=None)
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def _create_generation( self, prompts: list[str], negative_prompts: list[str], init_image_id: Optional[str] = None, ) -> str: prompt = ", ".join(prompts) negative_prompt = ", ".join(negative_prompts) request = { "prompt": prompt, "negative_prompt": negative_prompt, "width": self.image_width, "height": self.image_height, "num_images": 1, "modelId": self.model, } if init_image_id is not None: request["init_image_id"] = init_image_id if self.init_strength is not None: request["init_strength"] = self.init_strength if self.steps: request["num_inference_steps"] = self.steps if self.seed is not None: request["seed"] = self.seed if self.control_net: request["controlNet"] = self.control_net request["controlNetType"] = self.control_net_type response = self._make_api_request("/generations", request=request) if response is None or response["sdGenerationJob"] is None: raise Exception(f"failed to create generation: {response}") return response["sdGenerationJob"]["generationId"]
_download_image(url)
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def _download_image(self, url: str) -> bytes: response = self.requests_session.get(url=url, headers={"Authorization": f"Bearer {self.api_key}"}) return response.content
_get_image_url(generation_id)
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def _get_image_url(self, generation_id: str) -> str: for attempt in range(self.max_attempts): response = self.requests_session.get( url=f"{self.api_base}/generations/{generation_id}", headers={"Authorization": f"Bearer {self.api_key}"}, ).json() if response["generations_by_pk"]["status"] == "PENDING": time.sleep(attempt + 1) continue return response["generations_by_pk"]["generated_images"][0]["url"] raise Exception("image generation failed to complete")
_make_api_request(endpoint, request, method='POST')
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def _make_api_request(self, endpoint: str, request: dict, method: str = "POST") -> dict: url = f"{self.api_base}{endpoint}" headers = {"Authorization": f"Bearer {self.api_key}"} response = self.requests_session.request(url=url, method=method, json=request, headers=headers) if not response.ok: raise Exception(f"failed to make API request: {response.text}") return response.json()
_upload_init_image(image)
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def _upload_init_image(self, image: ImageArtifact) -> str: request = {"extension": image.mime_type.split("/")[1]} prep_response = self._make_api_request("/init-image", request=request) if prep_response is None or prep_response["uploadInitImage"] is None: raise Exception(f"failed to prepare init image: {prep_response}") fields = json.loads(prep_response["uploadInitImage"]["fields"]) pre_signed_url = prep_response["uploadInitImage"]["url"] init_image_id = prep_response["uploadInitImage"]["id"] files = {"file": image.value} upload_response = requests.post(pre_signed_url, data=fields, files=files) if not upload_response.ok: raise Exception(f"failed to upload init image: {upload_response.text}") return init_image_id
try_image_inpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting")
try_image_outpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")
try_image_variation(prompts, image, negative_prompts=None)
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: if negative_prompts is None: negative_prompts = [] init_image_id = self._upload_init_image(image) generation_id = self._create_generation( prompts=prompts, negative_prompts=negative_prompts, init_image_id=init_image_id, ) image_url = self._get_image_url(generation_id=generation_id) image_data = self._download_image(url=image_url) return ImageArtifact( value=image_data, format="png", width=self.image_width, height=self.image_height, meta={ "model": self.model, "prompt": ", ".join(prompts), }, )
try_text_to_image(prompts, negative_prompts=None)
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: if negative_prompts is None: negative_prompts = [] generation_id = self._create_generation(prompts=prompts, negative_prompts=negative_prompts) image_url = self._get_image_url(generation_id=generation_id) image_data = self._download_image(url=image_url) return ImageArtifact( value=image_data, format="png", width=self.image_width, height=self.image_height, meta={ "model": self.model, "prompt": ", ".join(prompts), }, )
LocalConversationMemoryDriver
Bases:
BaseConversationMemoryDriver
Source Code in griptape/drivers/memory/conversation/local_conversation_memory_driver.py
@define(kw_only=True) class LocalConversationMemoryDriver(BaseConversationMemoryDriver): persist_file: Optional[str] = field(default=None, metadata={"serializable": True}) def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: if self.persist_file is not None: Path(self.persist_file).write_text(json.dumps(self._to_params_dict(runs, metadata))) def load(self) -> tuple[list[Run], dict[str, Any]]: if ( self.persist_file is not None and os.path.exists(self.persist_file) and (loaded_str := Path(self.persist_file).read_text()) is not None ): try: return self._from_params_dict(json.loads(loaded_str)) except Exception as e: raise ValueError(f"Unable to load data from {self.persist_file}") from e return [], {}
persist_file = field(default=None, metadata={'serializable': True})
class-attribute instance-attribute
load()
Source Code in griptape/drivers/memory/conversation/local_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]: if ( self.persist_file is not None and os.path.exists(self.persist_file) and (loaded_str := Path(self.persist_file).read_text()) is not None ): try: return self._from_params_dict(json.loads(loaded_str)) except Exception as e: raise ValueError(f"Unable to load data from {self.persist_file}") from e return [], {}
store(runs, metadata)
Source Code in griptape/drivers/memory/conversation/local_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: if self.persist_file is not None: Path(self.persist_file).write_text(json.dumps(self._to_params_dict(runs, metadata)))
LocalFileManagerDriver
Bases:
BaseFileManagerDriver
Attributes
Name | Type | Description |
---|---|---|
workdir | str | The working directory as an absolute path. List, load, and save operations will be performed relative to this directory. Defaults to the current working directory. Setting this to None will disable the working directory and all paths will be treated as absolute paths. |
Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
@define class LocalFileManagerDriver(BaseFileManagerDriver): """LocalFileManagerDriver can be used to list, load, and save files on the local file system. Attributes: workdir: The working directory as an absolute path. List, load, and save operations will be performed relative to this directory. Defaults to the current working directory. Setting this to None will disable the working directory and all paths will be treated as absolute paths. """ _workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True, alias="workdir") @property def workdir(self) -> str: if os.path.isabs(self._workdir): return self._workdir return os.path.join(os.getcwd(), self._workdir) @workdir.setter def workdir(self, value: str) -> None: self._workdir = value def try_list_files(self, path: str) -> list[str]: full_path = self._full_path(path) return os.listdir(full_path) def try_load_file(self, path: str) -> bytes: full_path = self._full_path(path) if self._is_dir(full_path): raise IsADirectoryError return Path(full_path).read_bytes() def try_save_file(self, path: str, value: bytes) -> str: full_path = self._full_path(path) if self._is_dir(full_path): raise IsADirectoryError os.makedirs(os.path.dirname(full_path), exist_ok=True) Path(full_path).write_bytes(value) return full_path def _full_path(self, path: str) -> str: full_path = path if os.path.isabs(path) else os.path.join(self.workdir, path.lstrip("/")) # Need to keep the trailing slash if it was there, # because it means the path is a directory. ended_with_sep = path.endswith("/") full_path = os.path.normpath(full_path) if ended_with_sep: full_path = full_path.rstrip("/") + "/" return full_path def _is_dir(self, full_path: str) -> bool: return full_path.endswith("/") or Path(full_path).is_dir()
_workdir = field(default=Factory(lambda: os.getcwd()), kw_only=True, alias='workdir')
class-attribute instance-attributeworkdir
property writable
_full_path(path)
Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
def _full_path(self, path: str) -> str: full_path = path if os.path.isabs(path) else os.path.join(self.workdir, path.lstrip("/")) # Need to keep the trailing slash if it was there, # because it means the path is a directory. ended_with_sep = path.endswith("/") full_path = os.path.normpath(full_path) if ended_with_sep: full_path = full_path.rstrip("/") + "/" return full_path
_is_dir(full_path)
Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
def _is_dir(self, full_path: str) -> bool: return full_path.endswith("/") or Path(full_path).is_dir()
try_list_files(path)
Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
def try_list_files(self, path: str) -> list[str]: full_path = self._full_path(path) return os.listdir(full_path)
try_load_file(path)
Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
def try_load_file(self, path: str) -> bytes: full_path = self._full_path(path) if self._is_dir(full_path): raise IsADirectoryError return Path(full_path).read_bytes()
try_save_file(path, value)
Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
def try_save_file(self, path: str, value: bytes) -> str: full_path = self._full_path(path) if self._is_dir(full_path): raise IsADirectoryError os.makedirs(os.path.dirname(full_path), exist_ok=True) Path(full_path).write_bytes(value) return full_path
LocalRerankDriver
Bases:
BaseRerankDriver
, FuturesExecutorMixin
Source Code in griptape/drivers/rerank/local_rerank_driver.py
@define(kw_only=True) class LocalRerankDriver(BaseRerankDriver, FuturesExecutorMixin): calculate_relatedness: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y))) embedding_driver: BaseEmbeddingDriver = field( kw_only=True, default=Factory(lambda: Defaults.drivers_config.embedding_driver), metadata={"serializable": True} ) def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: query_embedding = self.embedding_driver.embed(query, vector_operation="query") with self.create_futures_executor() as futures_executor: artifact_embeddings = execute_futures_list( [ futures_executor.submit( with_contextvars(self.embedding_driver.embed_text_artifact), a, vector_operation="upsert" ) for a in artifacts ], ) artifacts_and_relatednesses = [ (artifact, self.calculate_relatedness(query_embedding, artifact_embedding)) for artifact, artifact_embedding in zip(artifacts, artifact_embeddings) ] artifacts_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True) return [artifact for artifact, _ in artifacts_and_relatednesses]
calculate_relatedness = field(default=lambda x, y: dot(x, y) / norm(x) * norm(y))
class-attribute instance-attributeembedding_driver = field(kw_only=True, default=Factory(lambda: Defaults.drivers_config.embedding_driver), metadata={'serializable': True})
class-attribute instance-attribute
run(query, artifacts)
Source Code in griptape/drivers/rerank/local_rerank_driver.py
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: query_embedding = self.embedding_driver.embed(query, vector_operation="query") with self.create_futures_executor() as futures_executor: artifact_embeddings = execute_futures_list( [ futures_executor.submit( with_contextvars(self.embedding_driver.embed_text_artifact), a, vector_operation="upsert" ) for a in artifacts ], ) artifacts_and_relatednesses = [ (artifact, self.calculate_relatedness(query_embedding, artifact_embedding)) for artifact, artifact_embedding in zip(artifacts, artifact_embeddings) ] artifacts_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True) return [artifact for artifact, _ in artifacts_and_relatednesses]
LocalRulesetDriver
Bases:
BaseRulesetDriver
Source Code in griptape/drivers/ruleset/local_ruleset_driver.py
@define(kw_only=True) class LocalRulesetDriver(BaseRulesetDriver): persist_dir: Optional[str] = field(default=None, metadata={"serializable": True}) def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]: if self.persist_dir is None: return [], {} file_name = os.path.join(self.persist_dir, ruleset_name) if ( file_name is not None and os.path.exists(file_name) and (loaded_str := Path(file_name).read_text()) is not None ): try: return self._from_ruleset_dict(json.loads(loaded_str)) except Exception as e: raise ValueError(f"Unable to load data from {file_name}") from e if self.raise_not_found: raise ValueError(f"Ruleset not found with name {file_name}") return [], {}
persist_dir = field(default=None, metadata={'serializable': True})
class-attribute instance-attribute
load(ruleset_name)
Source Code in griptape/drivers/ruleset/local_ruleset_driver.py
def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]: if self.persist_dir is None: return [], {} file_name = os.path.join(self.persist_dir, ruleset_name) if ( file_name is not None and os.path.exists(file_name) and (loaded_str := Path(file_name).read_text()) is not None ): try: return self._from_ruleset_dict(json.loads(loaded_str)) except Exception as e: raise ValueError(f"Unable to load data from {file_name}") from e if self.raise_not_found: raise ValueError(f"Ruleset not found with name {file_name}") return [], {}
LocalStructureRunDriver
Bases:
BaseStructureRunDriver
Source Code in griptape/drivers/structure_run/local_structure_run_driver.py
@define class LocalStructureRunDriver(BaseStructureRunDriver): create_structure: Callable[[], Structure] = field(kw_only=True) def try_run(self, *args: BaseArtifact) -> BaseArtifact: old_env = os.environ.copy() try: os.environ.update(self.env) structure = self.create_structure().run(*[arg.value for arg in args]) finally: os.environ.clear() os.environ.update(old_env) return structure.output
create_structure = field(kw_only=True)
class-attribute instance-attribute
try_run(*args)
Source Code in griptape/drivers/structure_run/local_structure_run_driver.py
def try_run(self, *args: BaseArtifact) -> BaseArtifact: old_env = os.environ.copy() try: os.environ.update(self.env) structure = self.create_structure().run(*[arg.value for arg in args]) finally: os.environ.clear() os.environ.update(old_env) return structure.output
LocalVectorStoreDriver
Bases:
BaseVectorStoreDriver
Source Code in griptape/drivers/vector/local_vector_store_driver.py
@define(kw_only=True) class LocalVectorStoreDriver(BaseVectorStoreDriver): entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict) persist_file: Optional[str] = field(default=None) calculate_relatedness: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y))) thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock())) def __attrs_post_init__(self) -> None: if self.persist_file is not None: directory = os.path.dirname(self.persist_file) if directory and not os.path.exists(directory): os.makedirs(directory) if not os.path.isfile(self.persist_file): with open(self.persist_file, "w") as file: self.__save_entries_to_file(file) with open(self.persist_file, "r+") as file: if os.path.getsize(self.persist_file) > 0: self.entries = self.load_entries_from_file(file) else: self.__save_entries_to_file(file) def load_entries_from_file(self, json_file: TextIO) -> dict[str, BaseVectorStoreDriver.Entry]: with self.thread_lock: data = json.load(json_file) return {k: BaseVectorStoreDriver.Entry.from_dict(v) for k, v in data.items()} def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: vector_id = vector_id or utils.str_to_hash(str(vector)) with self.thread_lock: self.entries[self.__namespaced_vector_id(vector_id, namespace=namespace)] = self.Entry( id=vector_id, vector=vector, meta=meta, namespace=namespace, ) if self.persist_file is not None: # TODO: optimize later since it reserializes all entries from memory and stores them in the JSON file # every time a new vector is inserted with open(self.persist_file, "w") as file: self.__save_entries_to_file(file) return vector_id def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: return self.entries.get(self.__namespaced_vector_id(vector_id, namespace=namespace), None) def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace] def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: if namespace: entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")} else: entries = self.entries entries_and_relatednesses = [ (entry, self.calculate_relatedness(vector, entry.vector)) for entry in list(entries.values()) ] entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True) result = [ BaseVectorStoreDriver.Entry( id=er[0].id, vector=er[0].vector, score=er[1], meta=er[0].meta, namespace=er[0].namespace ) for er in entries_and_relatednesses ][:count] if include_vectors: return result return [ BaseVectorStoreDriver.Entry(id=r.id, vector=[], score=r.score, meta=r.meta, namespace=r.namespace) for r in result ] def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") def __save_entries_to_file(self, json_file: TextIO) -> None: with self.thread_lock: serialized_data = {k: v.to_dict() for k, v in self.entries.items()} json.dump(serialized_data, json_file) def __namespaced_vector_id(self, vector_id: str, *, namespace: Optional[str]) -> str: return vector_id if namespace is None else f"{namespace}-{vector_id}"
calculate_relatedness = field(default=lambda x, y: dot(x, y) / norm(x) * norm(y))
class-attribute instance-attributeentries = field(factory=dict)
class-attribute instance-attributepersist_file = field(default=None)
class-attribute instance-attributethread_lock = field(default=Factory(lambda: threading.Lock()))
class-attribute instance-attribute
attrs_post_init()
Source Code in griptape/drivers/vector/local_vector_store_driver.py
def __attrs_post_init__(self) -> None: if self.persist_file is not None: directory = os.path.dirname(self.persist_file) if directory and not os.path.exists(directory): os.makedirs(directory) if not os.path.isfile(self.persist_file): with open(self.persist_file, "w") as file: self.__save_entries_to_file(file) with open(self.persist_file, "r+") as file: if os.path.getsize(self.persist_file) > 0: self.entries = self.load_entries_from_file(file) else: self.__save_entries_to_file(file)
__namespaced_vector_id(vector_id, *, namespace)
Source Code in griptape/drivers/vector/local_vector_store_driver.py
def __namespaced_vector_id(self, vector_id: str, *, namespace: Optional[str]) -> str: return vector_id if namespace is None else f"{namespace}-{vector_id}"
__save_entries_to_file(json_file)
Source Code in griptape/drivers/vector/local_vector_store_driver.py
def __save_entries_to_file(self, json_file: TextIO) -> None: with self.thread_lock: serialized_data = {k: v.to_dict() for k, v in self.entries.items()} json.dump(serialized_data, json_file)
delete_vector(vector_id)
Source Code in griptape/drivers/vector/local_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
load_entries(*, namespace=None)
Source Code in griptape/drivers/vector/local_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]
load_entries_from_file(json_file)
Source Code in griptape/drivers/vector/local_vector_store_driver.py
def load_entries_from_file(self, json_file: TextIO) -> dict[str, BaseVectorStoreDriver.Entry]: with self.thread_lock: data = json.load(json_file) return {k: BaseVectorStoreDriver.Entry.from_dict(v) for k, v in data.items()}
load_entry(vector_id, *, namespace=None)
Source Code in griptape/drivers/vector/local_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: return self.entries.get(self.__namespaced_vector_id(vector_id, namespace=namespace), None)
query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)
Source Code in griptape/drivers/vector/local_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: if namespace: entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")} else: entries = self.entries entries_and_relatednesses = [ (entry, self.calculate_relatedness(vector, entry.vector)) for entry in list(entries.values()) ] entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True) result = [ BaseVectorStoreDriver.Entry( id=er[0].id, vector=er[0].vector, score=er[1], meta=er[0].meta, namespace=er[0].namespace ) for er in entries_and_relatednesses ][:count] if include_vectors: return result return [ BaseVectorStoreDriver.Entry(id=r.id, vector=[], score=r.score, meta=r.meta, namespace=r.namespace) for r in result ]
upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/local_vector_store_driver.py
def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: vector_id = vector_id or utils.str_to_hash(str(vector)) with self.thread_lock: self.entries[self.__namespaced_vector_id(vector_id, namespace=namespace)] = self.Entry( id=vector_id, vector=vector, meta=meta, namespace=namespace, ) if self.persist_file is not None: # TODO: optimize later since it reserializes all entries from memory and stores them in the JSON file # every time a new vector is inserted with open(self.persist_file, "w") as file: self.__save_entries_to_file(file) return vector_id
MarkdownifyWebScraperDriver
Bases:
BaseWebScraperDriver
Attributes
Name | Type | Description |
---|---|---|
include_links | bool | If True , the driver will include link urls in the markdown output. |
exclude_tags | list[str] | Optionally provide custom tags to exclude from the scraped content. |
exclude_classes | list[str] | Optionally provide custom classes to exclude from the scraped content. |
exclude_ids | list[str] | Optionally provide custom ids to exclude from the scraped content. |
timeout | Optional[int] | Optionally provide a timeout in milliseconds for the page to continue loading after the browser has emitted the "load" event. |
Source Code in griptape/drivers/web_scraper/markdownify_web_scraper_driver.py
@define class MarkdownifyWebScraperDriver(BaseWebScraperDriver): """Driver to scrape a webpage and return the content in markdown format. As a prerequisite to using MarkdownifyWebScraperDriver, you need to install the browsers used by playwright. You can do this by running: `poetry run playwright install`. For more details about playwright, see https://playwright.dev/python/docs/library. Attributes: include_links: If `True`, the driver will include link urls in the markdown output. exclude_tags: Optionally provide custom tags to exclude from the scraped content. exclude_classes: Optionally provide custom classes to exclude from the scraped content. exclude_ids: Optionally provide custom ids to exclude from the scraped content. timeout: Optionally provide a timeout in milliseconds for the page to continue loading after the browser has emitted the "load" event. """ DEFAULT_EXCLUDE_TAGS = ["script", "style", "head", "audio", "img", "picture", "source", "video"] include_links: bool = field(default=True, kw_only=True) exclude_tags: list[str] = field( default=Factory(lambda self: self.DEFAULT_EXCLUDE_TAGS, takes_self=True), kw_only=True, ) exclude_classes: list[str] = field(default=Factory(list), kw_only=True) exclude_ids: list[str] = field(default=Factory(list), kw_only=True) timeout: Optional[int] = field(default=None, kw_only=True) def fetch_url(self, url: str) -> str: sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright with sync_playwright() as p, p.chromium.launch(headless=True) as browser: page = browser.new_page() def skip_loading_images(route: Any) -> Any: if route.request.resource_type == "image": return route.abort() route.continue_() return None page.route("**/*", skip_loading_images) page.goto(url) # Some websites require a delay before the content is fully loaded # even after the browser has emitted "load" event. if self.timeout: page.wait_for_timeout(self.timeout) content = page.content() if not content: raise Exception("can't access URL") return content def extract_page(self, page: str) -> TextArtifact: bs4 = import_optional_dependency("bs4") markdownify = import_optional_dependency("markdownify") include_links = self.include_links # Custom MarkdownConverter to optionally linked urls. If include_links is False only # the text of the link is returned. class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter): def convert_a(self, el: Any, text: str, parent_tags: Any) -> str: if include_links: return super().convert_a(el, text, parent_tags) return text soup = bs4.BeautifulSoup(page, "html.parser") # Remove unwanted elements exclude_selector = ",".join( self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids], ) if exclude_selector: for s in soup.select(exclude_selector): s.extract() text = OptionalLinksMarkdownConverter().convert_soup(soup) # Remove leading and trailing whitespace from the entire text text = text.strip() # Remove trailing whitespace from each line text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE) # Indent using 2 spaces instead of tabs text = re.sub(r"(\n?\s*?)\t", r"\1 ", text) # Remove triple+ newlines (keep double newlines for paragraphs) text = re.sub(r"\n\n+", "\n\n", text) return TextArtifact(text)
DEFAULT_EXCLUDE_TAGS = ['script', 'style', 'head', 'audio', 'img', 'picture', 'source', 'video']
class-attribute instance-attributeexclude_classes = field(default=Factory(list), kw_only=True)
class-attribute instance-attributeexclude_ids = field(default=Factory(list), kw_only=True)
class-attribute instance-attributeexclude_tags = field(default=Factory(lambda self: self.DEFAULT_EXCLUDE_TAGS, takes_self=True), kw_only=True)
class-attribute instance-attributeinclude_links = field(default=True, kw_only=True)
class-attribute instance-attributetimeout = field(default=None, kw_only=True)
class-attribute instance-attribute
extract_page(page)
Source Code in griptape/drivers/web_scraper/markdownify_web_scraper_driver.py
def extract_page(self, page: str) -> TextArtifact: bs4 = import_optional_dependency("bs4") markdownify = import_optional_dependency("markdownify") include_links = self.include_links # Custom MarkdownConverter to optionally linked urls. If include_links is False only # the text of the link is returned. class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter): def convert_a(self, el: Any, text: str, parent_tags: Any) -> str: if include_links: return super().convert_a(el, text, parent_tags) return text soup = bs4.BeautifulSoup(page, "html.parser") # Remove unwanted elements exclude_selector = ",".join( self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids], ) if exclude_selector: for s in soup.select(exclude_selector): s.extract() text = OptionalLinksMarkdownConverter().convert_soup(soup) # Remove leading and trailing whitespace from the entire text text = text.strip() # Remove trailing whitespace from each line text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE) # Indent using 2 spaces instead of tabs text = re.sub(r"(\n?\s*?)\t", r"\1 ", text) # Remove triple+ newlines (keep double newlines for paragraphs) text = re.sub(r"\n\n+", "\n\n", text) return TextArtifact(text)
fetch_url(url)
Source Code in griptape/drivers/web_scraper/markdownify_web_scraper_driver.py
def fetch_url(self, url: str) -> str: sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright with sync_playwright() as p, p.chromium.launch(headless=True) as browser: page = browser.new_page() def skip_loading_images(route: Any) -> Any: if route.request.resource_type == "image": return route.abort() route.continue_() return None page.route("**/*", skip_loading_images) page.goto(url) # Some websites require a delay before the content is fully loaded # even after the browser has emitted "load" event. if self.timeout: page.wait_for_timeout(self.timeout) content = page.content() if not content: raise Exception("can't access URL") return content
MarqoVectorStoreDriver
Bases:
BaseVectorStoreDriver
Attributes
Name | Type | Description |
---|---|---|
api_key | str | The API key for the Marqo API. |
url | str | The URL to the Marqo API. |
client | Client | An optional Marqo client. Defaults to a new client with the given URL and API key. |
index | str | The name of the index to use. |
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
@define class MarqoVectorStoreDriver(BaseVectorStoreDriver): """A Vector Store Driver for Marqo. Attributes: api_key: The API key for the Marqo API. url: The URL to the Marqo API. client: An optional Marqo client. Defaults to a new client with the given URL and API key. index: The name of the index to use. """ api_key: str = field(kw_only=True, metadata={"serializable": True}) url: str = field(kw_only=True, metadata={"serializable": True}) index: str = field(kw_only=True, metadata={"serializable": True}) _client: Optional[marqo.Client] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> marqo.Client: return import_optional_dependency("marqo").Client(self.url, api_key=self.api_key) def upsert( self, value: str | TextArtifact | ImageArtifact, *, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs: Any, ) -> str: """Upsert a text document into the Marqo index. Args: value: The value to be indexed. namespace: An optional namespace for the document. meta: An optional dictionary of metadata for the document. vector_id: The ID for the vector. If None, Marqo will generate an ID. kwargs: Additional keyword arguments to pass to the Marqo client. Returns: str: The ID of the document that was added. """ if isinstance(value, TextArtifact): artifact_json = value.to_json() vector_id = utils.str_to_hash(value.value) if vector_id is None else vector_id doc = { "_id": vector_id, "Description": value.value, "artifact": str(artifact_json), } elif isinstance(value, ImageArtifact): raise NotImplementedError("`MarqoVectorStoreDriver` does not upserting Image Artifacts.") else: doc = {"_id": vector_id, "Description": value} # Non-tensor fields if meta: doc["meta"] = str(meta) if namespace: doc["namespace"] = namespace response = self.client.index(self.index).add_documents([doc], tensor_fields=["Description"]) if isinstance(response, dict) and "items" in response and response["items"]: return response["items"][0]["_id"] raise ValueError(f"Failed to upsert text: {response}") def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Load a document entry from the Marqo index. Args: vector_id: The ID of the vector to load. namespace: The namespace of the vector to load. Returns: The loaded Entry if found, otherwise None. """ result = self.client.index(self.index).get_document(document_id=vector_id, expose_facets=True) if result and "_tensor_facets" in result and len(result["_tensor_facets"]) > 0: return BaseVectorStoreDriver.Entry( id=result["_id"], meta={k: v for k, v in result.items() if k != "_id"}, vector=result["_tensor_facets"][0]["_embedding"], ) return None def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Load all document entries from the Marqo index. Args: namespace: The namespace to filter entries by. Returns: The list of loaded Entries. """ filter_string = f"namespace:{namespace}" if namespace else None if filter_string is not None: results = self.client.index(self.index).search("", limit=10000, filter_string=filter_string) else: results = self.client.index(self.index).search("", limit=10000) # get all _id's from search results ids = [r["_id"] for r in results["hits"]] # get documents corresponding to the ids documents = self.client.index(self.index).get_documents(document_ids=ids, expose_facets=True) # for each document, if it's found, create an Entry object entries = [] for doc in documents["results"]: if doc["_found"]: entries.append( BaseVectorStoreDriver.Entry( id=doc["_id"], vector=doc["_tensor_facets"][0]["_embedding"], meta={k: v for k, v in doc.items() if k not in ["_id", "_tensor_facets", "_found"]}, namespace=doc.get("namespace"), ), ) return entries def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, include_metadata: bool = True, **kwargs: Any, ) -> list[BaseVectorStoreDriver.Entry]: """Query the Marqo index for documents. Args: vector: The vector to query by. count: The maximum number of results to return. namespace: The namespace to filter results by. include_vectors: Whether to include vector data in the results. include_metadata: Whether to include metadata in the results. kwargs: Additional keyword arguments to pass to the Marqo client. Returns: The list of query results. """ params = { "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, "attributes_to_retrieve": None if include_metadata else ["_id"], "filter_string": f"namespace:{namespace}" if namespace else None, } | kwargs results = self.client.index(self.index).search(**params, context={"tensor": [vector], "weight": 1}) return self.__process_results(results, include_vectors=include_vectors) def query( self, query: str | TextArtifact | ImageArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, include_metadata: bool = True, **kwargs: Any, ) -> list[BaseVectorStoreDriver.Entry]: """Query the Marqo index for documents. Args: query: The query string. count: The maximum number of results to return. namespace: The namespace to filter results by. include_vectors: Whether to include vector data in the results. include_metadata: Whether to include metadata in the results. kwargs: Additional keyword arguments to pass to the Marqo client. Returns: The list of query results. """ params = { "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, "attributes_to_retrieve": None if include_metadata else ["_id"], "filter_string": f"namespace:{namespace}" if namespace else None, } | kwargs results = self.client.index(self.index).search(str(query), **params) return self.__process_results(results, include_vectors=include_vectors) def delete_index(self, name: str) -> dict[str, Any]: """Delete an index in the Marqo client. Args: name: The name of the index to delete. """ return self.client.delete_index(name) def get_indexes(self) -> list[str]: """Get a list of all indexes in the Marqo client. Returns: The list of all indexes. """ return [index["index"] for index in self.client.get_indexes()["results"]] def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs: Any, ) -> str: """Upsert a vector into the Marqo index. Args: vector: The vector to be indexed. vector_id: The ID for the vector. If None, Marqo will generate an ID. namespace: An optional namespace for the vector. meta: An optional dictionary of metadata for the vector. kwargs: Additional keyword arguments to pass to the Marqo client. Raises: Exception: This function is not yet implemented. Returns: The ID of the vector that was added. """ raise NotImplementedError(f"{self.__class__.__name__} does not support upserting a vector.") def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.") def __process_results(self, results: dict, *, include_vectors: bool) -> list[BaseVectorStoreDriver.Entry]: if include_vectors: results["hits"] = [ {**r, **self.client.index(self.index).get_document(r["_id"], expose_facets=True)} for r in results["hits"] ] return [ BaseVectorStoreDriver.Entry( id=r["_id"], vector=r["_tensor_facets"][0]["_embedding"] if include_vectors else [], score=r["_score"], meta={k: v for k, v in r.items() if k not in ["_score", "_tensor_facets"]}, ) for r in results["hits"] ]
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeindex = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeurl = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
__process_results(results, *, include_vectors)
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def __process_results(self, results: dict, *, include_vectors: bool) -> list[BaseVectorStoreDriver.Entry]: if include_vectors: results["hits"] = [ {**r, **self.client.index(self.index).get_document(r["_id"], expose_facets=True)} for r in results["hits"] ] return [ BaseVectorStoreDriver.Entry( id=r["_id"], vector=r["_tensor_facets"][0]["_embedding"] if include_vectors else [], score=r["_score"], meta={k: v for k, v in r.items() if k not in ["_score", "_tensor_facets"]}, ) for r in results["hits"] ]
client()
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
@lazy_property() def client(self) -> marqo.Client: return import_optional_dependency("marqo").Client(self.url, api_key=self.api_key)
delete_index(name)
Delete an index in the Marqo client.
Parameters
Name | Type | Description | Default |
---|---|---|---|
name | str | The name of the index to delete. | required |
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def delete_index(self, name: str) -> dict[str, Any]: """Delete an index in the Marqo client. Args: name: The name of the index to delete. """ return self.client.delete_index(name)
delete_vector(vector_id)
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
get_indexes()
Get a list of all indexes in the Marqo client.
Returns
Type | Description |
---|---|
list[str] | The list of all indexes. |
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def get_indexes(self) -> list[str]: """Get a list of all indexes in the Marqo client. Returns: The list of all indexes. """ return [index["index"] for index in self.client.get_indexes()["results"]]
load_entries(*, namespace=None)
Load all document entries from the Marqo index.
Parameters
Name | Type | Description | Default |
---|---|---|---|
namespace | Optional[str] | The namespace to filter entries by. | None |
Returns
Type | Description |
---|---|
list[Entry] | The list of loaded Entries. |
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Load all document entries from the Marqo index. Args: namespace: The namespace to filter entries by. Returns: The list of loaded Entries. """ filter_string = f"namespace:{namespace}" if namespace else None if filter_string is not None: results = self.client.index(self.index).search("", limit=10000, filter_string=filter_string) else: results = self.client.index(self.index).search("", limit=10000) # get all _id's from search results ids = [r["_id"] for r in results["hits"]] # get documents corresponding to the ids documents = self.client.index(self.index).get_documents(document_ids=ids, expose_facets=True) # for each document, if it's found, create an Entry object entries = [] for doc in documents["results"]: if doc["_found"]: entries.append( BaseVectorStoreDriver.Entry( id=doc["_id"], vector=doc["_tensor_facets"][0]["_embedding"], meta={k: v for k, v in doc.items() if k not in ["_id", "_tensor_facets", "_found"]}, namespace=doc.get("namespace"), ), ) return entries
load_entry(vector_id, *, namespace=None)
Load a document entry from the Marqo index.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector_id | str | The ID of the vector to load. | required |
namespace | Optional[str] | The namespace of the vector to load. | None |
Returns
Type | Description |
---|---|
Optional[Entry] | The loaded Entry if found, otherwise None. |
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Load a document entry from the Marqo index. Args: vector_id: The ID of the vector to load. namespace: The namespace of the vector to load. Returns: The loaded Entry if found, otherwise None. """ result = self.client.index(self.index).get_document(document_id=vector_id, expose_facets=True) if result and "_tensor_facets" in result and len(result["_tensor_facets"]) > 0: return BaseVectorStoreDriver.Entry( id=result["_id"], meta={k: v for k, v in result.items() if k != "_id"}, vector=result["_tensor_facets"][0]["_embedding"], ) return None
query(query, *, count=None, namespace=None, include_vectors=False, include_metadata=True, **kwargs)
Query the Marqo index for documents.
Parameters
Name | Type | Description | Default |
---|---|---|---|
query | str | TextArtifact | ImageArtifact | The query string. | required |
count | Optional[int] | The maximum number of results to return. | None |
namespace | Optional[str] | The namespace to filter results by. | None |
include_vectors | bool | Whether to include vector data in the results. | False |
include_metadata | bool | Whether to include metadata in the results. | True |
kwargs | Any | Additional keyword arguments to pass to the Marqo client. | {} |
Returns
Type | Description |
---|---|
list[Entry] | The list of query results. |
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def query( self, query: str | TextArtifact | ImageArtifact, *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, include_metadata: bool = True, **kwargs: Any, ) -> list[BaseVectorStoreDriver.Entry]: """Query the Marqo index for documents. Args: query: The query string. count: The maximum number of results to return. namespace: The namespace to filter results by. include_vectors: Whether to include vector data in the results. include_metadata: Whether to include metadata in the results. kwargs: Additional keyword arguments to pass to the Marqo client. Returns: The list of query results. """ params = { "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, "attributes_to_retrieve": None if include_metadata else ["_id"], "filter_string": f"namespace:{namespace}" if namespace else None, } | kwargs results = self.client.index(self.index).search(str(query), **params) return self.__process_results(results, include_vectors=include_vectors)
query_vector(vector, *, count=None, namespace=None, include_vectors=False, include_metadata=True, **kwargs)
Query the Marqo index for documents.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector | list[float] | The vector to query by. | required |
count | Optional[int] | The maximum number of results to return. | None |
namespace | Optional[str] | The namespace to filter results by. | None |
include_vectors | bool | Whether to include vector data in the results. | False |
include_metadata | bool | Whether to include metadata in the results. | True |
kwargs | Any | Additional keyword arguments to pass to the Marqo client. | {} |
Returns
Type | Description |
---|---|
list[Entry] | The list of query results. |
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, include_metadata: bool = True, **kwargs: Any, ) -> list[BaseVectorStoreDriver.Entry]: """Query the Marqo index for documents. Args: vector: The vector to query by. count: The maximum number of results to return. namespace: The namespace to filter results by. include_vectors: Whether to include vector data in the results. include_metadata: Whether to include metadata in the results. kwargs: Additional keyword arguments to pass to the Marqo client. Returns: The list of query results. """ params = { "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, "attributes_to_retrieve": None if include_metadata else ["_id"], "filter_string": f"namespace:{namespace}" if namespace else None, } | kwargs results = self.client.index(self.index).search(**params, context={"tensor": [vector], "weight": 1}) return self.__process_results(results, include_vectors=include_vectors)
upsert(value, *, namespace=None, meta=None, vector_id=None, **kwargs)
Upsert a text document into the Marqo index.
Parameters
Name | Type | Description | Default |
---|---|---|---|
value | str | TextArtifact | ImageArtifact | The value to be indexed. | required |
namespace | Optional[str] | An optional namespace for the document. | None |
meta | Optional[dict] | An optional dictionary of metadata for the document. | None |
vector_id | Optional[str] | The ID for the vector. If None, Marqo will generate an ID. | None |
kwargs | Any | Additional keyword arguments to pass to the Marqo client. | {} |
Returns
Name | Type | Description |
---|---|---|
str | str | The ID of the document that was added. |
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def upsert( self, value: str | TextArtifact | ImageArtifact, *, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs: Any, ) -> str: """Upsert a text document into the Marqo index. Args: value: The value to be indexed. namespace: An optional namespace for the document. meta: An optional dictionary of metadata for the document. vector_id: The ID for the vector. If None, Marqo will generate an ID. kwargs: Additional keyword arguments to pass to the Marqo client. Returns: str: The ID of the document that was added. """ if isinstance(value, TextArtifact): artifact_json = value.to_json() vector_id = utils.str_to_hash(value.value) if vector_id is None else vector_id doc = { "_id": vector_id, "Description": value.value, "artifact": str(artifact_json), } elif isinstance(value, ImageArtifact): raise NotImplementedError("`MarqoVectorStoreDriver` does not upserting Image Artifacts.") else: doc = {"_id": vector_id, "Description": value} # Non-tensor fields if meta: doc["meta"] = str(meta) if namespace: doc["namespace"] = namespace response = self.client.index(self.index).add_documents([doc], tensor_fields=["Description"]) if isinstance(response, dict) and "items" in response and response["items"]: return response["items"][0]["_id"] raise ValueError(f"Failed to upsert text: {response}")
upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)
Upsert a vector into the Marqo index.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector | list[float] | The vector to be indexed. | required |
vector_id | Optional[str] | The ID for the vector. If None, Marqo will generate an ID. | None |
namespace | Optional[str] | An optional namespace for the vector. | None |
meta | Optional[dict] | An optional dictionary of metadata for the vector. | None |
kwargs | Any | Additional keyword arguments to pass to the Marqo client. | {} |
Raises
Returns
Type | Description |
---|---|
str | The ID of the vector that was added. |
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs: Any, ) -> str: """Upsert a vector into the Marqo index. Args: vector: The vector to be indexed. vector_id: The ID for the vector. If None, Marqo will generate an ID. namespace: An optional namespace for the vector. meta: An optional dictionary of metadata for the vector. kwargs: Additional keyword arguments to pass to the Marqo client. Raises: Exception: This function is not yet implemented. Returns: The ID of the vector that was added. """ raise NotImplementedError(f"{self.__class__.__name__} does not support upserting a vector.")
MongoDbAtlasVectorStoreDriver
Bases:
BaseVectorStoreDriver
Attributes
Name | Type | Description |
---|---|---|
connection_string | str | The connection string for the MongoDb Atlas cluster. |
database_name | str | The name of the database to use. |
collection_name | str | The name of the collection to use. |
index_name | str | The name of the index to use. |
vector_path | str | The path to the vector field in the collection. |
client | MongoClient | An optional MongoDb client to use. Defaults to a new client using the connection string. |
Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
@define class MongoDbAtlasVectorStoreDriver(BaseVectorStoreDriver): """A Vector Store Driver for MongoDb Atlas. Attributes: connection_string: The connection string for the MongoDb Atlas cluster. database_name: The name of the database to use. collection_name: The name of the collection to use. index_name: The name of the index to use. vector_path: The path to the vector field in the collection. client: An optional MongoDb client to use. Defaults to a new client using the connection string. """ MAX_NUM_CANDIDATES = 10000 connection_string: str = field(kw_only=True, metadata={"serializable": True}) database_name: str = field(kw_only=True, metadata={"serializable": True}) collection_name: str = field(kw_only=True, metadata={"serializable": True}) index_name: str = field(kw_only=True, metadata={"serializable": True}) vector_path: str = field(kw_only=True, metadata={"serializable": True}) num_candidates_multiplier: int = field( default=10, kw_only=True, metadata={"serializable": True}, ) # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#fields _client: Optional[MongoClient] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> MongoClient: return import_optional_dependency("pymongo").MongoClient(self.connection_string) def get_collection(self) -> Collection: """Returns the MongoDB Collection instance for the specified database and collection name.""" return self.client[self.database_name][self.collection_name] def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: """Inserts or updates a vector in the collection. If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. """ collection = self.get_collection() if vector_id is None: result = collection.insert_one({self.vector_path: vector, "namespace": namespace, "meta": meta}) vector_id = str(result.inserted_id) else: collection.replace_one( {"_id": vector_id}, {self.vector_path: vector, "namespace": namespace, "meta": meta}, upsert=True, ) return vector_id def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Loads a document entry from the MongoDB collection based on the vector ID. Returns: The loaded Entry if found; otherwise, None is returned. """ collection = self.get_collection() if namespace: doc = collection.find_one({"_id": vector_id, "namespace": namespace}) else: doc = collection.find_one({"_id": vector_id}) if doc is None: return doc return BaseVectorStoreDriver.Entry( id=str(doc["_id"]), vector=doc[self.vector_path], namespace=doc["namespace"], meta=doc["meta"], ) def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Loads all document entries from the MongoDB collection. Entries can optionally be filtered by namespace. """ collection = self.get_collection() cursor = collection.find() if namespace is None else collection.find({"namespace": namespace}) return [ BaseVectorStoreDriver.Entry( id=str(doc["_id"]), vector=doc[self.vector_path], namespace=doc["namespace"], meta=doc["meta"], ) for doc in cursor ] def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, offset: Optional[int] = None, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Queries the MongoDB collection for documents that match the provided vector list. Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. """ collection = self.get_collection() count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT offset = offset or 0 pipeline = [ { "$vectorSearch": { "index": self.index_name, "path": self.vector_path, "queryVector": vector, "numCandidates": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES), "limit": count, }, }, { "$project": { "_id": 1, self.vector_path: 1, "namespace": 1, "meta": 1, "score": {"$meta": "vectorSearchScore"}, }, }, ] if namespace: pipeline[0]["$vectorSearch"]["filter"] = {"namespace": namespace} return [ BaseVectorStoreDriver.Entry( id=str(doc["_id"]), vector=doc[self.vector_path] if include_vectors else [], score=doc["score"], meta=doc["meta"], namespace=namespace, ) for doc in collection.aggregate(pipeline) ] def delete_vector(self, vector_id: str) -> None: """Deletes the vector from the collection.""" collection = self.get_collection() collection.delete_one({"_id": vector_id})
MAX_NUM_CANDIDATES = 10000
class-attribute instance-attribute_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributecollection_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeconnection_string = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributedatabase_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeindex_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributenum_candidates_multiplier = field(default=10, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributevector_path = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
@lazy_property() def client(self) -> MongoClient: return import_optional_dependency("pymongo").MongoClient(self.connection_string)
delete_vector(vector_id)
Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def delete_vector(self, vector_id: str) -> None: """Deletes the vector from the collection.""" collection = self.get_collection() collection.delete_one({"_id": vector_id})
get_collection()
Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def get_collection(self) -> Collection: """Returns the MongoDB Collection instance for the specified database and collection name.""" return self.client[self.database_name][self.collection_name]
load_entries(*, namespace=None)
Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Loads all document entries from the MongoDB collection. Entries can optionally be filtered by namespace. """ collection = self.get_collection() cursor = collection.find() if namespace is None else collection.find({"namespace": namespace}) return [ BaseVectorStoreDriver.Entry( id=str(doc["_id"]), vector=doc[self.vector_path], namespace=doc["namespace"], meta=doc["meta"], ) for doc in cursor ]
load_entry(vector_id, *, namespace=None)
Loads a document entry from the MongoDB collection based on the vector ID.
Returns
Type | Description |
---|---|
Optional[Entry] | The loaded Entry if found; otherwise, None is returned. |
Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Loads a document entry from the MongoDB collection based on the vector ID. Returns: The loaded Entry if found; otherwise, None is returned. """ collection = self.get_collection() if namespace: doc = collection.find_one({"_id": vector_id, "namespace": namespace}) else: doc = collection.find_one({"_id": vector_id}) if doc is None: return doc return BaseVectorStoreDriver.Entry( id=str(doc["_id"]), vector=doc[self.vector_path], namespace=doc["namespace"], meta=doc["meta"], )
query_vector(vector, *, count=None, namespace=None, include_vectors=False, offset=None, **kwargs)
Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, offset: Optional[int] = None, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Queries the MongoDB collection for documents that match the provided vector list. Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index. """ collection = self.get_collection() count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT offset = offset or 0 pipeline = [ { "$vectorSearch": { "index": self.index_name, "path": self.vector_path, "queryVector": vector, "numCandidates": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES), "limit": count, }, }, { "$project": { "_id": 1, self.vector_path: 1, "namespace": 1, "meta": 1, "score": {"$meta": "vectorSearchScore"}, }, }, ] if namespace: pipeline[0]["$vectorSearch"]["filter"] = {"namespace": namespace} return [ BaseVectorStoreDriver.Entry( id=str(doc["_id"]), vector=doc[self.vector_path] if include_vectors else [], score=doc["score"], meta=doc["meta"], namespace=namespace, ) for doc in collection.aggregate(pipeline) ]
upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: """Inserts or updates a vector in the collection. If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. """ collection = self.get_collection() if vector_id is None: result = collection.insert_one({self.vector_path: vector, "namespace": namespace, "meta": meta}) vector_id = str(result.inserted_id) else: collection.replace_one( {"_id": vector_id}, {self.vector_path: vector, "namespace": namespace, "meta": meta}, upsert=True, ) return vector_id
NoOpObservabilityDriver
Bases:
BaseObservabilityDriver
Source Code in griptape/drivers/observability/no_op_observability_driver.py
@define class NoOpObservabilityDriver(BaseObservabilityDriver): def observe(self, call: Observable.Call) -> Any: return call() def get_span_id(self) -> Optional[str]: return None
get_span_id()
Source Code in griptape/drivers/observability/no_op_observability_driver.py
def get_span_id(self) -> Optional[str]: return None
observe(call)
Source Code in griptape/drivers/observability/no_op_observability_driver.py
def observe(self, call: Observable.Call) -> Any: return call()
OllamaEmbeddingDriver
Bases:
BaseEmbeddingDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | Ollama embedding model name. |
host | Optional[str] | Optional Ollama host. |
client | Client | Ollama Client . |
Source Code in griptape/drivers/embedding/ollama_embedding_driver.py
@define class OllamaEmbeddingDriver(BaseEmbeddingDriver): """Ollama Embedding Driver. Attributes: model: Ollama embedding model name. host: Optional Ollama host. client: Ollama `Client`. """ model: str = field(kw_only=True, metadata={"serializable": True}) host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> Client: return import_optional_dependency("ollama").Client(host=self.host) def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributehost = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/embedding/ollama_embedding_driver.py
@lazy_property() def client(self) -> Client: return import_optional_dependency("ollama").Client(host=self.host)
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/ollama_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
OllamaPromptDriver
Bases:
BasePromptDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | Model name. |
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
@define class OllamaPromptDriver(BasePromptDriver): """Ollama Prompt Driver. Attributes: model: Model name. """ model: str = field(kw_only=True, metadata={"serializable": True}) host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) tokenizer: BaseTokenizer = field( default=Factory( lambda self: SimpleTokenizer( characters_per_token=4, max_input_tokens=2000, max_output_tokens=self.max_tokens, ), takes_self=True, ), kw_only=True, ) options: dict = field( default=Factory( lambda self: { "temperature": self.temperature, "stop": self.tokenizer.stop_sequences, "num_predict": self.max_tokens, }, takes_self=True, ), kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> Client: return import_optional_dependency("ollama").Client(host=self.host) @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.chat(**params) logger.debug(response.model_dump()) return Message( content=self.__to_prompt_stack_message_content(response), role=Message.ASSISTANT_ROLE, ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = {**self._base_params(prompt_stack), "stream": True} logger.debug(params) stream: Iterator = self.client.chat(**params) tool_index = 0 for chunk in stream: logger.debug(chunk) message_content = self.__to_prompt_stack_delta_message_content(chunk) # Ollama provides multiple Tool calls as separate chunks but with no index to differentiate them. # So we must keep track of the index ourselves. if isinstance(message_content, ActionCallDeltaMessageContent): message_content.index = tool_index tool_index += 1 yield DeltaMessage(content=message_content) def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self._prompt_stack_to_messages(prompt_stack) params = { "messages": messages, "model": self.model, "options": self.options, **self.extra_params, } if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": params["format"] = prompt_stack.to_output_json_schema() # Tool calling is only supported when not streaming if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_ollama_tools(prompt_stack.tools) return params def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: ollama_messages = [] for message in prompt_stack.messages: action_result_contents = message.get_content_type(ActionResultMessageContent) # Function calls need to be handled separately from the rest of the message content if action_result_contents: ollama_messages.extend( [ { "role": self.__to_ollama_role(message, action_result_content), "content": self.__to_ollama_message_content(action_result_content), } for action_result_content in action_result_contents ], ) text_contents = message.get_content_type(TextMessageContent) if text_contents: ollama_messages.append({"role": self.__to_ollama_role(message), "content": message.to_text()}) else: ollama_message: dict[str, Any] = { "role": self.__to_ollama_role(message), "content": message.to_text(), } action_call_contents = message.get_content_type(ActionCallMessageContent) if action_call_contents: ollama_message["tool_calls"] = [ self.__to_ollama_message_content(action_call_content) for action_call_content in action_call_contents ] image_contents = message.get_content_type(ImageMessageContent) if image_contents: ollama_message["images"] = [ self.__to_ollama_message_content(image_content) for image_content in image_contents ] ollama_messages.append(ollama_message) return ollama_messages def __to_ollama_message_content(self, content: BaseMessageContent) -> str | dict: if isinstance(content, TextMessageContent): return content.artifact.to_text() if isinstance(content, ImageMessageContent): if isinstance(content.artifact, ImageArtifact): return content.artifact.base64 # TODO: add support for image urls once resolved https://github.com/ollama/ollama/issues/4474 raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}") if isinstance(content, ActionCallMessageContent): action = content.artifact.value return { "type": "function", "id": action.tag, "function": {"name": action.to_native_tool_name(), "arguments": action.input}, } if isinstance(content, ActionResultMessageContent): return content.artifact.to_text() raise ValueError(f"Unsupported content type: {type(content)}") def __to_ollama_tools(self, tools: list[BaseTool]) -> list[dict]: ollama_tools = [] for tool in tools: for activity in tool.activities(): ollama_tool = { "function": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), }, "type": "function", } activity_schema = tool.activity_schema(activity) if activity_schema is not None: ollama_tool["function"]["parameters"] = tool.to_activity_json_schema(activity, "Parameters Schema")[ "properties" ]["values"] ollama_tools.append(ollama_tool) return ollama_tools def __to_ollama_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str: if message.is_system(): return "system" if message.is_assistant(): return "assistant" if isinstance(message_content, ActionResultMessageContent): return "tool" return "user" def __to_prompt_stack_message_content(self, response: ChatResponse) -> list[BaseMessageContent]: content = [] message = response["message"] if message.get("content"): content.append(TextMessageContent(TextArtifact(response["message"]["content"]))) if "tool_calls" in message: content.extend( [ ActionCallMessageContent( ActionArtifact( ToolAction( tag=tool_call["function"]["name"], name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0], path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1], input=tool_call["function"]["arguments"], ), ), ) for tool_call in message["tool_calls"] ], ) return content def __to_prompt_stack_delta_message_content(self, content_delta: ChatResponse) -> BaseDeltaMessageContent: message = content_delta["message"] if message.get("content"): return TextDeltaMessageContent(message["content"]) if "tool_calls" in message and len(message["tool_calls"]): tool_calls = message["tool_calls"] # Ollama doesn't _really_ support Tool streaming. They provide the full tool call at once. # Multiple, parallel, Tool calls are provided as multiple content deltas. # Tracking here: https://github.com/ollama/ollama/issues/7886 tool_call = tool_calls[0] return ActionCallDeltaMessageContent( tag=tool_call["function"]["name"], name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0], path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1], partial_input=json.dumps(tool_call["function"]["arguments"]), ) return TextDeltaMessageContent("")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributehost = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeoptions = field(default=Factory(lambda self: {'temperature': self.temperature, 'stop': self.tokenizer.stop_sequences, 'num_predict': self.max_tokens}, takes_self=True), kw_only=True)
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: SimpleTokenizer(characters_per_token=4, max_input_tokens=2000, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True)
class-attribute instance-attributeuse_native_tools = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
__to_ollama_message_content(content)
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def __to_ollama_message_content(self, content: BaseMessageContent) -> str | dict: if isinstance(content, TextMessageContent): return content.artifact.to_text() if isinstance(content, ImageMessageContent): if isinstance(content.artifact, ImageArtifact): return content.artifact.base64 # TODO: add support for image urls once resolved https://github.com/ollama/ollama/issues/4474 raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}") if isinstance(content, ActionCallMessageContent): action = content.artifact.value return { "type": "function", "id": action.tag, "function": {"name": action.to_native_tool_name(), "arguments": action.input}, } if isinstance(content, ActionResultMessageContent): return content.artifact.to_text() raise ValueError(f"Unsupported content type: {type(content)}")
__to_ollama_role(message, message_content=None)
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def __to_ollama_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str: if message.is_system(): return "system" if message.is_assistant(): return "assistant" if isinstance(message_content, ActionResultMessageContent): return "tool" return "user"
__to_ollama_tools(tools)
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def __to_ollama_tools(self, tools: list[BaseTool]) -> list[dict]: ollama_tools = [] for tool in tools: for activity in tool.activities(): ollama_tool = { "function": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), }, "type": "function", } activity_schema = tool.activity_schema(activity) if activity_schema is not None: ollama_tool["function"]["parameters"] = tool.to_activity_json_schema(activity, "Parameters Schema")[ "properties" ]["values"] ollama_tools.append(ollama_tool) return ollama_tools
__to_prompt_stack_delta_message_content(content_delta)
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, content_delta: ChatResponse) -> BaseDeltaMessageContent: message = content_delta["message"] if message.get("content"): return TextDeltaMessageContent(message["content"]) if "tool_calls" in message and len(message["tool_calls"]): tool_calls = message["tool_calls"] # Ollama doesn't _really_ support Tool streaming. They provide the full tool call at once. # Multiple, parallel, Tool calls are provided as multiple content deltas. # Tracking here: https://github.com/ollama/ollama/issues/7886 tool_call = tool_calls[0] return ActionCallDeltaMessageContent( tag=tool_call["function"]["name"], name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0], path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1], partial_input=json.dumps(tool_call["function"]["arguments"]), ) return TextDeltaMessageContent("")
__to_prompt_stack_message_content(response)
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def __to_prompt_stack_message_content(self, response: ChatResponse) -> list[BaseMessageContent]: content = [] message = response["message"] if message.get("content"): content.append(TextMessageContent(TextArtifact(response["message"]["content"]))) if "tool_calls" in message: content.extend( [ ActionCallMessageContent( ActionArtifact( ToolAction( tag=tool_call["function"]["name"], name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0], path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1], input=tool_call["function"]["arguments"], ), ), ) for tool_call in message["tool_calls"] ], ) return content
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: messages = self._prompt_stack_to_messages(prompt_stack) params = { "messages": messages, "model": self.model, "options": self.options, **self.extra_params, } if prompt_stack.output_schema is not None and self.structured_output_strategy == "native": params["format"] = prompt_stack.to_output_json_schema() # Tool calling is only supported when not streaming if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_ollama_tools(prompt_stack.tools) return params
_prompt_stack_to_messages(prompt_stack)
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: ollama_messages = [] for message in prompt_stack.messages: action_result_contents = message.get_content_type(ActionResultMessageContent) # Function calls need to be handled separately from the rest of the message content if action_result_contents: ollama_messages.extend( [ { "role": self.__to_ollama_role(message, action_result_content), "content": self.__to_ollama_message_content(action_result_content), } for action_result_content in action_result_contents ], ) text_contents = message.get_content_type(TextMessageContent) if text_contents: ollama_messages.append({"role": self.__to_ollama_role(message), "content": message.to_text()}) else: ollama_message: dict[str, Any] = { "role": self.__to_ollama_role(message), "content": message.to_text(), } action_call_contents = message.get_content_type(ActionCallMessageContent) if action_call_contents: ollama_message["tool_calls"] = [ self.__to_ollama_message_content(action_call_content) for action_call_content in action_call_contents ] image_contents = message.get_content_type(ImageMessageContent) if image_contents: ollama_message["images"] = [ self.__to_ollama_message_content(image_content) for image_content in image_contents ] ollama_messages.append(ollama_message) return ollama_messages
client()
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
@lazy_property() def client(self) -> Client: return import_optional_dependency("ollama").Client(host=self.host)
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) response = self.client.chat(**params) logger.debug(response.model_dump()) return Message( content=self.__to_prompt_stack_message_content(response), role=Message.ASSISTANT_ROLE, )
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = {**self._base_params(prompt_stack), "stream": True} logger.debug(params) stream: Iterator = self.client.chat(**params) tool_index = 0 for chunk in stream: logger.debug(chunk) message_content = self.__to_prompt_stack_delta_message_content(chunk) # Ollama provides multiple Tool calls as separate chunks but with no index to differentiate them. # So we must keep track of the index ourselves. if isinstance(message_content, ActionCallDeltaMessageContent): message_content.index = tool_index tool_index += 1 yield DeltaMessage(content=message_content)
OpenAiAssistantDriver
Bases:
BaseAssistantDriver
Source Code in griptape/drivers/assistant/openai_assistant_driver.py
@define class OpenAiAssistantDriver(BaseAssistantDriver): class EventHandler(AssistantEventHandler): @override def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: if delta.value is not None: EventBus.publish_event(TextChunkEvent(token=delta.value)) @override def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None: if delta.type == "code_interpreter" and delta.code_interpreter is not None: if delta.code_interpreter.input: EventBus.publish_event(TextChunkEvent(token=delta.code_interpreter.input)) if delta.code_interpreter.outputs: EventBus.publish_event(TextChunkEvent(token="\n\noutput >")) for output in delta.code_interpreter.outputs: if output.type == "logs" and output.logs: EventBus.publish_event(TextChunkEvent(token=output.logs)) base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) thread_id: Optional[str] = field(default=None, kw_only=True) assistant_id: str = field(kw_only=True) event_handler: AssistantEventHandler = field( default=Factory(lambda: OpenAiAssistantDriver.EventHandler()), kw_only=True, metadata={"serializable": False} ) auto_create_thread: bool = field(default=True, kw_only=True) _client: Optional[openai.OpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI( base_url=self.base_url, api_key=self.api_key, organization=self.organization, ) def try_run(self, *args: BaseArtifact) -> TextArtifact: if self.thread_id is None: if self.auto_create_thread: thread_id = self.client.beta.threads.create().id self.thread_id = thread_id else: raise ValueError("Thread ID is required but not provided and auto_create_thread is disabled.") else: thread_id = self.thread_id response = self._create_run(thread_id, *args) response.meta.update({"assistant_id": self.assistant_id, "thread_id": self.thread_id}) return response def _create_run(self, thread_id: str, *args: BaseArtifact) -> TextArtifact: content = "\n".join(arg.value for arg in args) message_id = self.client.beta.threads.messages.create(thread_id=thread_id, role="user", content=content) with self.client.beta.threads.runs.stream( thread_id=thread_id, assistant_id=self.assistant_id, event_handler=self.event_handler, ) as stream: stream.until_done() last_messages = stream.get_final_messages() message_contents = [] for message in last_messages: message_contents.append( "".join(content.text.value for content in message.content if content.type == "TextContentBlock") ) message_text = "\n".join(message_contents) response = TextArtifact(message_text) response.meta.update( {"assistant_id": self.assistant_id, "thread_id": self.thread_id, "message_id": message_id} ) return response
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeassistant_id = field(kw_only=True)
class-attribute instance-attributeauto_create_thread = field(default=True, kw_only=True)
class-attribute instance-attributebase_url = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeevent_handler = field(default=Factory(lambda: OpenAiAssistantDriver.EventHandler()), kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeorganization = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributethread_id = field(default=None, kw_only=True)
class-attribute instance-attribute
EventHandler
Bases:
AssistantEventHandlerSource Code in griptape/drivers/assistant/openai_assistant_driver.py
class EventHandler(AssistantEventHandler): @override def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: if delta.value is not None: EventBus.publish_event(TextChunkEvent(token=delta.value)) @override def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None: if delta.type == "code_interpreter" and delta.code_interpreter is not None: if delta.code_interpreter.input: EventBus.publish_event(TextChunkEvent(token=delta.code_interpreter.input)) if delta.code_interpreter.outputs: EventBus.publish_event(TextChunkEvent(token="\n\noutput >")) for output in delta.code_interpreter.outputs: if output.type == "logs" and output.logs: EventBus.publish_event(TextChunkEvent(token=output.logs))
_create_run(thread_id, *args)
Source Code in griptape/drivers/assistant/openai_assistant_driver.py
def _create_run(self, thread_id: str, *args: BaseArtifact) -> TextArtifact: content = "\n".join(arg.value for arg in args) message_id = self.client.beta.threads.messages.create(thread_id=thread_id, role="user", content=content) with self.client.beta.threads.runs.stream( thread_id=thread_id, assistant_id=self.assistant_id, event_handler=self.event_handler, ) as stream: stream.until_done() last_messages = stream.get_final_messages() message_contents = [] for message in last_messages: message_contents.append( "".join(content.text.value for content in message.content if content.type == "TextContentBlock") ) message_text = "\n".join(message_contents) response = TextArtifact(message_text) response.meta.update( {"assistant_id": self.assistant_id, "thread_id": self.thread_id, "message_id": message_id} ) return response
client()
Source Code in griptape/drivers/assistant/openai_assistant_driver.py
@lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI( base_url=self.base_url, api_key=self.api_key, organization=self.organization, )
try_run(*args)
Source Code in griptape/drivers/assistant/openai_assistant_driver.py
def try_run(self, *args: BaseArtifact) -> TextArtifact: if self.thread_id is None: if self.auto_create_thread: thread_id = self.client.beta.threads.create().id self.thread_id = thread_id else: raise ValueError("Thread ID is required but not provided and auto_create_thread is disabled.") else: thread_id = self.thread_id response = self._create_run(thread_id, *args) response.meta.update({"assistant_id": self.assistant_id, "thread_id": self.thread_id}) return response
OpenAiAudioTranscriptionDriver
Bases:
BaseAudioTranscriptionDriver
Source Code in griptape/drivers/audio_transcription/openai_audio_transcription_driver.py
@define class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver): api_type: Optional[str] = field(default=openai.api_type, kw_only=True) api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True}) base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) _client: Optional[openai.OpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: additional_params = {} if prompts is not None: additional_params["prompt"] = ", ".join(prompts) transcription = self.client.audio.transcriptions.create( # Even though we're not actually providing a file to the client, the API still requires that we send a file # name. We set the file name to use the same format as the audio file so that the API can reject # it if the format is unsupported. model=self.model, file=(f"a.{audio.format}", io.BytesIO(audio.value)), response_format="json", **additional_params, ) return TextArtifact(value=transcription.text)
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeapi_type = field(default=openai.api_type, kw_only=True)
class-attribute instance-attributeapi_version = field(default=openai.api_version, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributebase_url = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeorganization = field(default=openai.organization, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/audio_transcription/openai_audio_transcription_driver.py
@lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)
try_run(audio, prompts=None)
Source Code in griptape/drivers/audio_transcription/openai_audio_transcription_driver.py
def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: additional_params = {} if prompts is not None: additional_params["prompt"] = ", ".join(prompts) transcription = self.client.audio.transcriptions.create( # Even though we're not actually providing a file to the client, the API still requires that we send a file # name. We set the file name to use the same format as the audio file so that the API can reject # it if the format is unsupported. model=self.model, file=(f"a.{audio.format}", io.BytesIO(audio.value)), response_format="json", **additional_params, ) return TextArtifact(value=transcription.text)
OpenAiChatPromptDriver
Bases:
BasePromptDriver
Attributes
Name | Type | Description |
---|---|---|
base_url | Optional[str] | An optional OpenAi API URL. |
api_key | Optional[str] | An optional OpenAi API key. If not provided, the OPENAI_API_KEY environment variable will be used. |
organization | Optional[str] | An optional OpenAI organization. If not provided, the OPENAI_ORG_ID environment variable will be used. |
client | OpenAI | An openai.OpenAI client. |
model | str | An OpenAI model name. |
tokenizer | BaseTokenizer | An OpenAiTokenizer . |
user | str | A user id. Can be used to track requests by user. |
response_format | Optional[dict] | An optional OpenAi Chat Completion response format. Currently only supports json_object which will enable OpenAi's JSON mode. |
seed | Optional[int] | An optional OpenAi Chat Completion seed. |
ignored_exception_types | tuple[type[Exception], ...] | An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types. |
parallel_tool_calls | bool | A flag to enable parallel tool calls. Defaults to True . |
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
@define class OpenAiChatPromptDriver(BasePromptDriver): """OpenAI Chat Prompt Driver. Attributes: base_url: An optional OpenAi API URL. api_key: An optional OpenAi API key. If not provided, the `OPENAI_API_KEY` environment variable will be used. organization: An optional OpenAI organization. If not provided, the `OPENAI_ORG_ID` environment variable will be used. client: An `openai.OpenAI` client. model: An OpenAI model name. tokenizer: An `OpenAiTokenizer`. user: A user id. Can be used to track requests by user. response_format: An optional OpenAi Chat Completion response format. Currently only supports `json_object` which will enable OpenAi's JSON mode. seed: An optional OpenAi Chat Completion seed. ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types. parallel_tool_calls: A flag to enable parallel tool calls. Defaults to `True`. """ base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) tokenizer: BaseTokenizer = field( default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True, ) user: str = field(default="", kw_only=True, metadata={"serializable": True}) response_format: Optional[dict] = field( default=None, kw_only=True, metadata={"serializable": True}, ) seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False}) reasoning_effort: Literal["low", "medium", "high"] = field( default="medium", kw_only=True, metadata={"serializable": True} ) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: StructuredOutputStrategy = field( default="native", kw_only=True, metadata={"serializable": True} ) parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True}) ignored_exception_types: tuple[type[Exception], ...] = field( default=Factory( lambda: ( openai.BadRequestError, openai.AuthenticationError, openai.PermissionDeniedError, openai.NotFoundError, openai.ConflictError, openai.UnprocessableEntityError, ), ), kw_only=True, ) modalities: list[str] = field(factory=list, kw_only=True, metadata={"serializable": True}) audio: dict = field( default=Factory(lambda: {"voice": "alloy", "format": "pcm16"}), kw_only=True, metadata={"serializable": True} ) _client: Optional[openai.OpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI( base_url=self.base_url, api_key=self.api_key, organization=self.organization, ) @property def is_reasoning_model(self) -> bool: return self.model.startswith("o") @observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) result = self.client.chat.completions.create(**params) logger.debug(result.model_dump()) return self._to_message(result) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = self._base_params(prompt_stack) logger.debug({"stream": True, **params}) result = self.client.chat.completions.create(**params, stream=True) return self._to_delta_message_stream(result) def _to_message(self, result: ChatCompletion) -> Message: if len(result.choices) == 1: choice_message = result.choices[0].message message = Message( content=self.__to_prompt_stack_message_content(choice_message), role=Message.ASSISTANT_ROLE, ) if result.usage is not None: message.usage = Message.Usage( input_tokens=result.usage.prompt_tokens, output_tokens=result.usage.completion_tokens, ) return message raise Exception("Completion with more than one choice is not supported yet.") def _to_delta_message_stream(self, result: Stream[ChatCompletionChunk]) -> Iterator[DeltaMessage]: for message in result: if message.usage is not None: yield DeltaMessage( usage=DeltaMessage.Usage( input_tokens=message.usage.prompt_tokens, output_tokens=message.usage.completion_tokens, ), ) if message.choices: choice = message.choices[0] delta = choice.delta content = self.__to_prompt_stack_delta_message_content(delta) if content is not None: yield DeltaMessage(content=content) def _base_params(self, prompt_stack: PromptStack) -> dict: params = { "model": self.model, **({"user": self.user} if self.user else {}), **({"seed": self.seed} if self.seed is not None else {}), **({"modalities": self.modalities} if self.modalities and not self.is_reasoning_model else {}), **( {"reasoning_effort": self.reasoning_effort} if self.is_reasoning_model and self.model != "o1-mini" else {} ), **({"temperature": self.temperature} if not self.is_reasoning_model else {}), **({"audio": self.audio} if "audio" in self.modalities else {}), **( {"stop": self.tokenizer.stop_sequences} if not self.is_reasoning_model and self.tokenizer.stop_sequences else {} ), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), **({"stream_options": {"include_usage": True}} if self.stream else {}), **self.extra_params, } if prompt_stack.tools and self.use_native_tools: params["tool_choice"] = self.tool_choice params["parallel_tool_calls"] = self.parallel_tool_calls if prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_schema", "json_schema": { "name": "Output", "schema": prompt_stack.to_output_json_schema(), "strict": True, }, } elif self.structured_output_strategy == "tool" and self.use_native_tools: params["tool_choice"] = "required" if self.response_format is not None: if self.response_format == {"type": "json_object"}: params["response_format"] = self.response_format # JSON mode still requires a system message instructing the LLM to output JSON. prompt_stack.add_system_message("Provide your response as a valid JSON object.") else: params["response_format"] = self.response_format if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_openai_tools(prompt_stack.tools) messages = self.__to_openai_messages(prompt_stack.messages) params["messages"] = messages return params def __to_openai_messages(self, messages: list[Message]) -> list[dict]: openai_messages = [] for message in messages: # If the message only contains textual content we can send it as a single content. if message.has_all_content_type(TextMessageContent): openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()}) # Action results must be sent as separate messages. elif action_result_contents := message.get_content_type(ActionResultMessageContent): openai_messages.extend( { "role": self.__to_openai_role(message, action_result_content), "content": self.__to_openai_message_content(action_result_content), "tool_call_id": action_result_content.action.tag, } for action_result_content in action_result_contents ) if message.has_any_content_type(TextMessageContent): openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()}) else: openai_message = { "role": self.__to_openai_role(message), "content": [], } for content in message.content: if isinstance(content, ActionCallMessageContent): if "tool_calls" not in openai_message: openai_message["tool_calls"] = [] openai_message["tool_calls"].append(self.__to_openai_message_content(content)) elif ( isinstance(content, AudioMessageContent) and message.is_assistant() and time.time() < content.artifact.meta["expires_at"] ): # For assistant audio messages, we reference the audio id instead of sending audio message content. openai_message["audio"] = { "id": content.artifact.meta["audio_id"], } else: openai_message["content"].append(self.__to_openai_message_content(content)) # Some OpenAi-compatible services don't accept an empty array for content if not openai_message["content"]: del openai_message["content"] openai_messages.append(openai_message) return openai_messages def __to_openai_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str: if message.is_system(): if self.is_reasoning_model: return "developer" return "system" if message.is_assistant(): return "assistant" if isinstance(message_content, ActionResultMessageContent): return "tool" return "user" def __to_openai_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "function": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "parameters": tool.to_activity_json_schema(activity, "Parameters Schema"), }, "type": "function", } for tool in tools for activity in tool.activities() ] def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict: if isinstance(content, TextMessageContent): return {"type": "text", "text": content.artifact.to_text()} if isinstance(content, ImageMessageContent): if isinstance(content.artifact, ImageArtifact): return { "type": "image_url", "image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"}, } if isinstance(content.artifact, ImageUrlArtifact): return { "type": "image_url", "image_url": {"url": content.artifact.value}, } raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}") if isinstance(content, AudioMessageContent): artifact = content.artifact metadata = artifact.meta # If there's an expiration date, we can assume it's an assistant message. if "expires_at" in metadata: # If it's expired, we send the transcript instead. if time.time() >= metadata["expires_at"]: return { "type": "text", "text": artifact.meta.get("transcript"), } # This should never occur, since a non-expired audio content # should have already been referenced by the audio id. raise ValueError("Assistant audio messages should be sent as audio ids.") # If there's no expiration date, we can assume it's a user message where we send the audio every time. return { "type": "input_audio", "input_audio": { "data": base64.b64encode(artifact.value).decode("utf-8"), "format": artifact.format, }, } if isinstance(content, ActionCallMessageContent): action = content.artifact.value return { "type": "function", "id": action.tag, "function": {"name": action.to_native_tool_name(), "arguments": json.dumps(action.input)}, } if isinstance(content, ActionResultMessageContent): return content.artifact.to_text() return {"type": "text", "text": content.artifact.to_text()} def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) -> list[BaseMessageContent]: content = [] if response.content is not None: content.append(TextMessageContent(TextArtifact(response.content))) if hasattr(response, "audio") and response.audio is not None: content.append( AudioMessageContent( AudioArtifact( value=base64.b64decode(response.audio.data), format="wav", meta={ "audio_id": response.audio.id, "transcript": response.audio.transcript, "expires_at": response.audio.expires_at, }, ) ) ) if response.tool_calls is not None: content.extend( [ ActionCallMessageContent( ActionArtifact( ToolAction( tag=tool_call.id, name=ToolAction.from_native_tool_name(tool_call.function.name)[0], path=ToolAction.from_native_tool_name(tool_call.function.name)[1], input=json.loads(tool_call.function.arguments), ), ), ) for tool_call in response.tool_calls ], ) return content def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> Optional[BaseDeltaMessageContent]: if content_delta.content is not None: return TextDeltaMessageContent(content_delta.content) if content_delta.tool_calls is not None: tool_calls = content_delta.tool_calls if len(tool_calls) == 1: tool_call = tool_calls[0] index = tool_call.index if tool_call.function is not None: function_name = tool_call.function.name return ActionCallDeltaMessageContent( index=index, tag=tool_call.id, name=ToolAction.from_native_tool_name(function_name)[0] if function_name else None, path=ToolAction.from_native_tool_name(function_name)[1] if function_name else None, partial_input=tool_call.function.arguments, ) raise ValueError(f"Unsupported tool call delta: {tool_call}") raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}") # OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr. if hasattr(content_delta, "audio") and getattr(content_delta, "audio") is not None: audio_chunk: dict = getattr(content_delta, "audio") return AudioDeltaMessageContent( id=audio_chunk.get("id"), data=audio_chunk.get("data"), expires_at=audio_chunk.get("expires_at"), transcript=audio_chunk.get("transcript"), ) return None
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeaudio = field(default=Factory(lambda: {'voice': 'alloy', 'format': 'pcm16'}), kw_only=True, metadata={'serializable': True})
class-attribute instance-attributebase_url = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeignored_exception_types = field(default=Factory(lambda: (openai.BadRequestError, openai.AuthenticationError, openai.PermissionDeniedError, openai.NotFoundError, openai.ConflictError, openai.UnprocessableEntityError)), kw_only=True)
class-attribute instance-attributeis_reasoning_model
propertymodalities = field(factory=list, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeorganization = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeparallel_tool_calls = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributereasoning_effort = field(default='medium', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeresponse_format = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeseed = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestructured_output_strategy = field(default='native', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True)
class-attribute instance-attributetool_choice = field(default='auto', kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeuse_native_tools = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeuser = field(default='', kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
__to_openai_message_content(content)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict: if isinstance(content, TextMessageContent): return {"type": "text", "text": content.artifact.to_text()} if isinstance(content, ImageMessageContent): if isinstance(content.artifact, ImageArtifact): return { "type": "image_url", "image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"}, } if isinstance(content.artifact, ImageUrlArtifact): return { "type": "image_url", "image_url": {"url": content.artifact.value}, } raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}") if isinstance(content, AudioMessageContent): artifact = content.artifact metadata = artifact.meta # If there's an expiration date, we can assume it's an assistant message. if "expires_at" in metadata: # If it's expired, we send the transcript instead. if time.time() >= metadata["expires_at"]: return { "type": "text", "text": artifact.meta.get("transcript"), } # This should never occur, since a non-expired audio content # should have already been referenced by the audio id. raise ValueError("Assistant audio messages should be sent as audio ids.") # If there's no expiration date, we can assume it's a user message where we send the audio every time. return { "type": "input_audio", "input_audio": { "data": base64.b64encode(artifact.value).decode("utf-8"), "format": artifact.format, }, } if isinstance(content, ActionCallMessageContent): action = content.artifact.value return { "type": "function", "id": action.tag, "function": {"name": action.to_native_tool_name(), "arguments": json.dumps(action.input)}, } if isinstance(content, ActionResultMessageContent): return content.artifact.to_text() return {"type": "text", "text": content.artifact.to_text()}
__to_openai_messages(messages)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_openai_messages(self, messages: list[Message]) -> list[dict]: openai_messages = [] for message in messages: # If the message only contains textual content we can send it as a single content. if message.has_all_content_type(TextMessageContent): openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()}) # Action results must be sent as separate messages. elif action_result_contents := message.get_content_type(ActionResultMessageContent): openai_messages.extend( { "role": self.__to_openai_role(message, action_result_content), "content": self.__to_openai_message_content(action_result_content), "tool_call_id": action_result_content.action.tag, } for action_result_content in action_result_contents ) if message.has_any_content_type(TextMessageContent): openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()}) else: openai_message = { "role": self.__to_openai_role(message), "content": [], } for content in message.content: if isinstance(content, ActionCallMessageContent): if "tool_calls" not in openai_message: openai_message["tool_calls"] = [] openai_message["tool_calls"].append(self.__to_openai_message_content(content)) elif ( isinstance(content, AudioMessageContent) and message.is_assistant() and time.time() < content.artifact.meta["expires_at"] ): # For assistant audio messages, we reference the audio id instead of sending audio message content. openai_message["audio"] = { "id": content.artifact.meta["audio_id"], } else: openai_message["content"].append(self.__to_openai_message_content(content)) # Some OpenAi-compatible services don't accept an empty array for content if not openai_message["content"]: del openai_message["content"] openai_messages.append(openai_message) return openai_messages
__to_openai_role(message, message_content=None)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_openai_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str: if message.is_system(): if self.is_reasoning_model: return "developer" return "system" if message.is_assistant(): return "assistant" if isinstance(message_content, ActionResultMessageContent): return "tool" return "user"
__to_openai_tools(tools)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_openai_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "function": { "name": tool.to_native_tool_name(activity), "description": tool.activity_description(activity), "parameters": tool.to_activity_json_schema(activity, "Parameters Schema"), }, "type": "function", } for tool in tools for activity in tool.activities() ]
__to_prompt_stack_delta_message_content(content_delta)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> Optional[BaseDeltaMessageContent]: if content_delta.content is not None: return TextDeltaMessageContent(content_delta.content) if content_delta.tool_calls is not None: tool_calls = content_delta.tool_calls if len(tool_calls) == 1: tool_call = tool_calls[0] index = tool_call.index if tool_call.function is not None: function_name = tool_call.function.name return ActionCallDeltaMessageContent( index=index, tag=tool_call.id, name=ToolAction.from_native_tool_name(function_name)[0] if function_name else None, path=ToolAction.from_native_tool_name(function_name)[1] if function_name else None, partial_input=tool_call.function.arguments, ) raise ValueError(f"Unsupported tool call delta: {tool_call}") raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}") # OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr. if hasattr(content_delta, "audio") and getattr(content_delta, "audio") is not None: audio_chunk: dict = getattr(content_delta, "audio") return AudioDeltaMessageContent( id=audio_chunk.get("id"), data=audio_chunk.get("data"), expires_at=audio_chunk.get("expires_at"), transcript=audio_chunk.get("transcript"), ) return None
__to_prompt_stack_message_content(response)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) -> list[BaseMessageContent]: content = [] if response.content is not None: content.append(TextMessageContent(TextArtifact(response.content))) if hasattr(response, "audio") and response.audio is not None: content.append( AudioMessageContent( AudioArtifact( value=base64.b64decode(response.audio.data), format="wav", meta={ "audio_id": response.audio.id, "transcript": response.audio.transcript, "expires_at": response.audio.expires_at, }, ) ) ) if response.tool_calls is not None: content.extend( [ ActionCallMessageContent( ActionArtifact( ToolAction( tag=tool_call.id, name=ToolAction.from_native_tool_name(tool_call.function.name)[0], path=ToolAction.from_native_tool_name(tool_call.function.name)[1], input=json.loads(tool_call.function.arguments), ), ), ) for tool_call in response.tool_calls ], ) return content
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: params = { "model": self.model, **({"user": self.user} if self.user else {}), **({"seed": self.seed} if self.seed is not None else {}), **({"modalities": self.modalities} if self.modalities and not self.is_reasoning_model else {}), **( {"reasoning_effort": self.reasoning_effort} if self.is_reasoning_model and self.model != "o1-mini" else {} ), **({"temperature": self.temperature} if not self.is_reasoning_model else {}), **({"audio": self.audio} if "audio" in self.modalities else {}), **( {"stop": self.tokenizer.stop_sequences} if not self.is_reasoning_model and self.tokenizer.stop_sequences else {} ), **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}), **({"stream_options": {"include_usage": True}} if self.stream else {}), **self.extra_params, } if prompt_stack.tools and self.use_native_tools: params["tool_choice"] = self.tool_choice params["parallel_tool_calls"] = self.parallel_tool_calls if prompt_stack.output_schema is not None: if self.structured_output_strategy == "native": params["response_format"] = { "type": "json_schema", "json_schema": { "name": "Output", "schema": prompt_stack.to_output_json_schema(), "strict": True, }, } elif self.structured_output_strategy == "tool" and self.use_native_tools: params["tool_choice"] = "required" if self.response_format is not None: if self.response_format == {"type": "json_object"}: params["response_format"] = self.response_format # JSON mode still requires a system message instructing the LLM to output JSON. prompt_stack.add_system_message("Provide your response as a valid JSON object.") else: params["response_format"] = self.response_format if prompt_stack.tools and self.use_native_tools: params["tools"] = self.__to_openai_tools(prompt_stack.tools) messages = self.__to_openai_messages(prompt_stack.messages) params["messages"] = messages return params
_to_delta_message_stream(result)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def _to_delta_message_stream(self, result: Stream[ChatCompletionChunk]) -> Iterator[DeltaMessage]: for message in result: if message.usage is not None: yield DeltaMessage( usage=DeltaMessage.Usage( input_tokens=message.usage.prompt_tokens, output_tokens=message.usage.completion_tokens, ), ) if message.choices: choice = message.choices[0] delta = choice.delta content = self.__to_prompt_stack_delta_message_content(delta) if content is not None: yield DeltaMessage(content=content)
_to_message(result)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def _to_message(self, result: ChatCompletion) -> Message: if len(result.choices) == 1: choice_message = result.choices[0].message message = Message( content=self.__to_prompt_stack_message_content(choice_message), role=Message.ASSISTANT_ROLE, ) if result.usage is not None: message.usage = Message.Usage( input_tokens=result.usage.prompt_tokens, output_tokens=result.usage.completion_tokens, ) return message raise Exception("Completion with more than one choice is not supported yet.")
client()
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
@lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI( base_url=self.base_url, api_key=self.api_key, organization=self.organization, )
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: params = self._base_params(prompt_stack) logger.debug(params) result = self.client.chat.completions.create(**params) logger.debug(result.model_dump()) return self._to_message(result)
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: params = self._base_params(prompt_stack) logger.debug({"stream": True, **params}) result = self.client.chat.completions.create(**params, stream=True) return self._to_delta_message_stream(result)
OpenAiEmbeddingDriver
Bases:
BaseEmbeddingDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | OpenAI embedding model name. Defaults to text-embedding-3-small . |
base_url | Optional[str] | API URL. Defaults to OpenAI's v1 API URL. |
api_key | Optional[str] | API key to pass directly. Defaults to OPENAI_API_KEY environment variable. |
organization | Optional[str] | OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable. |
tokenizer | OpenAiTokenizer | Optionally provide custom OpenAiTokenizer . |
client | OpenAI | Optionally provide custom openai.OpenAI client. |
azure_deployment | OpenAI | An Azure OpenAi deployment id. |
azure_endpoint | OpenAI | An Azure OpenAi endpoint. |
azure_ad_token | OpenAI | An optional Azure Active Directory token. |
azure_ad_token_provider | OpenAI | An optional Azure Active Directory token provider. |
api_version | OpenAI | An Azure OpenAi API version. |
Source Code in griptape/drivers/embedding/openai_embedding_driver.py
@define class OpenAiEmbeddingDriver(BaseEmbeddingDriver): """OpenAI Embedding Driver. Attributes: model: OpenAI embedding model name. Defaults to `text-embedding-3-small`. base_url: API URL. Defaults to OpenAI's v1 API URL. api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable. organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable. tokenizer: Optionally provide custom `OpenAiTokenizer`. client: Optionally provide custom `openai.OpenAI` client. azure_deployment: An Azure OpenAi deployment id. azure_endpoint: An Azure OpenAi endpoint. azure_ad_token: An optional Azure Active Directory token. azure_ad_token_provider: An optional Azure Active Directory token provider. api_version: An Azure OpenAi API version. """ DEFAULT_MODEL = "text-embedding-3-small" model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True}) base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) tokenizer: OpenAiTokenizer = field( default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True, ) _client: Optional[openai.OpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: # Address a performance issue in older ada models # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 if self.model.endswith("001"): chunk = chunk.replace("\n", " ") return self.client.embeddings.create(**self._params(chunk)).data[0].embedding def _params(self, chunk: str) -> dict: return {"input": chunk, "model": self.model}
DEFAULT_MODEL = 'text-embedding-3-small'
class-attribute instance-attribute_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributebase_url = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeorganization = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True)
class-attribute instance-attribute
_params(chunk)
Source Code in griptape/drivers/embedding/openai_embedding_driver.py
def _params(self, chunk: str) -> dict: return {"input": chunk, "model": self.model}
client()
Source Code in griptape/drivers/embedding/openai_embedding_driver.py
@lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/openai_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: # Address a performance issue in older ada models # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500 if self.model.endswith("001"): chunk = chunk.replace("\n", " ") return self.client.embeddings.create(**self._params(chunk)).data[0].embedding
OpenAiImageGenerationDriver
Bases:
BaseImageGenerationDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | OpenAI model, for example 'dall-e-2' or 'dall-e-3'. |
api_type | Optional[str] | OpenAI API type, for example 'open_ai' or 'azure'. |
api_version | Optional[str] | API version. |
base_url | Optional[str] | API URL. |
api_key | Optional[str] | OpenAI API key. |
organization | Optional[str] | OpenAI organization ID. |
style | Optional[Literal['vivid', 'natural']] | Optional and only supported for dall-e-3, can be either 'vivid' or 'natural'. |
quality | Optional[Literal['standard', 'hd', 'low', 'medium', 'high', 'auto']] | Optional and only supported for dall-e-3. Accepts 'standard', 'hd'. |
image_size | Optional[Literal['256x256', '512x512', '1024x1024', '1024x1792', '1792x1024']] | Size of the generated image. Must be one of the following, depending on the requested model: dall-e-2: [256x256, 512x512, 1024x1024] dall-e-3: [1024x1024, 1024x1792, 1792x1024] gpt-image-1: [1024x1024, 1536x1024, 1024x1536, auto] |
response_format | Literal['b64_json'] | The response format. Currently only supports 'b64_json' which will return a base64 encoded image in a JSON object. |
background | Optional[Literal['transparent', 'opaque', 'auto']] | Optional and only supported for gpt-image-1. Can be either 'transparent', 'opaque', or 'auto'. |
moderation | Optional[Literal['low', 'auto']] | Optional and only supported for gpt-image-1. Can be either 'low' or 'auto'. |
output_compression | Optional[int] | Optional and only supported for gpt-image-1. Can be an integer between 0 and 100. |
output_format | Optional[Literal['png', 'jpeg']] | Optional and only supported for gpt-image-1. Can be either 'png' or 'jpeg'. |
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
@define class OpenAiImageGenerationDriver(BaseImageGenerationDriver): """Driver for the OpenAI image generation API. Attributes: model: OpenAI model, for example 'dall-e-2' or 'dall-e-3'. api_type: OpenAI API type, for example 'open_ai' or 'azure'. api_version: API version. base_url: API URL. api_key: OpenAI API key. organization: OpenAI organization ID. style: Optional and only supported for dall-e-3, can be either 'vivid' or 'natural'. quality: Optional and only supported for dall-e-3. Accepts 'standard', 'hd'. image_size: Size of the generated image. Must be one of the following, depending on the requested model: dall-e-2: [256x256, 512x512, 1024x1024] dall-e-3: [1024x1024, 1024x1792, 1792x1024] gpt-image-1: [1024x1024, 1536x1024, 1024x1536, auto] response_format: The response format. Currently only supports 'b64_json' which will return a base64 encoded image in a JSON object. background: Optional and only supported for gpt-image-1. Can be either 'transparent', 'opaque', or 'auto'. moderation: Optional and only supported for gpt-image-1. Can be either 'low' or 'auto'. output_compression: Optional and only supported for gpt-image-1. Can be an integer between 0 and 100. output_format: Optional and only supported for gpt-image-1. Can be either 'png' or 'jpeg'. """ api_type: Optional[str] = field(default=openai.api_type, kw_only=True) api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True}) base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) style: Optional[Literal["vivid", "natural"]] = field( default=None, kw_only=True, metadata={"serializable": True, "model_allowlist": ["dall-e-3"]} ) quality: Optional[Literal["standard", "hd", "low", "medium", "high", "auto"]] = field( default=None, kw_only=True, metadata={"serializable": True}, ) image_size: Optional[Literal["256x256", "512x512", "1024x1024", "1024x1792", "1792x1024"]] = field( default=None, kw_only=True, metadata={"serializable": True}, ) response_format: Literal["b64_json"] = field( default="b64_json", kw_only=True, metadata={"serializable": True, "model_denylist": ["gpt-image-1"]}, ) background: Optional[Literal["transparent", "opaque", "auto"]] = field( default=None, kw_only=True, metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]}, ) moderation: Optional[Literal["low", "auto"]] = field( default=None, kw_only=True, metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]}, ) output_compression: Optional[int] = field( default=None, kw_only=True, metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]}, ) output_format: Optional[Literal["png", "jpeg"]] = field( default=None, kw_only=True, metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]}, ) _client: Optional[openai.OpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @image_size.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_image_size(self, attribute: str, value: str | None) -> None: """Validates the image size based on the model. Must be one of `1024x1024`, `1536x1024` (landscape), `1024x1536` (portrait), or `auto` (default value) for `gpt-image-1`, one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`, and one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3`. """ if value is None: return if self.model.startswith("gpt-image"): allowed_sizes = ("1024x1024", "1536x1024", "1024x1536", "auto") elif self.model == "dall-e-2": allowed_sizes = ("256x256", "512x512", "1024x1024") elif self.model == "dall-e-3": allowed_sizes = ("1024x1024", "1792x1024", "1024x1792") else: raise NotImplementedError(f"Image size validation not implemented for model {self.model}") if value is not None and value not in allowed_sizes: raise ValueError(f"Image size, {value}, must be one of the following: {allowed_sizes}") @lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization) def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: prompt = ", ".join(prompts) response = self.client.images.generate( model=self.model, prompt=prompt, n=1, **self._build_model_params( { "size": "image_size", "quality": "quality", "style": "style", "response_format": "response_format", "background": "background", "moderation": "moderation", "output_compression": "output_compression", "output_format": "output_format", } ), ) return self._parse_image_response(response, prompt) def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: """Creates a variation of an image. Only supported by for dall-e-2. Requires image size to be one of the following: [256x256, 512x512, 1024x1024] """ if self.model != "dall-e-2": raise NotImplementedError("Image variation only supports dall-e-2") response = self.client.images.create_variation( image=image.value, n=1, response_format=self.response_format, size=self.image_size, # pyright: ignore[reportArgumentType] ) return self._parse_image_response(response, "") def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: prompt = ", ".join(prompts) response = self.client.images.edit( prompt=prompt, image=image.value, mask=mask.value, **self._build_model_params( { "size": "image_size", "response_format": "response_format", } ), ) return self._parse_image_response(response, prompt) def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting") def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageArtifact: from griptape.loaders.image_loader import ImageLoader if response.data is None or response.data[0] is None or response.data[0].b64_json is None: raise Exception("Failed to generate image") image_data = base64.b64decode(response.data[0].b64_json) image_artifact = ImageLoader().parse(image_data) image_artifact.meta["prompt"] = prompt image_artifact.meta["model"] = self.model return image_artifact def _build_model_params(self, values: dict) -> dict: """Builds parameters while considering field metadata and None values. Args: values: A dictionary mapping parameter names to field names. Field will be added to the params dictionary if all conditions are met: - The field value is not None - The model_allowlist is None or the model is in the allowlist - The model_denylist is None or the model is not in the denylist """ params = {} fields = fields_dict(self.__class__) for param_name, field_name in values.items(): metadata = fields[field_name].metadata model_allowlist = metadata.get("model_allowlist") model_denylist = metadata.get("model_denylist") field_value = getattr(self, field_name, None) allowlist_condition = model_allowlist is None or self.model in model_allowlist denylist_condition = model_denylist is None or self.model not in model_denylist if field_value is not None and allowlist_condition and denylist_condition: params[param_name] = field_value return params
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeapi_type = field(default=openai.api_type, kw_only=True)
class-attribute instance-attributeapi_version = field(default=openai.api_version, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributebackground = field(default=None, kw_only=True, metadata={'serializable': True, 'model_allowlist': ['gpt-image-1']})
class-attribute instance-attributebase_url = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeimage_size = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemoderation = field(default=None, kw_only=True, metadata={'serializable': True, 'model_allowlist': ['gpt-image-1']})
class-attribute instance-attributeorganization = field(default=openai.organization, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeoutput_compression = field(default=None, kw_only=True, metadata={'serializable': True, 'model_allowlist': ['gpt-image-1']})
class-attribute instance-attributeoutput_format = field(default=None, kw_only=True, metadata={'serializable': True, 'model_allowlist': ['gpt-image-1']})
class-attribute instance-attributequality = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeresponse_format = field(default='b64_json', kw_only=True, metadata={'serializable': True, 'model_denylist': ['gpt-image-1']})
class-attribute instance-attributestyle = field(default=None, kw_only=True, metadata={'serializable': True, 'model_allowlist': ['dall-e-3']})
class-attribute instance-attribute
_build_model_params(values)
Builds parameters while considering field metadata and None values.
Parameters
Name | Type | Description | Default |
---|---|---|---|
values | dict | A dictionary mapping parameter names to field names. | required |
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def _build_model_params(self, values: dict) -> dict: """Builds parameters while considering field metadata and None values. Args: values: A dictionary mapping parameter names to field names. Field will be added to the params dictionary if all conditions are met: - The field value is not None - The model_allowlist is None or the model is in the allowlist - The model_denylist is None or the model is not in the denylist """ params = {} fields = fields_dict(self.__class__) for param_name, field_name in values.items(): metadata = fields[field_name].metadata model_allowlist = metadata.get("model_allowlist") model_denylist = metadata.get("model_denylist") field_value = getattr(self, field_name, None) allowlist_condition = model_allowlist is None or self.model in model_allowlist denylist_condition = model_denylist is None or self.model not in model_denylist if field_value is not None and allowlist_condition and denylist_condition: params[param_name] = field_value return params
_parse_image_response(response, prompt)
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageArtifact: from griptape.loaders.image_loader import ImageLoader if response.data is None or response.data[0] is None or response.data[0].b64_json is None: raise Exception("Failed to generate image") image_data = base64.b64decode(response.data[0].b64_json) image_artifact = ImageLoader().parse(image_data) image_artifact.meta["prompt"] = prompt image_artifact.meta["model"] = self.model return image_artifact
client()
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
@lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)
try_image_inpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_image_inpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: prompt = ", ".join(prompts) response = self.client.images.edit( prompt=prompt, image=image.value, mask=mask.value, **self._build_model_params( { "size": "image_size", "response_format": "response_format", } ), ) return self._parse_image_response(response, prompt)
try_image_outpainting(prompts, image, mask, negative_prompts=None)
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_image_outpainting( self, prompts: list[str], image: ImageArtifact, mask: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")
try_image_variation(prompts, image, negative_prompts=None)
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_image_variation( self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None, ) -> ImageArtifact: """Creates a variation of an image. Only supported by for dall-e-2. Requires image size to be one of the following: [256x256, 512x512, 1024x1024] """ if self.model != "dall-e-2": raise NotImplementedError("Image variation only supports dall-e-2") response = self.client.images.create_variation( image=image.value, n=1, response_format=self.response_format, size=self.image_size, # pyright: ignore[reportArgumentType] ) return self._parse_image_response(response, "")
try_text_to_image(prompts, negative_prompts=None)
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: prompt = ", ".join(prompts) response = self.client.images.generate( model=self.model, prompt=prompt, n=1, **self._build_model_params( { "size": "image_size", "quality": "quality", "style": "style", "response_format": "response_format", "background": "background", "moderation": "moderation", "output_compression": "output_compression", "output_format": "output_format", } ), ) return self._parse_image_response(response, prompt)
validate_image_size(attribute, value)
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
@image_size.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_image_size(self, attribute: str, value: str | None) -> None: """Validates the image size based on the model. Must be one of `1024x1024`, `1536x1024` (landscape), `1024x1536` (portrait), or `auto` (default value) for `gpt-image-1`, one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`, and one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3`. """ if value is None: return if self.model.startswith("gpt-image"): allowed_sizes = ("1024x1024", "1536x1024", "1024x1536", "auto") elif self.model == "dall-e-2": allowed_sizes = ("256x256", "512x512", "1024x1024") elif self.model == "dall-e-3": allowed_sizes = ("1024x1024", "1792x1024", "1024x1792") else: raise NotImplementedError(f"Image size validation not implemented for model {self.model}") if value is not None and value not in allowed_sizes: raise ValueError(f"Image size, {value}, must be one of the following: {allowed_sizes}")
OpenAiTextToSpeechDriver
Bases:
BaseTextToSpeechDriver
Source Code in griptape/drivers/text_to_speech/openai_text_to_speech_driver.py
@define class OpenAiTextToSpeechDriver(BaseTextToSpeechDriver): model: str = field(default="tts-1", kw_only=True, metadata={"serializable": True}) voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = field( default="alloy", kw_only=True, metadata={"serializable": True}, ) format: Literal["mp3", "opus", "aac", "flac"] = field(default="mp3", kw_only=True, metadata={"serializable": True}) api_type: Optional[str] = field(default=openai.api_type, kw_only=True) api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True}) base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True) organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True}) _client: Optional[openai.OpenAI] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI( api_key=self.api_key, base_url=self.base_url, organization=self.organization, ) def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: response = self.client.audio.speech.create( input=". ".join(prompts), voice=self.voice, model=self.model, response_format=self.format, ) return AudioArtifact(value=response.content, format=self.format)
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True)
class-attribute instance-attributeapi_type = field(default=openai.api_type, kw_only=True)
class-attribute instance-attributeapi_version = field(default=openai.api_version, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributebase_url = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeformat = field(default='mp3', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(default='tts-1', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeorganization = field(default=openai.organization, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributevoice = field(default='alloy', kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/text_to_speech/openai_text_to_speech_driver.py
@lazy_property() def client(self) -> openai.OpenAI: return openai.OpenAI( api_key=self.api_key, base_url=self.base_url, organization=self.organization, )
try_text_to_audio(prompts)
Source Code in griptape/drivers/text_to_speech/openai_text_to_speech_driver.py
def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: response = self.client.audio.speech.create( input=". ".join(prompts), voice=self.voice, model=self.model, response_format=self.format, ) return AudioArtifact(value=response.content, format=self.format)
OpenSearchVectorStoreDriver
Bases:
BaseVectorStoreDriver
Attributes
Name | Type | Description |
---|---|---|
host | str | The host of the OpenSearch cluster. |
port | int | The port of the OpenSearch cluster. |
http_auth | str | tuple[str, Optional[str]] | The HTTP authentication credentials to use. |
use_ssl | bool | Whether to use SSL. |
verify_certs | bool | Whether to verify SSL certificates. |
index_name | str | The name of the index to use. |
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
@define class OpenSearchVectorStoreDriver(BaseVectorStoreDriver): """A Vector Store Driver for OpenSearch. Attributes: host: The host of the OpenSearch cluster. port: The port of the OpenSearch cluster. http_auth: The HTTP authentication credentials to use. use_ssl: Whether to use SSL. verify_certs: Whether to verify SSL certificates. index_name: The name of the index to use. """ host: str = field(kw_only=True, metadata={"serializable": True}) port: int = field(default=443, kw_only=True, metadata={"serializable": True}) http_auth: str | tuple[str, Optional[str]] = field(default=None, kw_only=True, metadata={"serializable": True}) use_ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True}) verify_certs: bool = field(default=True, kw_only=True, metadata={"serializable": True}) index_name: str = field(kw_only=True, metadata={"serializable": True}) _client: Optional[OpenSearch] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> OpenSearch: opensearchpy = import_optional_dependency("opensearchpy") return opensearchpy.OpenSearch( hosts=[{"host": self.host, "port": self.port}], http_auth=self.http_auth, use_ssl=self.use_ssl, verify_certs=self.verify_certs, connection_class=opensearchpy.RequestsHttpConnection, ) def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: """Inserts or updates a vector in OpenSearch. If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided. """ vector_id = vector_id or utils.str_to_hash(str(vector)) doc = {"vector": vector, "namespace": namespace, "metadata": meta} doc.update(kwargs) response = self.client.index(index=self.index_name, id=vector_id, body=doc) return response["_id"] def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Retrieves a specific vector entry from OpenSearch based on its identifier and optional namespace. Returns: If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned. """ try: query = {"bool": {"must": [{"term": {"_id": vector_id}}]}} if namespace: query["bool"]["must"].append({"term": {"namespace": namespace}}) response = self.client.search(index=self.index_name, body={"query": query, "size": 1}) if response["hits"]["total"]["value"] > 0: vector_data = response["hits"]["hits"][0]["_source"] return BaseVectorStoreDriver.Entry( id=vector_id, meta=vector_data.get("metadata"), vector=vector_data.get("vector"), namespace=vector_data.get("namespace"), ) return None except Exception as e: logging.exception("Error while loading entry: %s", e) return None def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Retrieves all vector entries from OpenSearch that match the optional namespace. Returns: A list of BaseVectorStoreDriver.Entry objects. """ query_body = {"size": 10000, "query": {"match_all": {}}} if namespace: query_body["query"] = {"match": {"namespace": namespace}} response = self.client.search(index=self.index_name, body=query_body) return [ BaseVectorStoreDriver.Entry( id=hit["_id"], vector=hit["_source"].get("vector"), meta=hit["_source"].get("metadata"), namespace=hit["_source"].get("namespace"), ) for hit in response["hits"]["hits"] ] def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, include_metadata: bool = True, field_name: str = "vector", **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided vector list. Results can be limited using the count parameter and optionally filtered by a namespace. Returns: A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. """ count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT # Base k-NN query query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}} if namespace: query_body["query"] = { "bool": { "must": [ {"match": {"namespace": namespace}}, {"knn": {field_name: {"vector": vector, "k": count}}}, ], }, } response = self.client.search(index=self.index_name, body=query_body) return [ BaseVectorStoreDriver.Entry( id=hit["_id"], namespace=hit["_source"].get("namespace") if namespace else None, score=hit["_score"], vector=hit["_source"].get("vector") if include_vectors else None, meta=hit["_source"].get("metadata") if include_metadata else None, ) for hit in response["hits"]["hits"] ] def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributehost = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributehttp_auth = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeindex_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeport = field(default=443, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeuse_ssl = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeverify_certs = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
@lazy_property() def client(self) -> OpenSearch: opensearchpy = import_optional_dependency("opensearchpy") return opensearchpy.OpenSearch( hosts=[{"host": self.host, "port": self.port}], http_auth=self.http_auth, use_ssl=self.use_ssl, verify_certs=self.verify_certs, connection_class=opensearchpy.RequestsHttpConnection, )
delete_vector(vector_id)
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
load_entries(*, namespace=None)
Retrieves all vector entries from OpenSearch that match the optional namespace.
Returns
Type | Description |
---|---|
list[Entry] | A list of BaseVectorStoreDriver.Entry objects. |
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Retrieves all vector entries from OpenSearch that match the optional namespace. Returns: A list of BaseVectorStoreDriver.Entry objects. """ query_body = {"size": 10000, "query": {"match_all": {}}} if namespace: query_body["query"] = {"match": {"namespace": namespace}} response = self.client.search(index=self.index_name, body=query_body) return [ BaseVectorStoreDriver.Entry( id=hit["_id"], vector=hit["_source"].get("vector"), meta=hit["_source"].get("metadata"), namespace=hit["_source"].get("namespace"), ) for hit in response["hits"]["hits"] ]
load_entry(vector_id, *, namespace=None)
Retrieves a specific vector entry from OpenSearch based on its identifier and optional namespace.
Returns
Type | Description |
---|---|
Optional[Entry] | If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned. |
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Retrieves a specific vector entry from OpenSearch based on its identifier and optional namespace. Returns: If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned. """ try: query = {"bool": {"must": [{"term": {"_id": vector_id}}]}} if namespace: query["bool"]["must"].append({"term": {"namespace": namespace}}) response = self.client.search(index=self.index_name, body={"query": query, "size": 1}) if response["hits"]["total"]["value"] > 0: vector_data = response["hits"]["hits"][0]["_source"] return BaseVectorStoreDriver.Entry( id=vector_id, meta=vector_data.get("metadata"), vector=vector_data.get("vector"), namespace=vector_data.get("namespace"), ) return None except Exception as e: logging.exception("Error while loading entry: %s", e) return None
query_vector(vector, *, count=None, namespace=None, include_vectors=False, include_metadata=True, field_name='vector', **kwargs)
Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided vector list.
Results can be limited using the count parameter and optionally filtered by a namespace.
Returns
Type | Description |
---|---|
list[Entry] | A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. |
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, include_metadata: bool = True, field_name: str = "vector", **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided vector list. Results can be limited using the count parameter and optionally filtered by a namespace. Returns: A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. """ count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT # Base k-NN query query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}} if namespace: query_body["query"] = { "bool": { "must": [ {"match": {"namespace": namespace}}, {"knn": {field_name: {"vector": vector, "k": count}}}, ], }, } response = self.client.search(index=self.index_name, body=query_body) return [ BaseVectorStoreDriver.Entry( id=hit["_id"], namespace=hit["_source"].get("namespace") if namespace else None, score=hit["_score"], vector=hit["_source"].get("vector") if include_vectors else None, meta=hit["_source"].get("metadata") if include_metadata else None, ) for hit in response["hits"]["hits"] ]
upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: """Inserts or updates a vector in OpenSearch. If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided. """ vector_id = vector_id or utils.str_to_hash(str(vector)) doc = {"vector": vector, "namespace": namespace, "metadata": meta} doc.update(kwargs) response = self.client.index(index=self.index_name, id=vector_id, body=doc) return response["_id"]
OpenTelemetryObservabilityDriver
Bases:
BaseObservabilityDriver
Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
@define class OpenTelemetryObservabilityDriver(BaseObservabilityDriver): service_name: str = field(default="griptape", kw_only=True) span_processor: SpanProcessor = field(kw_only=True) service_version: Optional[str] = field(default=None, kw_only=True) deployment_env: Optional[str] = field(default=None, kw_only=True) trace_provider: TracerProvider = field( default=Factory( lambda self: self._trace_provider_factory(), takes_self=True, ), kw_only=True, ) _tracer: Tracer = field(init=False) _root_span_context_manager: Any = None def _trace_provider_factory(self) -> TracerProvider: opentelemetry_trace = import_optional_dependency("opentelemetry.sdk.trace") attributes = {"service.name": self.service_name} if self.service_version is not None: attributes["service.version"] = self.service_version if self.deployment_env is not None: attributes["deployment.environment"] = self.deployment_env return opentelemetry_trace.TracerProvider( resource=import_optional_dependency("opentelemetry.sdk.resources").Resource(attributes=attributes) ) # pyright: ignore[reportArgumentType] def __attrs_post_init__(self) -> None: opentelemetry_trace = import_optional_dependency("opentelemetry.trace") self.trace_provider.add_span_processor(self.span_processor) self._tracer = opentelemetry_trace.get_tracer(self.service_name, tracer_provider=self.trace_provider) def __enter__(self) -> None: opentelemetry_instrumentation_threading = import_optional_dependency("opentelemetry.instrumentation.threading") opentelemetry_instrumentation_threading.ThreadingInstrumentor().instrument() self._root_span_context_manager = self._tracer.start_as_current_span("main") # pyright: ignore[reportCallIssue] self._root_span_context_manager.__enter__() def __exit__( self, exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], exc_traceback: Optional[TracebackType], ) -> bool: opentelemetry_trace = import_optional_dependency("opentelemetry.trace") opentelemetry_instrumentation_threading = import_optional_dependency("opentelemetry.instrumentation.threading") root_span = opentelemetry_trace.get_current_span() if exc_value: root_span = opentelemetry_trace.get_current_span() root_span.set_status(opentelemetry_trace.Status(opentelemetry_trace.StatusCode.ERROR)) root_span.record_exception(exc_value) else: root_span.set_status(opentelemetry_trace.Status(opentelemetry_trace.StatusCode.OK)) if self._root_span_context_manager: self._root_span_context_manager.__exit__(exc_type, exc_value, exc_traceback) self._root_span_context_manager = None self.trace_provider.force_flush() opentelemetry_instrumentation_threading.ThreadingInstrumentor().uninstrument() return False def observe(self, call: Observable.Call) -> Any: open_telemetry_trace = import_optional_dependency("opentelemetry.trace") func = call.func instance = call.instance tags = call.tags class_name = f"{instance.__class__.__name__}." if instance else "" span_name = f"{class_name}{func.__name__}()" with self._tracer.start_as_current_span(span_name) as span: # pyright: ignore[reportCallIssue] if tags is not None: span.set_attribute("tags", tags) try: result = call() span.set_status(open_telemetry_trace.Status(open_telemetry_trace.StatusCode.OK)) return result except Exception as e: span.set_status(open_telemetry_trace.Status(open_telemetry_trace.StatusCode.ERROR)) span.record_exception(e) raise e def get_span_id(self) -> Optional[str]: opentelemetry_trace = import_optional_dependency("opentelemetry.trace") span = opentelemetry_trace.get_current_span() if span is opentelemetry_trace.INVALID_SPAN: return None return opentelemetry_trace.format_span_id(span.get_span_context().span_id)
_root_span_context_manager = None
class-attribute instance-attribute_tracer = field(init=False)
class-attribute instance-attributedeployment_env = field(default=None, kw_only=True)
class-attribute instance-attributeservice_name = field(default='griptape', kw_only=True)
class-attribute instance-attributeservice_version = field(default=None, kw_only=True)
class-attribute instance-attributespan_processor = field(kw_only=True)
class-attribute instance-attributetrace_provider = field(default=Factory(lambda self: self._trace_provider_factory(), takes_self=True), kw_only=True)
class-attribute instance-attribute
attrs_post_init()
Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def __attrs_post_init__(self) -> None: opentelemetry_trace = import_optional_dependency("opentelemetry.trace") self.trace_provider.add_span_processor(self.span_processor) self._tracer = opentelemetry_trace.get_tracer(self.service_name, tracer_provider=self.trace_provider)
enter()
Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def __enter__(self) -> None: opentelemetry_instrumentation_threading = import_optional_dependency("opentelemetry.instrumentation.threading") opentelemetry_instrumentation_threading.ThreadingInstrumentor().instrument() self._root_span_context_manager = self._tracer.start_as_current_span("main") # pyright: ignore[reportCallIssue] self._root_span_context_manager.__enter__()
exit(exc_type, exc_value, exc_traceback)
Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def __exit__( self, exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], exc_traceback: Optional[TracebackType], ) -> bool: opentelemetry_trace = import_optional_dependency("opentelemetry.trace") opentelemetry_instrumentation_threading = import_optional_dependency("opentelemetry.instrumentation.threading") root_span = opentelemetry_trace.get_current_span() if exc_value: root_span = opentelemetry_trace.get_current_span() root_span.set_status(opentelemetry_trace.Status(opentelemetry_trace.StatusCode.ERROR)) root_span.record_exception(exc_value) else: root_span.set_status(opentelemetry_trace.Status(opentelemetry_trace.StatusCode.OK)) if self._root_span_context_manager: self._root_span_context_manager.__exit__(exc_type, exc_value, exc_traceback) self._root_span_context_manager = None self.trace_provider.force_flush() opentelemetry_instrumentation_threading.ThreadingInstrumentor().uninstrument() return False
_trace_provider_factory()
Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def _trace_provider_factory(self) -> TracerProvider: opentelemetry_trace = import_optional_dependency("opentelemetry.sdk.trace") attributes = {"service.name": self.service_name} if self.service_version is not None: attributes["service.version"] = self.service_version if self.deployment_env is not None: attributes["deployment.environment"] = self.deployment_env return opentelemetry_trace.TracerProvider( resource=import_optional_dependency("opentelemetry.sdk.resources").Resource(attributes=attributes) ) # pyright: ignore[reportArgumentType]
get_span_id()
Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def get_span_id(self) -> Optional[str]: opentelemetry_trace = import_optional_dependency("opentelemetry.trace") span = opentelemetry_trace.get_current_span() if span is opentelemetry_trace.INVALID_SPAN: return None return opentelemetry_trace.format_span_id(span.get_span_context().span_id)
observe(call)
Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def observe(self, call: Observable.Call) -> Any: open_telemetry_trace = import_optional_dependency("opentelemetry.trace") func = call.func instance = call.instance tags = call.tags class_name = f"{instance.__class__.__name__}." if instance else "" span_name = f"{class_name}{func.__name__}()" with self._tracer.start_as_current_span(span_name) as span: # pyright: ignore[reportCallIssue] if tags is not None: span.set_attribute("tags", tags) try: result = call() span.set_status(open_telemetry_trace.Status(open_telemetry_trace.StatusCode.OK)) return result except Exception as e: span.set_status(open_telemetry_trace.Status(open_telemetry_trace.StatusCode.ERROR)) span.record_exception(e) raise e
PerplexityPromptDriver
Bases:
OpenAiChatPromptDriver
Source Code in griptape/drivers/prompt/perplexity_prompt_driver.py
@define class PerplexityPromptDriver(OpenAiChatPromptDriver): base_url: str = field(default="https://api.perplexity.ai", kw_only=True, metadata={"serializable": True}) structured_output_strategy: str = field(default="native", kw_only=True, metadata={"serializable": True}) @override def _to_message(self, result: ChatCompletion) -> Message: message = super()._to_message(result) message.content[0].artifact.meta["citations"] = getattr(result, "citations", []) return message def _base_params(self, prompt_stack: PromptStack) -> dict: params = super()._base_params(prompt_stack) if "stop" in params: del params["stop"] return params
base_url = field(default='https://api.perplexity.ai', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestructured_output_strategy = field(default='native', kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/perplexity_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: params = super()._base_params(prompt_stack) if "stop" in params: del params["stop"] return params
_to_message(result)
Source Code in griptape/drivers/prompt/perplexity_prompt_driver.py
@override def _to_message(self, result: ChatCompletion) -> Message: message = super()._to_message(result) message.content[0].artifact.meta["citations"] = getattr(result, "citations", []) return message
PerplexityWebSearchDriver
Bases:
BaseWebSearchDriver
Source Code in griptape/drivers/web_search/perplexity_web_search_driver.py
@define class PerplexityWebSearchDriver(BaseWebSearchDriver): model: str = field(default="sonar-pro", kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(kw_only=True, default=None) _prompt_driver: Optional[PerplexityPromptDriver] = field(default=None, alias="prompt_driver") @lazy_property() def prompt_driver(self) -> PerplexityPromptDriver: if self.api_key is None: raise ValueError("api_key is required") return PerplexityPromptDriver(model=self.model, api_key=self.api_key) def search(self, query: str, **kwargs) -> ListArtifact: message = self.prompt_driver.run(PromptStack.from_artifact(TextArtifact(query))) return ListArtifact([message.to_artifact()])
_prompt_driver = field(default=None, alias='prompt_driver')
class-attribute instance-attributeapi_key = field(kw_only=True, default=None)
class-attribute instance-attributemodel = field(default='sonar-pro', kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
prompt_driver()
Source Code in griptape/drivers/web_search/perplexity_web_search_driver.py
@lazy_property() def prompt_driver(self) -> PerplexityPromptDriver: if self.api_key is None: raise ValueError("api_key is required") return PerplexityPromptDriver(model=self.model, api_key=self.api_key)
search(query, **kwargs)
Source Code in griptape/drivers/web_search/perplexity_web_search_driver.py
def search(self, query: str, **kwargs) -> ListArtifact: message = self.prompt_driver.run(PromptStack.from_artifact(TextArtifact(query))) return ListArtifact([message.to_artifact()])
PgAiKnowledgeBaseVectorStoreDriver
Bases:
BaseVectorStoreDriver
Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
@define class PgAiKnowledgeBaseVectorStoreDriver(BaseVectorStoreDriver): connection_string: str = field(kw_only=True, metadata={"serializable": True}) knowledge_base_name: str = field(kw_only=True, metadata={"serializable": True}) embedding_driver: BaseEmbeddingDriver = field( default=Factory(lambda: DummyEmbeddingDriver()), metadata={"serializable": True}, kw_only=True, init=False, ) _engine: sqlalchemy.Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False}) @lazy_property() def engine(self) -> sqlalchemy.Engine: return import_optional_dependency("sqlalchemy").create_engine(self.connection_string) def query( self, query: str | TextArtifact | ImageArtifact, *, count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: if isinstance(query, ImageArtifact): raise ValueError(f"{self.__class__.__name__} does not support querying with Image Artifacts.") sqlalchemy = import_optional_dependency("sqlalchemy") with sqlalchemy.orm.Session(self.engine) as session: rows = session.query(sqlalchemy.func.aidb.retrieve_text(self.knowledge_base_name, query, count)).all() entries = [] for (row,) in rows: # Remove the first and last parentheses from the row and list by commas # Example: '(foo,bar)' -> ['foo', 'bar'] row_list = "".join(row.replace("(", "", 1).rsplit(")", 1)).split(",") entries.append( BaseVectorStoreDriver.Entry( id=row_list[0], score=float(row_list[2]), meta={"artifact": TextArtifact(row_list[1]).to_json()}, ) ) return entries def upsert_vector( self, vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support vector upsert.") def upsert_text_artifact( self, artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support text artifact upsert.") def upsert_text( self, string: str, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.") def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry: raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.") def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.") def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.") def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
_engine = field(default=None, kw_only=True, alias='engine', metadata={'serializable': False})
class-attribute instance-attributeconnection_string = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeembedding_driver = field(default=Factory(lambda: DummyEmbeddingDriver()), metadata={'serializable': True}, kw_only=True, init=False)
class-attribute instance-attributeknowledge_base_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
delete_vector(vector_id)
Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
engine()
Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
@lazy_property() def engine(self) -> sqlalchemy.Engine: return import_optional_dependency("sqlalchemy").create_engine(self.connection_string)
load_artifacts(*, namespace=None)
Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact: raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.")
load_entries(*, namespace=None)
Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")
load_entry(vector_id, *, namespace=None)
Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry: raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")
query(query, *, count=BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, **kwargs)
Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def query( self, query: str | TextArtifact | ImageArtifact, *, count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: if isinstance(query, ImageArtifact): raise ValueError(f"{self.__class__.__name__} does not support querying with Image Artifacts.") sqlalchemy = import_optional_dependency("sqlalchemy") with sqlalchemy.orm.Session(self.engine) as session: rows = session.query(sqlalchemy.func.aidb.retrieve_text(self.knowledge_base_name, query, count)).all() entries = [] for (row,) in rows: # Remove the first and last parentheses from the row and list by commas # Example: '(foo,bar)' -> ['foo', 'bar'] row_list = "".join(row.replace("(", "", 1).rsplit(")", 1)).split(",") entries.append( BaseVectorStoreDriver.Entry( id=row_list[0], score=float(row_list[2]), meta={"artifact": TextArtifact(row_list[1]).to_json()}, ) ) return entries
upsert_text(string, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def upsert_text( self, string: str, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.")
upsert_text_artifact(artifact, namespace=None, meta=None, vector_id=None, **kwargs)
Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def upsert_text_artifact( self, artifact: TextArtifact, namespace: Optional[str] = None, meta: Optional[dict] = None, vector_id: Optional[str] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support text artifact upsert.")
upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def upsert_vector( self, vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: raise NotImplementedError(f"{self.__class__.__name__} does not support vector upsert.")
PgVectorVectorStoreDriver
Bases:
BaseVectorStoreDriver
Attributes
Name | Type | Description |
---|---|---|
connection_string | Optional[str] | An optional string describing the target Postgres database instance. |
create_engine_params | dict | Additional configuration params passed when creating the database connection. |
engine | Engine | An optional sqlalchemy Postgres engine to use. |
table_name | str | Optionally specify the name of the table to used to store vectors. |
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
@define class PgVectorVectorStoreDriver(BaseVectorStoreDriver): """A vector store driver to Postgres using the PGVector extension. Attributes: connection_string: An optional string describing the target Postgres database instance. create_engine_params: Additional configuration params passed when creating the database connection. engine: An optional sqlalchemy Postgres engine to use. table_name: Optionally specify the name of the table to used to store vectors. """ connection_string: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) create_engine_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True}) table_name: str = field(kw_only=True, metadata={"serializable": True}) _model: Any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True)) _engine: Optional[sqlalchemy.Engine] = field( default=None, kw_only=True, alias="engine", metadata={"serializable": False} ) @connection_string.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_connection_string(self, _: Attribute, connection_string: Optional[str]) -> None: # If an engine is provided, the connection string is not used. if self._engine is not None: return # If an engine is not provided, a connection string is required. if connection_string is None: raise ValueError("An engine or connection string is required") if not connection_string.startswith("postgresql://"): raise ValueError("The connection string must describe a Postgres database connection") @lazy_property() def engine(self) -> sqlalchemy.Engine: return import_optional_dependency("sqlalchemy").create_engine( self.connection_string, **self.create_engine_params ) def setup( self, *, create_schema: bool = True, install_uuid_extension: bool = True, install_vector_extension: bool = True, ) -> None: """Provides a mechanism to initialize the database schema and extensions.""" sqlalchemy_sql = import_optional_dependency("sqlalchemy.sql") if install_uuid_extension: with self.engine.begin() as conn: conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')) if install_vector_extension: with self.engine.begin() as conn: conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "vector";')) if create_schema: self._model.metadata.create_all(self.engine) def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: """Inserts or updates a vector in the collection.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") with sqlalchemy_orm.Session(self.engine) as session: obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs) obj = session.merge(obj) session.commit() return str(obj.id) def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry: """Retrieves a specific vector entry from the collection based on its identifier and optional namespace.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") with sqlalchemy_orm.Session(self.engine) as session: result = session.get(self._model, vector_id) return BaseVectorStoreDriver.Entry( id=result.id, vector=result.vector, namespace=result.namespace, meta=result.meta, ) def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") with sqlalchemy_orm.Session(self.engine) as session: query = session.query(self._model) if namespace: query = query.filter_by(namespace=namespace) results = query.all() return [ BaseVectorStoreDriver.Entry( id=str(result.id), vector=result.vector, namespace=result.namespace, meta=result.meta, ) for result in results ] def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, distance_metric: str = "cosine_distance", **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT distance_metrics = { "cosine_distance": self._model.vector.cosine_distance, "l2_distance": self._model.vector.l2_distance, "inner_product": self._model.vector.max_inner_product, } if distance_metric not in distance_metrics: raise ValueError("Invalid distance metric provided") op = distance_metrics[distance_metric] with sqlalchemy_orm.Session(self.engine) as session: # The query should return both the vector and the distance metric score. query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector)) # pyright: ignore[reportOptionalCall] filter_kwargs: Optional[OrderedDict] = None if namespace is not None: filter_kwargs = OrderedDict(namespace=namespace) if "filter" in kwargs and isinstance(kwargs["filter"], dict): filter_kwargs = filter_kwargs or OrderedDict() filter_kwargs.update(kwargs["filter"]) if filter_kwargs is not None: query_result = query_result.filter_by(**filter_kwargs) results = query_result.limit(count).all() return [ BaseVectorStoreDriver.Entry( id=str(result[0].id), vector=result[0].vector if include_vectors else None, score=result[1], meta=result[0].meta, namespace=result[0].namespace, ) for result in results ] def default_vector_model(self) -> Any: pgvector_sqlalchemy = import_optional_dependency("pgvector.sqlalchemy") sqlalchemy = import_optional_dependency("sqlalchemy") sqlalchemy_dialects_postgresql = import_optional_dependency("sqlalchemy.dialects.postgresql") sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") @dataclass class VectorModel(sqlalchemy_orm.declarative_base()): __tablename__ = self.table_name id = sqlalchemy.Column( sqlalchemy_dialects_postgresql.UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False, ) vector = sqlalchemy.Column(pgvector_sqlalchemy.Vector()) namespace = sqlalchemy.Column(sqlalchemy.String) meta = sqlalchemy.Column(sqlalchemy.JSON) return VectorModel def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
_engine = field(default=None, kw_only=True, alias='engine', metadata={'serializable': False})
class-attribute instance-attribute_model = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True))
class-attribute instance-attributeconnection_string = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributecreate_engine_params = field(factory=dict, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetable_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
default_vector_model()
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def default_vector_model(self) -> Any: pgvector_sqlalchemy = import_optional_dependency("pgvector.sqlalchemy") sqlalchemy = import_optional_dependency("sqlalchemy") sqlalchemy_dialects_postgresql = import_optional_dependency("sqlalchemy.dialects.postgresql") sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") @dataclass class VectorModel(sqlalchemy_orm.declarative_base()): __tablename__ = self.table_name id = sqlalchemy.Column( sqlalchemy_dialects_postgresql.UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, unique=True, nullable=False, ) vector = sqlalchemy.Column(pgvector_sqlalchemy.Vector()) namespace = sqlalchemy.Column(sqlalchemy.String) meta = sqlalchemy.Column(sqlalchemy.JSON) return VectorModel
delete_vector(vector_id)
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
engine()
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
@lazy_property() def engine(self) -> sqlalchemy.Engine: return import_optional_dependency("sqlalchemy").create_engine( self.connection_string, **self.create_engine_params )
load_entries(*, namespace=None)
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") with sqlalchemy_orm.Session(self.engine) as session: query = session.query(self._model) if namespace: query = query.filter_by(namespace=namespace) results = query.all() return [ BaseVectorStoreDriver.Entry( id=str(result.id), vector=result.vector, namespace=result.namespace, meta=result.meta, ) for result in results ]
load_entry(vector_id, *, namespace=None)
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry: """Retrieves a specific vector entry from the collection based on its identifier and optional namespace.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") with sqlalchemy_orm.Session(self.engine) as session: result = session.get(self._model, vector_id) return BaseVectorStoreDriver.Entry( id=result.id, vector=result.vector, namespace=result.namespace, meta=result.meta, )
query_vector(vector, *, count=None, namespace=None, include_vectors=False, distance_metric='cosine_distance', **kwargs)
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, distance_metric: str = "cosine_distance", **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT distance_metrics = { "cosine_distance": self._model.vector.cosine_distance, "l2_distance": self._model.vector.l2_distance, "inner_product": self._model.vector.max_inner_product, } if distance_metric not in distance_metrics: raise ValueError("Invalid distance metric provided") op = distance_metrics[distance_metric] with sqlalchemy_orm.Session(self.engine) as session: # The query should return both the vector and the distance metric score. query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector)) # pyright: ignore[reportOptionalCall] filter_kwargs: Optional[OrderedDict] = None if namespace is not None: filter_kwargs = OrderedDict(namespace=namespace) if "filter" in kwargs and isinstance(kwargs["filter"], dict): filter_kwargs = filter_kwargs or OrderedDict() filter_kwargs.update(kwargs["filter"]) if filter_kwargs is not None: query_result = query_result.filter_by(**filter_kwargs) results = query_result.limit(count).all() return [ BaseVectorStoreDriver.Entry( id=str(result[0].id), vector=result[0].vector if include_vectors else None, score=result[1], meta=result[0].meta, namespace=result[0].namespace, ) for result in results ]
setup(*, create_schema=True, install_uuid_extension=True, install_vector_extension=True)
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def setup( self, *, create_schema: bool = True, install_uuid_extension: bool = True, install_vector_extension: bool = True, ) -> None: """Provides a mechanism to initialize the database schema and extensions.""" sqlalchemy_sql = import_optional_dependency("sqlalchemy.sql") if install_uuid_extension: with self.engine.begin() as conn: conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";')) if install_vector_extension: with self.engine.begin() as conn: conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "vector";')) if create_schema: self._model.metadata.create_all(self.engine)
upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: """Inserts or updates a vector in the collection.""" sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm") with sqlalchemy_orm.Session(self.engine) as session: obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs) obj = session.merge(obj) session.commit() return str(obj.id)
validateconnection_string(, connection_string)
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
@connection_string.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_connection_string(self, _: Attribute, connection_string: Optional[str]) -> None: # If an engine is provided, the connection string is not used. if self._engine is not None: return # If an engine is not provided, a connection string is required. if connection_string is None: raise ValueError("An engine or connection string is required") if not connection_string.startswith("postgresql://"): raise ValueError("The connection string must describe a Postgres database connection")
PineconeVectorStoreDriver
Bases:
BaseVectorStoreDriver
Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
@define class PineconeVectorStoreDriver(BaseVectorStoreDriver): api_key: str = field(kw_only=True, metadata={"serializable": True}) index_name: str = field(kw_only=True, metadata={"serializable": True}) environment: str = field(kw_only=True, metadata={"serializable": True}) project_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) _client: Optional[pinecone.Pinecone] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) _index: Optional[pinecone.Index] = field( default=None, kw_only=True, alias="index", metadata={"serializable": False} ) @lazy_property() def client(self) -> pinecone.Pinecone: return import_optional_dependency("pinecone").Pinecone( api_key=self.api_key, environment=self.environment, project_name=self.project_name, ) @lazy_property() def index(self) -> pinecone.data.index.Index: return self.client.Index(self.index_name) def upsert_vector( self, vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: vector_id = vector_id or str_to_hash(str(vector)) params: dict[str, Any] = {"namespace": namespace} | kwargs self.index.upsert(vectors=[(vector_id, vector, meta)], **params) # pyright: ignore[reportArgumentType] return vector_id def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() # pyright: ignore[reportAttributeAccessIssue] vectors = list(result["vectors"].values()) if len(vectors) > 0: vector = vectors[0] return BaseVectorStoreDriver.Entry( id=vector["id"], meta=vector["metadata"], vector=vector["values"], namespace=result["namespace"], ) return None def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching # all values from a namespace: # https://community.pinecone.io/t/is-there-a-way-to-query-all-the-vectors-and-or-metadata-from-a-namespace/797/5 results = self.index.query( vector=self.embedding_driver.embed("", vector_operation="query"), top_k=10000, include_metadata=True, namespace=namespace, ) return [ BaseVectorStoreDriver.Entry( id=r["id"], vector=r["values"], meta=r["metadata"], namespace=results["namespace"], # pyright: ignore[reportIndexIssue] ) for r in results["matches"] # pyright: ignore[reportIndexIssue] ] def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, include_metadata: bool = True, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: params = { "top_k": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, "namespace": namespace, "include_values": include_vectors, "include_metadata": include_metadata, } | kwargs results = self.index.query(vector=vector, **params) return [ BaseVectorStoreDriver.Entry( id=r["id"], vector=r["values"], score=r["score"], meta=r["metadata"], namespace=results["namespace"], # pyright: ignore[reportIndexIssue] ) for r in results["matches"] # pyright: ignore[reportIndexIssue] ] def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attribute_index = field(default=None, kw_only=True, alias='index', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeenvironment = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeindex_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeproject_name = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
@lazy_property() def client(self) -> pinecone.Pinecone: return import_optional_dependency("pinecone").Pinecone( api_key=self.api_key, environment=self.environment, project_name=self.project_name, )
delete_vector(vector_id)
Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
index()
Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
@lazy_property() def index(self) -> pinecone.data.index.Index: return self.client.Index(self.index_name)
load_entries(*, namespace=None)
Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching # all values from a namespace: # https://community.pinecone.io/t/is-there-a-way-to-query-all-the-vectors-and-or-metadata-from-a-namespace/797/5 results = self.index.query( vector=self.embedding_driver.embed("", vector_operation="query"), top_k=10000, include_metadata=True, namespace=namespace, ) return [ BaseVectorStoreDriver.Entry( id=r["id"], vector=r["values"], meta=r["metadata"], namespace=results["namespace"], # pyright: ignore[reportIndexIssue] ) for r in results["matches"] # pyright: ignore[reportIndexIssue] ]
load_entry(vector_id, *, namespace=None)
Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() # pyright: ignore[reportAttributeAccessIssue] vectors = list(result["vectors"].values()) if len(vectors) > 0: vector = vectors[0] return BaseVectorStoreDriver.Entry( id=vector["id"], meta=vector["metadata"], vector=vector["values"], namespace=result["namespace"], ) return None
query_vector(vector, *, count=None, namespace=None, include_vectors=False, include_metadata=True, **kwargs)
Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, include_metadata: bool = True, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: params = { "top_k": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, "namespace": namespace, "include_values": include_vectors, "include_metadata": include_metadata, } | kwargs results = self.index.query(vector=vector, **params) return [ BaseVectorStoreDriver.Entry( id=r["id"], vector=r["values"], score=r["score"], meta=r["metadata"], namespace=results["namespace"], # pyright: ignore[reportIndexIssue] ) for r in results["matches"] # pyright: ignore[reportIndexIssue] ]
upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
def upsert_vector( self, vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: vector_id = vector_id or str_to_hash(str(vector)) params: dict[str, Any] = {"namespace": namespace} | kwargs self.index.upsert(vectors=[(vector_id, vector, meta)], **params) # pyright: ignore[reportArgumentType] return vector_id
ProxyWebScraperDriver
Bases:
BaseWebScraperDriver
Source Code in griptape/drivers/web_scraper/proxy_web_scraper_driver.py
@define class ProxyWebScraperDriver(BaseWebScraperDriver): proxies: dict = field(kw_only=True, metadata={"serializable": False}) params: dict = field(default=Factory(dict), kw_only=True, metadata={"serializable": True}) def fetch_url(self, url: str) -> str: response = requests.get(url, proxies=self.proxies, **self.params) return response.text def extract_page(self, page: str) -> TextArtifact: return TextArtifact(page)
params = field(default=Factory(dict), kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeproxies = field(kw_only=True, metadata={'serializable': False})
class-attribute instance-attribute
extract_page(page)
Source Code in griptape/drivers/web_scraper/proxy_web_scraper_driver.py
def extract_page(self, page: str) -> TextArtifact: return TextArtifact(page)
fetch_url(url)
Source Code in griptape/drivers/web_scraper/proxy_web_scraper_driver.py
def fetch_url(self, url: str) -> str: response = requests.get(url, proxies=self.proxies, **self.params) return response.text
PusherEventListenerDriver
Bases:
BaseEventListenerDriver
Source Code in griptape/drivers/event_listener/pusher_event_listener_driver.py
@define class PusherEventListenerDriver(BaseEventListenerDriver): app_id: str = field(kw_only=True, metadata={"serializable": True}) key: str = field(kw_only=True, metadata={"serializable": True}) secret: str = field(kw_only=True, metadata={"serializable": False}) cluster: str = field(kw_only=True, metadata={"serializable": True}) channel: str = field(kw_only=True, metadata={"serializable": True}) event_name: str = field(kw_only=True, metadata={"serializable": True}) ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _client: Optional[Pusher] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> Pusher: return import_optional_dependency("pusher").Pusher( app_id=self.app_id, key=self.key, secret=self.secret, cluster=self.cluster, ssl=self.ssl, ) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: data = [ {"channel": self.channel, "name": self.event_name, "data": event_payload} for event_payload in event_payload_batch ] self.client.trigger_batch(data) def try_publish_event_payload(self, event_payload: dict) -> None: self.client.trigger(channels=self.channel, event_name=self.event_name, data=event_payload)
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapp_id = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributechannel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributecluster = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeevent_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributekey = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesecret = field(kw_only=True, metadata={'serializable': False})
class-attribute instance-attributessl = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/event_listener/pusher_event_listener_driver.py
@lazy_property() def client(self) -> Pusher: return import_optional_dependency("pusher").Pusher( app_id=self.app_id, key=self.key, secret=self.secret, cluster=self.cluster, ssl=self.ssl, )
try_publish_event_payload(event_payload)
Source Code in griptape/drivers/event_listener/pusher_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None: self.client.trigger(channels=self.channel, event_name=self.event_name, data=event_payload)
try_publish_event_payload_batch(event_payload_batch)
Source Code in griptape/drivers/event_listener/pusher_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: data = [ {"channel": self.channel, "name": self.event_name, "data": event_payload} for event_payload in event_payload_batch ] self.client.trigger_batch(data)
QdrantVectorStoreDriver
Bases:
BaseVectorStoreDriver
Attributes
Name | Type | Description |
---|---|---|
location | Optional[str] | An optional location for the Qdrant client. If set to ':memory:', an in-memory client is used. |
url | Optional[str] | An optional Qdrant API URL. |
host | Optional[str] | An optional Qdrant host. |
path | Optional[str] | Persistence path for QdrantLocal. Default: None |
port | int | The port number for the Qdrant client. Defaults: 6333. |
grpc_port | int | The gRPC port number for the Qdrant client. Defaults: 6334. |
prefer_grpc | bool | A boolean indicating whether to prefer gRPC over HTTP. Defaults: False. |
force_disable_check_same_thread | Optional[bool] | For QdrantLocal, force disable check_same_thread. Default: False Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient. |
timeout | Optional[int] | Timeout for REST and gRPC API requests. Default: 5 seconds for REST and unlimited for gRPC |
api_key | Optional[str] | API key for authentication in Qdrant Cloud. Defaults: False |
https | bool | If true - use HTTPS(SSL) protocol. Default: True |
prefix | Optional[str] | Add prefix to the REST URL path. Example: service/v1 will result in Example: service/v1 will result in http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. Defaults: None |
distance | str | The distance metric to be used for the vectors. Defaults: 'COSINE'. |
collection_name | str | The name of the Qdrant collection. |
vector_name | Optional[str] | An optional name for the vectors. |
content_payload_key | str | The key for the content payload in the metadata. Defaults: 'data'. |
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
@define class QdrantVectorStoreDriver(BaseVectorStoreDriver): """Vector Store Driver for Qdrant. Attributes: location: An optional location for the Qdrant client. If set to ':memory:', an in-memory client is used. url: An optional Qdrant API URL. host: An optional Qdrant host. path: Persistence path for QdrantLocal. Default: None port: The port number for the Qdrant client. Defaults: 6333. grpc_port: The gRPC port number for the Qdrant client. Defaults: 6334. prefer_grpc: A boolean indicating whether to prefer gRPC over HTTP. Defaults: False. force_disable_check_same_thread: For QdrantLocal, force disable check_same_thread. Default: False Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient. timeout: Timeout for REST and gRPC API requests. Default: 5 seconds for REST and unlimited for gRPC api_key: API key for authentication in Qdrant Cloud. Defaults: False https: If true - use HTTPS(SSL) protocol. Default: True prefix: Add prefix to the REST URL path. Example: service/v1 will result in Example: service/v1 will result in http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. Defaults: None distance: The distance metric to be used for the vectors. Defaults: 'COSINE'. collection_name: The name of the Qdrant collection. vector_name: An optional name for the vectors. content_payload_key: The key for the content payload in the metadata. Defaults: 'data'. """ location: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) port: int = field(default=6333, kw_only=True, metadata={"serializable": True}) grpc_port: int = field(default=6334, kw_only=True, metadata={"serializable": True}) prefer_grpc: bool = field(default=False, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) https: bool = field(default=True, kw_only=True, metadata={"serializable": True}) prefix: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) force_disable_check_same_thread: Optional[bool] = field( default=False, kw_only=True, metadata={"serializable": True}, ) timeout: Optional[int] = field(default=5, kw_only=True, metadata={"serializable": True}) distance: str = field(default=DEFAULT_DISTANCE, kw_only=True, metadata={"serializable": True}) collection_name: str = field(kw_only=True, metadata={"serializable": True}) vector_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) content_payload_key: str = field(default=CONTENT_PAYLOAD_KEY, kw_only=True, metadata={"serializable": True}) _client: Optional[QdrantClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> QdrantClient: return import_optional_dependency("qdrant_client").QdrantClient( location=self.location, url=self.url, host=self.host, path=self.path, port=self.port, prefer_grpc=self.prefer_grpc, grpc_port=self.grpc_port, api_key=self.api_key, https=self.https, prefix=self.prefix, force_disable_check_same_thread=self.force_disable_check_same_thread, timeout=self.timeout, ) def delete_vector(self, vector_id: str) -> None: """Delete a vector from the Qdrant collection based on its ID. Parameters: vector_id (str | id): ID of the vector to delete. """ deletion_response = self.client.delete( collection_name=self.collection_name, points_selector=import_optional_dependency("qdrant_client.http.models").PointIdsList(points=[vector_id]), ) if deletion_response.status == import_optional_dependency("qdrant_client.http.models").UpdateStatus.COMPLETED: logging.info("ID %s is successfully deleted", vector_id) def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Query the Qdrant collection based on a query vector. Parameters: vector (list[float]): Query vector. count (Optional[int]): Optional number of results to return. namespace (Optional[str]): Optional namespace of the vectors. include_vectors (bool): Whether to include vectors in the results. Returns: list[BaseVectorStoreDriver.Entry]: List of Entry objects. """ # Create a search request request = {"collection_name": self.collection_name, "query_vector": vector, "limit": count} request = {k: v for k, v in request.items() if v is not None} results = self.client.search(**request) # Convert results to QueryResult objects return [ BaseVectorStoreDriver.Entry( id=str(result.id), vector=result.vector if include_vectors else [], # pyright: ignore[reportArgumentType] score=result.score, meta={k: v for k, v in result.payload.items() if k not in ["_score", "_tensor_facets"]}, ) for result in results if result.payload is not None ] def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, content: Optional[str] = None, **kwargs, ) -> str: """Upsert vectors into the Qdrant collection. Parameters: vector (list[float]): The vector to be upserted. vector_id (Optional[str]): Optional vector ID. namespace (Optional[str]): Optional namespace for the vector. meta (Optional[dict]): Optional dictionary containing metadata. content (Optional[str]): The text content to be included in the payload. Returns: str: The ID of the upserted vector. """ if vector_id is None: vector_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(vector))) if meta is None: meta = {} if content: meta[self.content_payload_key] = content points = import_optional_dependency("qdrant_client.http.models").Batch( ids=[vector_id], vectors=[vector], payloads=[meta] if meta else None, ) self.client.upsert(collection_name=self.collection_name, points=points) return vector_id def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Load a vector entry from the Qdrant collection based on its ID. Parameters: vector_id (str): ID of the vector to load. namespace (str, optional): Optional namespace of the vector. Returns: Optional[BaseVectorStoreDriver.Entry]: Vector entry if found, else None. """ results = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id]) if results: entry = results[0] if entry.payload is None: entry.payload = {} return BaseVectorStoreDriver.Entry( id=str(entry.id), vector=entry.vector if entry.vector is not None else [], # pyright: ignore[reportArgumentType] meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]}, ) return None def load_entries(self, *, namespace: Optional[str] = None, **kwargs) -> list[BaseVectorStoreDriver.Entry]: """Load vector entries from the Qdrant collection. Parameters: namespace: Optional namespace of the vectors. Returns: List of points. """ results = self.client.retrieve( collection_name=self.collection_name, ids=kwargs.get("ids", []), with_payload=kwargs.get("with_payload", True), with_vectors=kwargs.get("with_vectors", True), ) return [ BaseVectorStoreDriver.Entry( id=str(entry.id), vector=entry.vector if kwargs.get("with_vectors", True) else [], # pyright: ignore[reportArgumentType] meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]}, ) for entry in results if entry.payload is not None ]
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributecollection_name = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributecontent_payload_key = field(default=CONTENT_PAYLOAD_KEY, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributedistance = field(default=DEFAULT_DISTANCE, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeforce_disable_check_same_thread = field(default=False, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributegrpc_port = field(default=6334, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributehost = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributehttps = field(default=True, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributelocation = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributepath = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeport = field(default=6333, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeprefer_grpc = field(default=False, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeprefix = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetimeout = field(default=5, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeurl = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributevector_name = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
@lazy_property() def client(self) -> QdrantClient: return import_optional_dependency("qdrant_client").QdrantClient( location=self.location, url=self.url, host=self.host, path=self.path, port=self.port, prefer_grpc=self.prefer_grpc, grpc_port=self.grpc_port, api_key=self.api_key, https=self.https, prefix=self.prefix, force_disable_check_same_thread=self.force_disable_check_same_thread, timeout=self.timeout, )
delete_vector(vector_id)
Delete a vector from the Qdrant collection based on its ID.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector_id | str | id | ID of the vector to delete. | required |
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
def delete_vector(self, vector_id: str) -> None: """Delete a vector from the Qdrant collection based on its ID. Parameters: vector_id (str | id): ID of the vector to delete. """ deletion_response = self.client.delete( collection_name=self.collection_name, points_selector=import_optional_dependency("qdrant_client.http.models").PointIdsList(points=[vector_id]), ) if deletion_response.status == import_optional_dependency("qdrant_client.http.models").UpdateStatus.COMPLETED: logging.info("ID %s is successfully deleted", vector_id)
load_entries(*, namespace=None, **kwargs)
Load vector entries from the Qdrant collection.
Parameters
Name | Type | Description | Default |
---|---|---|---|
namespace | Optional[str] | Optional namespace of the vectors. | None |
Returns
Type | Description |
---|---|
list[Entry] | List of points. |
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None, **kwargs) -> list[BaseVectorStoreDriver.Entry]: """Load vector entries from the Qdrant collection. Parameters: namespace: Optional namespace of the vectors. Returns: List of points. """ results = self.client.retrieve( collection_name=self.collection_name, ids=kwargs.get("ids", []), with_payload=kwargs.get("with_payload", True), with_vectors=kwargs.get("with_vectors", True), ) return [ BaseVectorStoreDriver.Entry( id=str(entry.id), vector=entry.vector if kwargs.get("with_vectors", True) else [], # pyright: ignore[reportArgumentType] meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]}, ) for entry in results if entry.payload is not None ]
load_entry(vector_id, *, namespace=None)
Load a vector entry from the Qdrant collection based on its ID.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector_id | str | ID of the vector to load. | required |
namespace | str | Optional namespace of the vector. | None |
Returns
Type | Description |
---|---|
Optional[Entry] | Optional[BaseVectorStoreDriver.Entry]: Vector entry if found, else None. |
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Load a vector entry from the Qdrant collection based on its ID. Parameters: vector_id (str): ID of the vector to load. namespace (str, optional): Optional namespace of the vector. Returns: Optional[BaseVectorStoreDriver.Entry]: Vector entry if found, else None. """ results = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id]) if results: entry = results[0] if entry.payload is None: entry.payload = {} return BaseVectorStoreDriver.Entry( id=str(entry.id), vector=entry.vector if entry.vector is not None else [], # pyright: ignore[reportArgumentType] meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]}, ) return None
query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)
Query the Qdrant collection based on a query vector.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector | list[float] | Query vector. | required |
count | Optional[int] | Optional number of results to return. | None |
namespace | Optional[str] | Optional namespace of the vectors. | None |
include_vectors | bool | Whether to include vectors in the results. | False |
Returns
Type | Description |
---|---|
list[Entry] | list[BaseVectorStoreDriver.Entry]: List of Entry objects. |
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Query the Qdrant collection based on a query vector. Parameters: vector (list[float]): Query vector. count (Optional[int]): Optional number of results to return. namespace (Optional[str]): Optional namespace of the vectors. include_vectors (bool): Whether to include vectors in the results. Returns: list[BaseVectorStoreDriver.Entry]: List of Entry objects. """ # Create a search request request = {"collection_name": self.collection_name, "query_vector": vector, "limit": count} request = {k: v for k, v in request.items() if v is not None} results = self.client.search(**request) # Convert results to QueryResult objects return [ BaseVectorStoreDriver.Entry( id=str(result.id), vector=result.vector if include_vectors else [], # pyright: ignore[reportArgumentType] score=result.score, meta={k: v for k, v in result.payload.items() if k not in ["_score", "_tensor_facets"]}, ) for result in results if result.payload is not None ]
upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, content=None, **kwargs)
Upsert vectors into the Qdrant collection.
Parameters
Name | Type | Description | Default |
---|---|---|---|
vector | list[float] | The vector to be upserted. | required |
vector_id | Optional[str] | Optional vector ID. | None |
namespace | Optional[str] | Optional namespace for the vector. | None |
meta | Optional[dict] | Optional dictionary containing metadata. | None |
content | Optional[str] | The text content to be included in the payload. | None |
Returns
Name | Type | Description |
---|---|---|
str | str | The ID of the upserted vector. |
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
def upsert_vector( self, vector: list[float], *, vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, content: Optional[str] = None, **kwargs, ) -> str: """Upsert vectors into the Qdrant collection. Parameters: vector (list[float]): The vector to be upserted. vector_id (Optional[str]): Optional vector ID. namespace (Optional[str]): Optional namespace for the vector. meta (Optional[dict]): Optional dictionary containing metadata. content (Optional[str]): The text content to be included in the payload. Returns: str: The ID of the upserted vector. """ if vector_id is None: vector_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(vector))) if meta is None: meta = {} if content: meta[self.content_payload_key] = content points = import_optional_dependency("qdrant_client.http.models").Batch( ids=[vector_id], vectors=[vector], payloads=[meta] if meta else None, ) self.client.upsert(collection_name=self.collection_name, points=points) return vector_id
RedisConversationMemoryDriver
Bases:
BaseConversationMemoryDriver
Attributes
Name | Type | Description |
---|---|---|
host | str | The host of the Redis instance. |
port | int | The port of the Redis instance. |
db | int | The database of the Redis instance. |
username | str | The username of the Redis instance. |
password | Optional[str] | The password of the Redis instance. |
index | str | The name of the index to use. |
conversation_id | str | The id of the conversation. |
Source Code in griptape/drivers/memory/conversation/redis_conversation_memory_driver.py
@define class RedisConversationMemoryDriver(BaseConversationMemoryDriver): """A Conversation Memory Driver for Redis. This driver interfaces with a Redis instance and utilizes the Redis hashes and RediSearch module to store, retrieve, and query conversations in a structured manner. Proper setup of the Redis instance and RediSearch is necessary for the driver to function correctly. Attributes: host: The host of the Redis instance. port: The port of the Redis instance. db: The database of the Redis instance. username: The username of the Redis instance. password: The password of the Redis instance. index: The name of the index to use. conversation_id: The id of the conversation. """ host: str = field(kw_only=True, metadata={"serializable": True}) username: str = field(kw_only=True, default="default", metadata={"serializable": False}) port: int = field(kw_only=True, metadata={"serializable": True}) db: int = field(kw_only=True, default=0, metadata={"serializable": True}) password: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) index: str = field(kw_only=True, metadata={"serializable": True}) conversation_id: str = field(kw_only=True, default=uuid.uuid4().hex) client: Redis = field( default=Factory( lambda self: import_optional_dependency("redis").Redis( host=self.host, port=self.port, db=self.db, username=self.username, password=self.password, decode_responses=False, ), takes_self=True, ), ) def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: self.client.hset(self.index, self.conversation_id, json.dumps(self._to_params_dict(runs, metadata))) def load(self) -> tuple[list[Run], dict[str, Any]]: memory_json = self.client.hget(self.index, self.conversation_id) if memory_json is not None: return self._from_params_dict(json.loads(memory_json)) # pyright: ignore[reportArgumentType] https://github.com/redis/redis-py/issues/2399 return [], {}
client = field(default=Factory(lambda self: import_optional_dependency('redis').Redis(host=self.host, port=self.port, db=self.db, username=self.username, password=self.password, decode_responses=False), takes_self=True))
class-attribute instance-attributeconversation_id = field(kw_only=True, default=uuid.uuid4().hex)
class-attribute instance-attributedb = field(kw_only=True, default=0, metadata={'serializable': True})
class-attribute instance-attributehost = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeindex = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributepassword = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeport = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeusername = field(kw_only=True, default='default', metadata={'serializable': False})
class-attribute instance-attribute
load()
Source Code in griptape/drivers/memory/conversation/redis_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]: memory_json = self.client.hget(self.index, self.conversation_id) if memory_json is not None: return self._from_params_dict(json.loads(memory_json)) # pyright: ignore[reportArgumentType] https://github.com/redis/redis-py/issues/2399 return [], {}
store(runs, metadata)
Source Code in griptape/drivers/memory/conversation/redis_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: self.client.hset(self.index, self.conversation_id, json.dumps(self._to_params_dict(runs, metadata)))
RedisVectorStoreDriver
Bases:
BaseVectorStoreDriver
Attributes
Name | Type | Description |
---|---|---|
host | str | The host of the Redis instance. |
port | int | The port of the Redis instance. |
db | int | The database of the Redis instance. |
username | str | The username of the Redis instance. |
password | Optional[str] | The password of the Redis instance. |
index | str | The name of the index to use. |
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
@define class RedisVectorStoreDriver(BaseVectorStoreDriver): """A Vector Store Driver for Redis. This driver interfaces with a Redis instance and utilizes the Redis hashes and RediSearch module to store, retrieve, and query vectors in a structured manner. Proper setup of the Redis instance and RediSearch is necessary for the driver to function correctly. Attributes: host: The host of the Redis instance. port: The port of the Redis instance. db: The database of the Redis instance. username: The username of the Redis instance. password: The password of the Redis instance. index: The name of the index to use. """ host: str = field(kw_only=True, metadata={"serializable": True}) username: str = field(kw_only=True, default="default", metadata={"serializable": False}) port: int = field(kw_only=True, metadata={"serializable": True}) db: int = field(kw_only=True, default=0, metadata={"serializable": True}) password: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) index: str = field(kw_only=True, metadata={"serializable": True}) _client: Optional[Redis] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> Redis: return import_optional_dependency("redis").Redis( host=self.host, port=self.port, db=self.db, username=self.username, password=self.password, decode_responses=False, ) def upsert_vector( self, vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: """Inserts or updates a vector in Redis. If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided. """ vector_id = vector_id or str_to_hash(str(vector)) key = self._generate_key(vector_id, namespace) bytes_vector = json.dumps(vector).encode("utf-8") mapping = {} mapping["vector"] = np.array(vector, dtype=np.float32).tobytes() mapping["vec_string"] = bytes_vector if namespace: mapping["namespace"] = namespace if meta: mapping["metadata"] = json.dumps(meta) self.client.hset(key, mapping=mapping) return vector_id def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Retrieves a specific vector entry from Redis based on its identifier and optional namespace. Returns: If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned. """ key = self._generate_key(vector_id, namespace) result = self.client.hgetall(key) vector = np.frombuffer(result[b"vector"], dtype=np.float32).tolist() # pyright: ignore[reportIndexIssue] https://github.com/redis/redis-py/issues/2399 meta = json.loads(result[b"metadata"]) if b"metadata" in result else None # pyright: ignore[reportIndexIssue, reportOperatorIssue] return BaseVectorStoreDriver.Entry(id=vector_id, meta=meta, vector=vector, namespace=namespace) def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Retrieves all vector entries from Redis that match the optional namespace. Returns: A list of `BaseVectorStoreDriver.Entry` objects. """ pattern = f"{namespace}:*" if namespace else "*" keys = self.client.keys(pattern) entries = [] for key in keys: # pyright: ignore[reportGeneralTypeIssues] https://github.com/redis/redis-py/issues/2399 entry = self.load_entry(key.decode("utf-8"), namespace=namespace) if entry: entries.append(entry) return entries def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Performs a nearest neighbor search on Redis to find vectors similar to the provided input vector. Results can be limited using the count parameter and optionally filtered by a namespace. Returns: A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. """ search_query = import_optional_dependency("redis.commands.search.query") filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*" query_expression = ( search_query.Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]") .sort_by("score") .return_fields("id", "score", "metadata", "vec_string") .paging(0, count or 10) .dialect(2) ) query_params = {"vector": np.array(vector, dtype=np.float32).tobytes()} results = self.client.ft(self.index).search(query_expression, query_params).docs # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] query_results = [] for document in results: metadata = json.loads(document.metadata) if hasattr(document, "metadata") else None namespace = document.id.split(":")[0] if ":" in document.id else None vector_id = document.id.split(":")[1] if ":" in document.id else document.id vector_float_list = json.loads(document.vec_string) if include_vectors else None query_results.append( BaseVectorStoreDriver.Entry( id=vector_id, vector=vector_float_list, score=float(document.score), meta=metadata, namespace=namespace, ), ) return query_results def _generate_key(self, vector_id: str, namespace: Optional[str] = None) -> str: """Generates a Redis key using the provided vector ID and optionally a namespace.""" return f"{namespace}:{vector_id}" if namespace else vector_id def _get_doc_prefix(self, namespace: Optional[str] = None) -> str: """Get the document prefix based on the provided namespace.""" return f"{namespace}:" if namespace else "" def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributedb = field(kw_only=True, default=0, metadata={'serializable': True})
class-attribute instance-attributehost = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeindex = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributepassword = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeport = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeusername = field(kw_only=True, default='default', metadata={'serializable': False})
class-attribute instance-attribute
_generate_key(vector_id, namespace=None)
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def _generate_key(self, vector_id: str, namespace: Optional[str] = None) -> str: """Generates a Redis key using the provided vector ID and optionally a namespace.""" return f"{namespace}:{vector_id}" if namespace else vector_id
_get_doc_prefix(namespace=None)
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def _get_doc_prefix(self, namespace: Optional[str] = None) -> str: """Get the document prefix based on the provided namespace.""" return f"{namespace}:" if namespace else ""
client()
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
@lazy_property() def client(self) -> Redis: return import_optional_dependency("redis").Redis( host=self.host, port=self.port, db=self.db, username=self.username, password=self.password, decode_responses=False, )
delete_vector(vector_id)
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn: raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
load_entries(*, namespace=None)
Retrieves all vector entries from Redis that match the optional namespace.
Returns
Type | Description |
---|---|
list[Entry] | A list of BaseVectorStoreDriver.Entry objects. |
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]: """Retrieves all vector entries from Redis that match the optional namespace. Returns: A list of `BaseVectorStoreDriver.Entry` objects. """ pattern = f"{namespace}:*" if namespace else "*" keys = self.client.keys(pattern) entries = [] for key in keys: # pyright: ignore[reportGeneralTypeIssues] https://github.com/redis/redis-py/issues/2399 entry = self.load_entry(key.decode("utf-8"), namespace=namespace) if entry: entries.append(entry) return entries
load_entry(vector_id, *, namespace=None)
Retrieves a specific vector entry from Redis based on its identifier and optional namespace.
Returns
Type | Description |
---|---|
Optional[Entry] | If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned. |
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: """Retrieves a specific vector entry from Redis based on its identifier and optional namespace. Returns: If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned. """ key = self._generate_key(vector_id, namespace) result = self.client.hgetall(key) vector = np.frombuffer(result[b"vector"], dtype=np.float32).tolist() # pyright: ignore[reportIndexIssue] https://github.com/redis/redis-py/issues/2399 meta = json.loads(result[b"metadata"]) if b"metadata" in result else None # pyright: ignore[reportIndexIssue, reportOperatorIssue] return BaseVectorStoreDriver.Entry(id=vector_id, meta=meta, vector=vector, namespace=namespace)
query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)
Performs a nearest neighbor search on Redis to find vectors similar to the provided input vector.
Results can be limited using the count parameter and optionally filtered by a namespace.
Returns
Type | Description |
---|---|
list[Entry] | A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. |
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def query_vector( self, vector: list[float], *, count: Optional[int] = None, namespace: Optional[str] = None, include_vectors: bool = False, **kwargs, ) -> list[BaseVectorStoreDriver.Entry]: """Performs a nearest neighbor search on Redis to find vectors similar to the provided input vector. Results can be limited using the count parameter and optionally filtered by a namespace. Returns: A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace. """ search_query = import_optional_dependency("redis.commands.search.query") filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*" query_expression = ( search_query.Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]") .sort_by("score") .return_fields("id", "score", "metadata", "vec_string") .paging(0, count or 10) .dialect(2) ) query_params = {"vector": np.array(vector, dtype=np.float32).tobytes()} results = self.client.ft(self.index).search(query_expression, query_params).docs # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] query_results = [] for document in results: metadata = json.loads(document.metadata) if hasattr(document, "metadata") else None namespace = document.id.split(":")[0] if ":" in document.id else None vector_id = document.id.split(":")[1] if ":" in document.id else document.id vector_float_list = json.loads(document.vec_string) if include_vectors else None query_results.append( BaseVectorStoreDriver.Entry( id=vector_id, vector=vector_float_list, score=float(document.score), meta=metadata, namespace=namespace, ), ) return query_results
upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def upsert_vector( self, vector: list[float], vector_id: Optional[str] = None, namespace: Optional[str] = None, meta: Optional[dict] = None, **kwargs, ) -> str: """Inserts or updates a vector in Redis. If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted. Metadata associated with the vector can also be provided. """ vector_id = vector_id or str_to_hash(str(vector)) key = self._generate_key(vector_id, namespace) bytes_vector = json.dumps(vector).encode("utf-8") mapping = {} mapping["vector"] = np.array(vector, dtype=np.float32).tobytes() mapping["vec_string"] = bytes_vector if namespace: mapping["namespace"] = namespace if meta: mapping["metadata"] = json.dumps(meta) self.client.hset(key, mapping=mapping) return vector_id
SnowflakeSqlDriver
Bases:
BaseSqlDriver
Source Code in griptape/drivers/sql/snowflake_sql_driver.py
@define class SnowflakeSqlDriver(BaseSqlDriver): get_connection: Callable[[], SnowflakeConnection] = field(kw_only=True) _engine: Optional[Engine] = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False}) @get_connection.validator # pyright: ignore[reportFunctionMemberAccess] def validate_get_connection(self, _: Attribute, get_connection: Callable[[], SnowflakeConnection]) -> None: snowflake_connection = get_connection() snowflake = import_optional_dependency("snowflake") if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection): raise ValueError("The get_connection function must return a SnowflakeConnection") if not snowflake_connection.schema or not snowflake_connection.database: raise ValueError("Provide a schema and database for the Snowflake connection") @lazy_property() def engine(self) -> Engine: return import_optional_dependency("sqlalchemy").create_engine( "snowflake://not@used/db", creator=self.get_connection, ) def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: rows = self.execute_query_raw(query) if rows: return [BaseSqlDriver.RowResult(row) for row in rows] return None def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]: sqlalchemy = import_optional_dependency("sqlalchemy") with self.engine.connect() as con: results = con.execute(sqlalchemy.text(query)) if results is not None: if results.returns_rows: return [dict(result._mapping) for result in results] return None raise ValueError("No results found") def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: sqlalchemy = import_optional_dependency("sqlalchemy") try: metadata_obj = sqlalchemy.MetaData() metadata_obj.reflect(bind=self.engine) table = sqlalchemy.Table(table_name, metadata_obj, schema=schema, autoload=True, autoload_with=self.engine) return str([(c.name, c.type) for c in table.columns]) except sqlalchemy.exc.NoSuchTableError: return None
_engine = field(default=None, kw_only=True, alias='engine', metadata={'serializable': False})
class-attribute instance-attributeget_connection = field(kw_only=True)
class-attribute instance-attribute
engine()
Source Code in griptape/drivers/sql/snowflake_sql_driver.py
@lazy_property() def engine(self) -> Engine: return import_optional_dependency("sqlalchemy").create_engine( "snowflake://not@used/db", creator=self.get_connection, )
execute_query(query)
Source Code in griptape/drivers/sql/snowflake_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: rows = self.execute_query_raw(query) if rows: return [BaseSqlDriver.RowResult(row) for row in rows] return None
execute_query_raw(query)
Source Code in griptape/drivers/sql/snowflake_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]: sqlalchemy = import_optional_dependency("sqlalchemy") with self.engine.connect() as con: results = con.execute(sqlalchemy.text(query)) if results is not None: if results.returns_rows: return [dict(result._mapping) for result in results] return None raise ValueError("No results found")
get_table_schema(table_name, schema=None)
Source Code in griptape/drivers/sql/snowflake_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: sqlalchemy = import_optional_dependency("sqlalchemy") try: metadata_obj = sqlalchemy.MetaData() metadata_obj.reflect(bind=self.engine) table = sqlalchemy.Table(table_name, metadata_obj, schema=schema, autoload=True, autoload_with=self.engine) return str([(c.name, c.type) for c in table.columns]) except sqlalchemy.exc.NoSuchTableError: return None
validateget_connection(, get_connection)
Source Code in griptape/drivers/sql/snowflake_sql_driver.py
@get_connection.validator # pyright: ignore[reportFunctionMemberAccess] def validate_get_connection(self, _: Attribute, get_connection: Callable[[], SnowflakeConnection]) -> None: snowflake_connection = get_connection() snowflake = import_optional_dependency("snowflake") if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection): raise ValueError("The get_connection function must return a SnowflakeConnection") if not snowflake_connection.schema or not snowflake_connection.database: raise ValueError("Provide a schema and database for the Snowflake connection")
SqlDriver
Bases:
BaseSqlDriver
Source Code in griptape/drivers/sql/sql_driver.py
@define class SqlDriver(BaseSqlDriver): engine_url: str = field(kw_only=True) create_engine_params: dict = field(factory=dict, kw_only=True) _engine: Optional[Engine] = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False}) @lazy_property() def engine(self) -> Engine: return import_optional_dependency("sqlalchemy").create_engine(self.engine_url, **self.create_engine_params) def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: rows = self.execute_query_raw(query) if rows: return [BaseSqlDriver.RowResult(row) for row in rows] return None def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]: sqlalchemy = import_optional_dependency("sqlalchemy") with self.engine.connect() as con: results = con.execute(sqlalchemy.text(query)) if results is not None: if results.returns_rows: return [dict(result._mapping) for result in results] con.commit() return None raise ValueError("No result found") def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: sqlalchemy_exc = import_optional_dependency("sqlalchemy.exc") try: return str(SqlDriver._get_table_schema(self.engine, table_name, schema)) except sqlalchemy_exc.NoSuchTableError: return None @staticmethod @lru_cache def _get_table_schema( engine: Engine, table_name: str, schema: Optional[str] = None ) -> Optional[list[tuple[str, str]]]: sqlalchemy = import_optional_dependency("sqlalchemy") return [(col["name"], col["type"]) for col in sqlalchemy.inspect(engine).get_columns(table_name, schema=schema)]
_engine = field(default=None, kw_only=True, alias='engine', metadata={'serializable': False})
class-attribute instance-attributecreate_engine_params = field(factory=dict, kw_only=True)
class-attribute instance-attributeengine_url = field(kw_only=True)
class-attribute instance-attribute
_get_table_schema(engine, table_name, schema=None)cachedstaticmethod
Source Code in griptape/drivers/sql/sql_driver.py
@staticmethod @lru_cache def _get_table_schema( engine: Engine, table_name: str, schema: Optional[str] = None ) -> Optional[list[tuple[str, str]]]: sqlalchemy = import_optional_dependency("sqlalchemy") return [(col["name"], col["type"]) for col in sqlalchemy.inspect(engine).get_columns(table_name, schema=schema)]
engine()
Source Code in griptape/drivers/sql/sql_driver.py
@lazy_property() def engine(self) -> Engine: return import_optional_dependency("sqlalchemy").create_engine(self.engine_url, **self.create_engine_params)
execute_query(query)
Source Code in griptape/drivers/sql/sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: rows = self.execute_query_raw(query) if rows: return [BaseSqlDriver.RowResult(row) for row in rows] return None
execute_query_raw(query)
Source Code in griptape/drivers/sql/sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]: sqlalchemy = import_optional_dependency("sqlalchemy") with self.engine.connect() as con: results = con.execute(sqlalchemy.text(query)) if results is not None: if results.returns_rows: return [dict(result._mapping) for result in results] con.commit() return None raise ValueError("No result found")
get_table_schema(table_name, schema=None)
Source Code in griptape/drivers/sql/sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: sqlalchemy_exc = import_optional_dependency("sqlalchemy.exc") try: return str(SqlDriver._get_table_schema(self.engine, table_name, schema)) except sqlalchemy_exc.NoSuchTableError: return None
StableDiffusion3ControlNetImageGenerationPipelineDriver
Bases:
StableDiffusion3ImageGenerationPipelineDriver
Attributes
Name | Type | Description |
---|---|---|
controlnet_model | str | The ControlNet model to use for image generation. |
controlnet_conditioning_scale | Optional[float] | The conditioning scale for the ControlNet model. Defaults to None. |
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py
@define class StableDiffusion3ControlNetImageGenerationPipelineDriver(StableDiffusion3ImageGenerationPipelineDriver): """Image generation model driver for Stable Diffusion 3 models with ControlNet. For more information, see the HuggingFace documentation for the StableDiffusion3ControlNetPipeline: https://huggingface.co/docs/diffusers/en/api/pipelines/controlnet_sd3 Attributes: controlnet_model: The ControlNet model to use for image generation. controlnet_conditioning_scale: The conditioning scale for the ControlNet model. Defaults to None. """ controlnet_model: str = field(kw_only=True) controlnet_conditioning_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel sd3_controlnet_pipeline = import_optional_dependency( "diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet" ).StableDiffusion3ControlNetPipeline pipeline_params = {} controlnet_pipeline_params = {} if self.torch_dtype is not None: pipeline_params["torch_dtype"] = self.torch_dtype controlnet_pipeline_params["torch_dtype"] = self.torch_dtype if self.drop_t5_encoder: pipeline_params["text_encoder_3"] = None pipeline_params["tokenizer_3"] = None # For both Stable Diffusion and ControlNet, models can be provided either # as a path to a local file or as a HuggingFace model repo name. # We use the from_single_file method if the model is a local file and the # from_pretrained method if the model is a local directory or hosted on HuggingFace. if os.path.isfile(self.controlnet_model): pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file( self.controlnet_model, **controlnet_pipeline_params ) else: pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained( self.controlnet_model, **controlnet_pipeline_params ) if os.path.isfile(model): pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params) else: pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params) if self.enable_model_cpu_offload: pipeline.enable_model_cpu_offload() if device is not None: pipeline.to(device) return pipeline def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: if image is None: raise ValueError("Input image is required for ControlNet pipelines.") return {"control_image": image} def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: additional_params = super().make_additional_params(negative_prompts, device) del additional_params["height"] del additional_params["width"] if self.controlnet_conditioning_scale is not None: additional_params["controlnet_conditioning_scale"] = self.controlnet_conditioning_scale return additional_params
controlnet_conditioning_scale = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributecontrolnet_model = field(kw_only=True)
class-attribute instance-attribute
make_additional_params(negative_prompts, device)
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: additional_params = super().make_additional_params(negative_prompts, device) del additional_params["height"] del additional_params["width"] if self.controlnet_conditioning_scale is not None: additional_params["controlnet_conditioning_scale"] = self.controlnet_conditioning_scale return additional_params
make_image_param(image)
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: if image is None: raise ValueError("Input image is required for ControlNet pipelines.") return {"control_image": image}
prepare_pipeline(model, device)
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel sd3_controlnet_pipeline = import_optional_dependency( "diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet" ).StableDiffusion3ControlNetPipeline pipeline_params = {} controlnet_pipeline_params = {} if self.torch_dtype is not None: pipeline_params["torch_dtype"] = self.torch_dtype controlnet_pipeline_params["torch_dtype"] = self.torch_dtype if self.drop_t5_encoder: pipeline_params["text_encoder_3"] = None pipeline_params["tokenizer_3"] = None # For both Stable Diffusion and ControlNet, models can be provided either # as a path to a local file or as a HuggingFace model repo name. # We use the from_single_file method if the model is a local file and the # from_pretrained method if the model is a local directory or hosted on HuggingFace. if os.path.isfile(self.controlnet_model): pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file( self.controlnet_model, **controlnet_pipeline_params ) else: pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained( self.controlnet_model, **controlnet_pipeline_params ) if os.path.isfile(model): pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params) else: pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params) if self.enable_model_cpu_offload: pipeline.enable_model_cpu_offload() if device is not None: pipeline.to(device) return pipeline
StableDiffusion3ImageGenerationPipelineDriver
Bases:
BaseDiffusionImageGenerationPipelineDriver
Attributes
Name | Type | Description |
---|---|---|
width | int | The width of the generated image. Defaults to 1024. Must be a multiple of 64. |
height | int | The height of the generated image. Defaults to 1024. Must be a multiple of 64. |
seed | Optional[int] | The random seed to use for image generation. If not provided, a random seed will be used. |
guidance_scale | Optional[float] | The strength of the guidance loss. If not provided, the default value will be used. |
steps | Optional[int] | The number of inference steps to use in image generation. If not provided, the default value will be used. |
torch_dtype | Optional[dtype] | The torch data type to use for image generation. If not provided, the default value will be used. |
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py
@define class StableDiffusion3ImageGenerationPipelineDriver(BaseDiffusionImageGenerationPipelineDriver): """Image generation model driver for Stable Diffusion 3 models. For more information, see the HuggingFace documentation for the StableDiffusion3Pipeline: https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_3 Attributes: width: The width of the generated image. Defaults to 1024. Must be a multiple of 64. height: The height of the generated image. Defaults to 1024. Must be a multiple of 64. seed: The random seed to use for image generation. If not provided, a random seed will be used. guidance_scale: The strength of the guidance loss. If not provided, the default value will be used. steps: The number of inference steps to use in image generation. If not provided, the default value will be used. torch_dtype: The torch data type to use for image generation. If not provided, the default value will be used. """ width: int = field(default=1024, kw_only=True, metadata={"serializable": True}) height: int = field(default=1024, kw_only=True, metadata={"serializable": True}) seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) guidance_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) steps: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) torch_dtype: Optional[torch.dtype] = field(default=None, kw_only=True, metadata={"serializable": True}) enable_model_cpu_offload: bool = field(default=False, kw_only=True, metadata={"serializable": True}) drop_t5_encoder: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: sd3_pipeline = import_optional_dependency( "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3" ).StableDiffusion3Pipeline pipeline_params = {} if self.torch_dtype is not None: pipeline_params["torch_dtype"] = self.torch_dtype if self.drop_t5_encoder: pipeline_params["text_encoder_3"] = None pipeline_params["tokenizer_3"] = None # A model can be provided either as a path to a local file # or as a HuggingFace model repo name. if os.path.isfile(model): # If the model provided is a local file (not a directory), # we load it using the from_single_file method. pipeline = sd3_pipeline.from_single_file(model, **pipeline_params) else: # If the model is a local directory or hosted on HuggingFace, # we load it using the from_pretrained method. pipeline = sd3_pipeline.from_pretrained(model, **pipeline_params) if self.enable_model_cpu_offload: pipeline.enable_model_cpu_offload() # Move inference to particular device if requested. if device is not None: pipeline.to(device) return pipeline def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: return None def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: torch_generator = import_optional_dependency("torch").Generator additional_params = {} if negative_prompts: additional_params["negative_prompt"] = ", ".join(negative_prompts) if self.width is not None: additional_params["width"] = self.width if self.height is not None: additional_params["height"] = self.height if self.seed is not None: additional_params["generator"] = [torch_generator(device=device).manual_seed(self.seed)] if self.guidance_scale is not None: additional_params["guidance_scale"] = self.guidance_scale if self.steps is not None: additional_params["num_inference_steps"] = self.steps return additional_params @property def output_image_dimensions(self) -> tuple[int, int]: return self.width, self.height
drop_t5_encoder = field(default=False, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeenable_model_cpu_offload = field(default=False, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeguidance_scale = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeheight = field(default=1024, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeoutput_image_dimensions
propertyseed = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesteps = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetorch_dtype = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributewidth = field(default=1024, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
make_additional_params(negative_prompts, device)
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: torch_generator = import_optional_dependency("torch").Generator additional_params = {} if negative_prompts: additional_params["negative_prompt"] = ", ".join(negative_prompts) if self.width is not None: additional_params["width"] = self.width if self.height is not None: additional_params["height"] = self.height if self.seed is not None: additional_params["generator"] = [torch_generator(device=device).manual_seed(self.seed)] if self.guidance_scale is not None: additional_params["guidance_scale"] = self.guidance_scale if self.steps is not None: additional_params["num_inference_steps"] = self.steps return additional_params
make_image_param(image)
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: return None
prepare_pipeline(model, device)
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: sd3_pipeline = import_optional_dependency( "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3" ).StableDiffusion3Pipeline pipeline_params = {} if self.torch_dtype is not None: pipeline_params["torch_dtype"] = self.torch_dtype if self.drop_t5_encoder: pipeline_params["text_encoder_3"] = None pipeline_params["tokenizer_3"] = None # A model can be provided either as a path to a local file # or as a HuggingFace model repo name. if os.path.isfile(model): # If the model provided is a local file (not a directory), # we load it using the from_single_file method. pipeline = sd3_pipeline.from_single_file(model, **pipeline_params) else: # If the model is a local directory or hosted on HuggingFace, # we load it using the from_pretrained method. pipeline = sd3_pipeline.from_pretrained(model, **pipeline_params) if self.enable_model_cpu_offload: pipeline.enable_model_cpu_offload() # Move inference to particular device if requested. if device is not None: pipeline.to(device) return pipeline
StableDiffusion3Img2ImgImageGenerationPipelineDriver
Bases:
StableDiffusion3ImageGenerationPipelineDriver
Attributes
Name | Type | Description |
---|---|---|
strength | Optional[float] | A value [0.0, 1.0] that determines the strength of the initial image in the output. |
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py
@define class StableDiffusion3Img2ImgImageGenerationPipelineDriver(StableDiffusion3ImageGenerationPipelineDriver): """Image generation model driver for Stable Diffusion 3 model image to image pipelines. For more information, see the HuggingFace documentation for the StableDiffusion3Img2ImgPipeline: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py Attributes: strength: A value [0.0, 1.0] that determines the strength of the initial image in the output. """ strength: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True}) def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: sd3_img2img_pipeline = import_optional_dependency( "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img" ).StableDiffusion3Img2ImgPipeline pipeline_params = {} if self.torch_dtype is not None: pipeline_params["torch_dtype"] = self.torch_dtype if self.drop_t5_encoder: pipeline_params["text_encoder_3"] = None pipeline_params["tokenizer_3"] = None # A model can be provided either as a path to a local file # or as a HuggingFace model repo name. if os.path.isfile(model): # If the model provided is a local file (not a directory), # we load it using the from_single_file method. pipeline = sd3_img2img_pipeline.from_single_file(model, **pipeline_params) else: # If the model is a local directory or hosted on HuggingFace, # we load it using the from_pretrained method. pipeline = sd3_img2img_pipeline.from_pretrained(model, **pipeline_params) if self.enable_model_cpu_offload: pipeline.enable_model_cpu_offload() # Move inference to particular device if requested. if device is not None: pipeline.to(device) return pipeline def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: if image is None: raise ValueError("Input image is required for image to image pipelines.") return {"image": image} def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: additional_params = super().make_additional_params(negative_prompts, device) # Explicit height and width params are not supported, but # are instead inferred from input image. del additional_params["height"] del additional_params["width"] if self.strength is not None: additional_params["strength"] = self.strength return additional_params
strength = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
make_additional_params(negative_prompts, device)
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]: additional_params = super().make_additional_params(negative_prompts, device) # Explicit height and width params are not supported, but # are instead inferred from input image. del additional_params["height"] del additional_params["width"] if self.strength is not None: additional_params["strength"] = self.strength return additional_params
make_image_param(image)
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: if image is None: raise ValueError("Input image is required for image to image pipelines.") return {"image": image}
prepare_pipeline(model, device)
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: sd3_img2img_pipeline = import_optional_dependency( "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img" ).StableDiffusion3Img2ImgPipeline pipeline_params = {} if self.torch_dtype is not None: pipeline_params["torch_dtype"] = self.torch_dtype if self.drop_t5_encoder: pipeline_params["text_encoder_3"] = None pipeline_params["tokenizer_3"] = None # A model can be provided either as a path to a local file # or as a HuggingFace model repo name. if os.path.isfile(model): # If the model provided is a local file (not a directory), # we load it using the from_single_file method. pipeline = sd3_img2img_pipeline.from_single_file(model, **pipeline_params) else: # If the model is a local directory or hosted on HuggingFace, # we load it using the from_pretrained method. pipeline = sd3_img2img_pipeline.from_pretrained(model, **pipeline_params) if self.enable_model_cpu_offload: pipeline.enable_model_cpu_offload() # Move inference to particular device if requested. if device is not None: pipeline.to(device) return pipeline
TavilyWebSearchDriver
Bases:
BaseWebSearchDriver
Source Code in griptape/drivers/web_search/tavily_web_search_driver.py
@define class TavilyWebSearchDriver(BaseWebSearchDriver): api_key: str = field(kw_only=True) params: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) _client: Optional[TavilyClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> TavilyClient: return import_optional_dependency("tavily").TavilyClient(self.api_key) def search(self, query: str, **kwargs) -> ListArtifact: response = self.client.search(query, max_results=self.results_count, **self.params, **kwargs) results = response.get("results", []) return ListArtifact([(JsonArtifact(result)) for result in results])
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(kw_only=True)
class-attribute instance-attributeparams = field(factory=dict, kw_only=True, metadata={'serializable': True})
class-attribute instance-attribute
client()
Source Code in griptape/drivers/web_search/tavily_web_search_driver.py
@lazy_property() def client(self) -> TavilyClient: return import_optional_dependency("tavily").TavilyClient(self.api_key)
search(query, **kwargs)
Source Code in griptape/drivers/web_search/tavily_web_search_driver.py
def search(self, query: str, **kwargs) -> ListArtifact: response = self.client.search(query, max_results=self.results_count, **self.params, **kwargs) results = response.get("results", []) return ListArtifact([(JsonArtifact(result)) for result in results])
TrafilaturaWebScraperDriver
Bases:
BaseWebScraperDriver
Source Code in griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py
@define class TrafilaturaWebScraperDriver(BaseWebScraperDriver): include_links: bool = field(default=True, kw_only=True) no_ssl: bool = field(default=False, kw_only=True) def fetch_url(self, url: str) -> str: trafilatura = import_optional_dependency("trafilatura") use_config = trafilatura.settings.use_config config = use_config() page = trafilatura.fetch_url(url, no_ssl=self.no_ssl) # This disables signal, so that trafilatura can work on any thread: # More info: https://trafilatura.readthedocs.io/usage-python.html#disabling-signal config.set("DEFAULT", "EXTRACTION_TIMEOUT", "0") # Disable error logging in trafilatura as it sometimes logs errors from lxml, even though # the end result of page parsing is successful. logging.getLogger("trafilatura").setLevel(logging.FATAL) if page is None: raise Exception("can't access URL") return page def extract_page(self, page: str) -> TextArtifact: trafilatura = import_optional_dependency("trafilatura") use_config = trafilatura.settings.use_config config = use_config() extracted_page = trafilatura.extract( page, include_links=self.include_links, output_format="json", config=config, ) if not extracted_page: raise Exception("can't extract page") text = json.loads(extracted_page).get("text") return TextArtifact(text)
include_links = field(default=True, kw_only=True)
class-attribute instance-attributeno_ssl = field(default=False, kw_only=True)
class-attribute instance-attribute
extract_page(page)
Source Code in griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py
def extract_page(self, page: str) -> TextArtifact: trafilatura = import_optional_dependency("trafilatura") use_config = trafilatura.settings.use_config config = use_config() extracted_page = trafilatura.extract( page, include_links=self.include_links, output_format="json", config=config, ) if not extracted_page: raise Exception("can't extract page") text = json.loads(extracted_page).get("text") return TextArtifact(text)
fetch_url(url)
Source Code in griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py
def fetch_url(self, url: str) -> str: trafilatura = import_optional_dependency("trafilatura") use_config = trafilatura.settings.use_config config = use_config() page = trafilatura.fetch_url(url, no_ssl=self.no_ssl) # This disables signal, so that trafilatura can work on any thread: # More info: https://trafilatura.readthedocs.io/usage-python.html#disabling-signal config.set("DEFAULT", "EXTRACTION_TIMEOUT", "0") # Disable error logging in trafilatura as it sometimes logs errors from lxml, even though # the end result of page parsing is successful. logging.getLogger("trafilatura").setLevel(logging.FATAL) if page is None: raise Exception("can't access URL") return page
VoyageAiEmbeddingDriver
Bases:
BaseEmbeddingDriver
Attributes
Name | Type | Description |
---|---|---|
model | str | VoyageAI embedding model name. Defaults to voyage-large-2 . |
api_key | Optional[str] | API key to pass directly. Defaults to VOYAGE_API_KEY environment variable. |
tokenizer | VoyageAiTokenizer | Optionally provide custom VoyageAiTokenizer . |
client | Any | Optionally provide custom VoyageAI Client . |
input_type | str | VoyageAI input type. Defaults to document . |
Source Code in griptape/drivers/embedding/voyageai_embedding_driver.py
@define class VoyageAiEmbeddingDriver(BaseEmbeddingDriver): """VoyageAI Embedding Driver. Attributes: model: VoyageAI embedding model name. Defaults to `voyage-large-2`. api_key: API key to pass directly. Defaults to `VOYAGE_API_KEY` environment variable. tokenizer: Optionally provide custom `VoyageAiTokenizer`. client: Optionally provide custom VoyageAI `Client`. input_type: VoyageAI input type. Defaults to `document`. """ DEFAULT_MODEL = "voyage-large-2" model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True}) api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False}) tokenizer: VoyageAiTokenizer = field( default=Factory(lambda self: VoyageAiTokenizer(model=self.model, api_key=self.api_key), takes_self=True), kw_only=True, ) input_type: str = field(default="document", kw_only=True, metadata={"serializable": True}) _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> Any: return import_optional_dependency("voyageai").Client(api_key=self.api_key) def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact, **kwargs) -> list[float]: if isinstance(artifact, TextArtifact): return self.try_embed_chunk(artifact.value, **kwargs) pil_image = import_optional_dependency("PIL.Image") return self.client.multimodal_embed([[pil_image.open(BytesIO(artifact.value))]], model=self.model).embeddings[0] def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0]
DEFAULT_MODEL = 'voyage-large-2'
class-attribute instance-attribute_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_key = field(default=None, kw_only=True, metadata={'serializable': False})
class-attribute instance-attributeinput_type = field(default='document', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: VoyageAiTokenizer(model=self.model, api_key=self.api_key), takes_self=True), kw_only=True)
class-attribute instance-attribute
client()
Source Code in griptape/drivers/embedding/voyageai_embedding_driver.py
@lazy_property() def client(self) -> Any: return import_optional_dependency("voyageai").Client(api_key=self.api_key)
try_embed_artifact(artifact, **kwargs)
Source Code in griptape/drivers/embedding/voyageai_embedding_driver.py
def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact, **kwargs) -> list[float]: if isinstance(artifact, TextArtifact): return self.try_embed_chunk(artifact.value, **kwargs) pil_image = import_optional_dependency("PIL.Image") return self.client.multimodal_embed([[pil_image.open(BytesIO(artifact.value))]], model=self.model).embeddings[0]
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/voyageai_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0]
WebhookEventListenerDriver
Bases:
BaseEventListenerDriver
Source Code in griptape/drivers/event_listener/webhook_event_listener_driver.py
@define class WebhookEventListenerDriver(BaseEventListenerDriver): webhook_url: str = field(kw_only=True) headers: Optional[dict] = field(default=None, kw_only=True) def try_publish_event_payload(self, event_payload: dict) -> None: response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers) response.raise_for_status() def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: response = requests.post(url=self.webhook_url, json=event_payload_batch, headers=self.headers) response.raise_for_status()
headers = field(default=None, kw_only=True)
class-attribute instance-attributewebhook_url = field(kw_only=True)
class-attribute instance-attribute
try_publish_event_payload(event_payload)
Source Code in griptape/drivers/event_listener/webhook_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None: response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers) response.raise_for_status()
try_publish_event_payload_batch(event_payload_batch)
Source Code in griptape/drivers/event_listener/webhook_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: response = requests.post(url=self.webhook_url, json=event_payload_batch, headers=self.headers) response.raise_for_status()
- On this page
- AmazonBedrockImageGenerationDriver
- AmazonBedrockPromptDriver
- AmazonBedrockTitanEmbeddingDriver
- AmazonDynamoDbConversationMemoryDriver
- AmazonOpenSearchVectorStoreDriver
- AmazonRedshiftSqlDriver
- AmazonS3FileManagerDriver
- AmazonSageMakerJumpstartEmbeddingDriver
- AmazonSageMakerJumpstartPromptDriver
- AmazonSqsEventListenerDriver
- AnthropicPromptDriver
- AstraDbVectorStoreDriver
- AwsIotCoreEventListenerDriver
- AzureMongoDbVectorStoreDriver
- AzureOpenAiChatPromptDriver
- AzureOpenAiEmbeddingDriver
- AzureOpenAiImageGenerationDriver
- AzureOpenAiTextToSpeechDriver
- BaseAssistantDriver
- BaseAudioTranscriptionDriver
- BaseConversationMemoryDriver
- BaseDiffusionImageGenerationPipelineDriver
- BaseEmbeddingDriver
- BaseEventListenerDriver
- BaseFileManagerDriver
- BaseImageGenerationDriver
- BaseImageGenerationModelDriver
- BaseMultiModelImageGenerationDriver
- BaseObservabilityDriver
- BasePromptDriver
- BaseRerankDriver
- BaseRulesetDriver
- BaseSqlDriver
- BaseStructureRunDriver
- BaseTextToSpeechDriver
- BaseVectorStoreDriver
- BaseWebScraperDriver
- BaseWebSearchDriver
- BedrockStableDiffusionImageGenerationModelDriver
- BedrockTitanImageGenerationModelDriver
- CohereEmbeddingDriver
- CoherePromptDriver
- CohereRerankDriver
- DatadogObservabilityDriver
- DuckDuckGoWebSearchDriver
- DummyAudioTranscriptionDriver
- DummyEmbeddingDriver
- DummyImageGenerationDriver
- DummyPromptDriver
- DummyTextToSpeechDriver
- DummyVectorStoreDriver
- ElevenLabsTextToSpeechDriver
- ExaWebSearchDriver
- GoogleEmbeddingDriver
- GooglePromptDriver
- GoogleWebSearchDriver
- GriptapeCloudAssistantDriver
- GriptapeCloudConversationMemoryDriver
- GriptapeCloudEventListenerDriver
- GriptapeCloudFileManagerDriver
- GriptapeCloudImageGenerationDriver
- GriptapeCloudObservabilityDriver
- GriptapeCloudPromptDriver
- GriptapeCloudRulesetDriver
- GriptapeCloudStructureRunDriver
- GriptapeCloudVectorStoreDriver
- GrokPromptDriver
- HuggingFaceHubEmbeddingDriver
- HuggingFaceHubPromptDriver
- HuggingFacePipelineImageGenerationDriver
- HuggingFacePipelinePromptDriver
- LeonardoImageGenerationDriver
- LocalConversationMemoryDriver
- LocalFileManagerDriver
- LocalRerankDriver
- LocalRulesetDriver
- LocalStructureRunDriver
- LocalVectorStoreDriver
- MarkdownifyWebScraperDriver
- MarqoVectorStoreDriver
- MongoDbAtlasVectorStoreDriver
- NoOpObservabilityDriver
- OllamaEmbeddingDriver
- OllamaPromptDriver
- OpenAiAssistantDriver
- OpenAiAudioTranscriptionDriver
- OpenAiChatPromptDriver
- OpenAiEmbeddingDriver
- OpenAiImageGenerationDriver
- OpenAiTextToSpeechDriver
- OpenSearchVectorStoreDriver
- OpenTelemetryObservabilityDriver
- PerplexityPromptDriver
- PerplexityWebSearchDriver
- PgAiKnowledgeBaseVectorStoreDriver
- PgVectorVectorStoreDriver
- PineconeVectorStoreDriver
- ProxyWebScraperDriver
- PusherEventListenerDriver
- QdrantVectorStoreDriver
- RedisConversationMemoryDriver
- RedisVectorStoreDriver
- SnowflakeSqlDriver
- SqlDriver
- StableDiffusion3ControlNetImageGenerationPipelineDriver
- StableDiffusion3ImageGenerationPipelineDriver
- StableDiffusion3Img2ImgImageGenerationPipelineDriver
- TavilyWebSearchDriver
- TrafilaturaWebScraperDriver
- VoyageAiEmbeddingDriver
- WebhookEventListenerDriver
Could this page be better? Report a problem or suggest an addition!