• __all__ = ['AmazonBedrockCohereEmbeddingDriver', 'AmazonBedrockImageGenerationDriver', 'AmazonBedrockPromptDriver', 'AmazonBedrockTitanEmbeddingDriver', 'AmazonDynamoDbConversationMemoryDriver', 'AmazonOpenSearchVectorStoreDriver', 'AmazonRedshiftSqlDriver', 'AmazonS3FileManagerDriver', 'AmazonSageMakerJumpstartEmbeddingDriver', 'AmazonSageMakerJumpstartPromptDriver', 'AmazonSqsEventListenerDriver', 'AnthropicPromptDriver', 'AstraDbVectorStoreDriver', 'AwsIotCoreEventListenerDriver', 'AzureMongoDbVectorStoreDriver', 'AzureOpenAiChatPromptDriver', 'AzureOpenAiEmbeddingDriver', 'AzureOpenAiImageGenerationDriver', 'AzureOpenAiTextToSpeechDriver', 'BaseAssistantDriver', 'BaseAudioTranscriptionDriver', 'BaseConversationMemoryDriver', 'BaseDiffusionImageGenerationPipelineDriver', 'BaseEmbeddingDriver', 'BaseEventListenerDriver', 'BaseFileManagerDriver', 'BaseImageGenerationDriver', 'BaseImageGenerationModelDriver', 'BaseMultiModelImageGenerationDriver', 'BaseObservabilityDriver', 'BasePromptDriver', 'BaseRerankDriver', 'BaseRulesetDriver', 'BaseSqlDriver', 'BaseStructureRunDriver', 'BaseTextToSpeechDriver', 'BaseVectorStoreDriver', 'BaseWebScraperDriver', 'BaseWebSearchDriver', 'BedrockStableDiffusionImageGenerationModelDriver', 'BedrockTitanImageGenerationModelDriver', 'CohereEmbeddingDriver', 'CoherePromptDriver', 'CohereRerankDriver', 'DatadogObservabilityDriver', 'DuckDuckGoWebSearchDriver', 'DummyAudioTranscriptionDriver', 'DummyEmbeddingDriver', 'DummyImageGenerationDriver', 'DummyPromptDriver', 'DummyTextToSpeechDriver', 'DummyVectorStoreDriver', 'ElevenLabsTextToSpeechDriver', 'ExaWebSearchDriver', 'GoogleEmbeddingDriver', 'GooglePromptDriver', 'GoogleWebSearchDriver', 'GriptapeCloudAssistantDriver', 'GriptapeCloudConversationMemoryDriver', 'GriptapeCloudEventListenerDriver', 'GriptapeCloudFileManagerDriver', 'GriptapeCloudObservabilityDriver', 'GriptapeCloudPromptDriver', 'GriptapeCloudRulesetDriver', 'GriptapeCloudStructureRunDriver', 'GriptapeCloudVectorStoreDriver', 'GrokPromptDriver', 'HuggingFaceHubEmbeddingDriver', 'HuggingFaceHubPromptDriver', 'HuggingFacePipelineImageGenerationDriver', 'HuggingFacePipelinePromptDriver', 'LeonardoImageGenerationDriver', 'LocalConversationMemoryDriver', 'LocalFileManagerDriver', 'LocalRerankDriver', 'LocalRulesetDriver', 'LocalStructureRunDriver', 'LocalVectorStoreDriver', 'MarkdownifyWebScraperDriver', 'MarqoVectorStoreDriver', 'MongoDbAtlasVectorStoreDriver', 'NoOpObservabilityDriver', 'OllamaEmbeddingDriver', 'OllamaPromptDriver', 'OpenAiAssistantDriver', 'OpenAiAudioTranscriptionDriver', 'OpenAiChatPromptDriver', 'OpenAiEmbeddingDriver', 'OpenAiImageGenerationDriver', 'OpenAiTextToSpeechDriver', 'OpenSearchVectorStoreDriver', 'OpenTelemetryObservabilityDriver', 'PerplexityPromptDriver', 'PerplexityWebSearchDriver', 'PgAiKnowledgeBaseVectorStoreDriver', 'PgVectorVectorStoreDriver', 'PineconeVectorStoreDriver', 'ProxyWebScraperDriver', 'PusherEventListenerDriver', 'QdrantVectorStoreDriver', 'RedisConversationMemoryDriver', 'RedisVectorStoreDriver', 'SnowflakeSqlDriver', 'SqlDriver', 'StableDiffusion3ControlNetImageGenerationPipelineDriver', 'StableDiffusion3ImageGenerationPipelineDriver', 'StableDiffusion3Img2ImgImageGenerationPipelineDriver', 'StableDiffusion3ControlNetImageGenerationPipelineDriver', 'BaseImageGenerationDriver', 'BaseMultiModelImageGenerationDriver', 'OpenAiImageGenerationDriver', 'LeonardoImageGenerationDriver', 'AmazonBedrockImageGenerationDriver', 'AzureOpenAiImageGenerationDriver', 'DummyImageGenerationDriver', 'HuggingFacePipelineImageGenerationDriver', 'GriptapeCloudImageGenerationDriver', 'BaseWebScraperDriver', 'TrafilaturaWebScraperDriver', 'MarkdownifyWebScraperDriver', 'ProxyWebScraperDriver', 'BaseWebSearchDriver', 'GoogleWebSearchDriver', 'DuckDuckGoWebSearchDriver', 'ExaWebSearchDriver', 'TavilyWebSearchDriver', 'TrafilaturaWebScraperDriver', 'VoyageAiEmbeddingDriver', 'WebhookEventListenerDriver'] module-attribute

Bases: BaseEmbeddingDriver

Attributes

NameTypeDescription
modelstrEmbedding model name. Defaults to DEFAULT_MODEL.
input_typestrDefaults to search_query. Prepends special tokens to differentiate each type from one another: search_document when you encode documents for embeddings that you store in a vector database. search_query when querying your vector DB to find relevant documents.
sessionSessionOptionally provide custom boto3.Session.
tokenizerBaseTokenizerOptionally provide custom BedrockCohereTokenizer.
clientBedrockRuntimeClientOptionally provide custom bedrock-runtime client.
Source Code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
@define
class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
    """Amazon Bedrock Cohere Embedding Driver.

    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        input_type: Defaults to `search_query`. Prepends special tokens to differentiate each type from one another:
            `search_document` when you encode documents for embeddings that you store in a vector database.
            `search_query` when querying your vector DB to find relevant documents.
        session: Optionally provide custom `boto3.Session`.
        tokenizer: Optionally provide custom `BedrockCohereTokenizer`.
        client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "cohere.embed-english-v3"

    model: str = field(default=DEFAULT_MODEL, kw_only=True)
    input_type: str = field(default="search_query", kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    _client: Optional[BedrockRuntimeClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> BedrockRuntimeClient:
        return self.session.client("bedrock-runtime")

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        payload = {"input_type": self.input_type, "texts": [chunk]}

        response = self.client.invoke_model(
            body=json.dumps(payload),
            modelId=self.model,
            accept="*/*",
            contentType="application/json",
        )
        response_body = json.loads(response.get("body").read())

        return response_body.get("embeddings")[0]
  • DEFAULT_MODEL = 'cohere.embed-english-v3' class-attribute instance-attribute

  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • input_type = field(default='search_query', kw_only=True) class-attribute instance-attribute

  • model = field(default=DEFAULT_MODEL, kw_only=True) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

client()

Source Code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
@lazy_property()
def client(self) -> BedrockRuntimeClient:
    return self.session.client("bedrock-runtime")

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    payload = {"input_type": self.input_type, "texts": [chunk]}

    response = self.client.invoke_model(
        body=json.dumps(payload),
        modelId=self.model,
        accept="*/*",
        contentType="application/json",
    )
    response_body = json.loads(response.get("body").read())

    return response_body.get("embeddings")[0]

AmazonBedrockImageGenerationDriver

Bases: BaseMultiModelImageGenerationDriver

Attributes

NameTypeDescription
modelstrBedrock model ID.
sessionSessionboto3 session.
clientBedrockRuntimeClientBedrock runtime client.
image_widthintWidth of output images. Defaults to 512 and must be a multiple of 64.
image_heightintHeight of output images. Defaults to 512 and must be a multiple of 64.
seedOptional[int]Optionally provide a consistent seed to generation requests, increasing consistency in output.
Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
@define
class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver):
    """Driver for image generation models provided by Amazon Bedrock.

    Attributes:
        model: Bedrock model ID.
        session: boto3 session.
        client: Bedrock runtime client.
        image_width: Width of output images. Defaults to 512 and must be a multiple of 64.
        image_height: Height of output images. Defaults to 512 and must be a multiple of 64.
        seed: Optionally provide a consistent seed to generation requests, increasing consistency in output.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    image_width: int = field(default=512, kw_only=True, metadata={"serializable": True})
    image_height: int = field(default=512, kw_only=True, metadata={"serializable": True})
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    _client: Optional[BedrockRuntimeClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> BedrockRuntimeClient:
        return self.session.client("bedrock-runtime")

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        request = self.image_generation_model_driver.text_to_image_request_parameters(
            prompts,
            self.image_width,
            self.image_height,
            negative_prompts=negative_prompts,
            seed=self.seed,
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            value=image_bytes,
            format="png",
            width=self.image_width,
            height=self.image_height,
            meta={"prompt": ", ".join(prompts), "model": self.model},
        )

    def try_image_variation(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_variation_request_parameters(
            prompts,
            image=image,
            negative_prompts=negative_prompts,
            seed=self.seed,
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            meta={"prompt": ", ".join(prompts), "model": self.model},
        )

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_inpainting_request_parameters(
            prompts,
            image=image,
            mask=mask,
            negative_prompts=negative_prompts,
            seed=self.seed,
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            meta={"prompt": ", ".join(prompts), "model": self.model},
        )

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        request = self.image_generation_model_driver.image_outpainting_request_parameters(
            prompts,
            image=image,
            mask=mask,
            negative_prompts=negative_prompts,
            seed=self.seed,
        )

        image_bytes = self._make_request(request)

        return ImageArtifact(
            value=image_bytes,
            format="png",
            width=image.width,
            height=image.height,
            meta={"prompt": ", ".join(prompts), "model": self.model},
        )

    def _make_request(self, request: dict) -> bytes:
        response = self.client.invoke_model(
            body=json.dumps(request),
            modelId=self.model,
            accept="application/json",
            contentType="application/json",
        )

        response_body = json.loads(response.get("body").read())

        try:
            image_bytes = self.image_generation_model_driver.get_generated_image(response_body)
        except Exception as e:
            raise ValueError(f"Inpainting generation failed: {e}") from e

        return image_bytes
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • image_height = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • image_width = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • seed = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

_make_request(request)

Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def _make_request(self, request: dict) -> bytes:
    response = self.client.invoke_model(
        body=json.dumps(request),
        modelId=self.model,
        accept="application/json",
        contentType="application/json",
    )

    response_body = json.loads(response.get("body").read())

    try:
        image_bytes = self.image_generation_model_driver.get_generated_image(response_body)
    except Exception as e:
        raise ValueError(f"Inpainting generation failed: {e}") from e

    return image_bytes

client()

Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
@lazy_property()
def client(self) -> BedrockRuntimeClient:
    return self.session.client("bedrock-runtime")

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_inpainting_request_parameters(
        prompts,
        image=image,
        mask=mask,
        negative_prompts=negative_prompts,
        seed=self.seed,
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        meta={"prompt": ", ".join(prompts), "model": self.model},
    )

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_outpainting_request_parameters(
        prompts,
        image=image,
        mask=mask,
        negative_prompts=negative_prompts,
        seed=self.seed,
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        meta={"prompt": ", ".join(prompts), "model": self.model},
    )

try_image_variation(prompts, image, negative_prompts=None)

Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_image_variation(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    request = self.image_generation_model_driver.image_variation_request_parameters(
        prompts,
        image=image,
        negative_prompts=negative_prompts,
        seed=self.seed,
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        value=image_bytes,
        format="png",
        width=image.width,
        height=image.height,
        meta={"prompt": ", ".join(prompts), "model": self.model},
    )

try_text_to_image(prompts, negative_prompts=None)

Source Code in griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    request = self.image_generation_model_driver.text_to_image_request_parameters(
        prompts,
        self.image_width,
        self.image_height,
        negative_prompts=negative_prompts,
        seed=self.seed,
    )

    image_bytes = self._make_request(request)

    return ImageArtifact(
        value=image_bytes,
        format="png",
        width=self.image_width,
        height=self.image_height,
        meta={"prompt": ", ".join(prompts), "model": self.model},
    )

AmazonBedrockPromptDriver

Bases: BasePromptDriver

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@define
class AmazonBedrockPromptDriver(BasePromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    additional_model_request_fields: dict = field(default=Factory(dict), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    structured_output_strategy: StructuredOutputStrategy = field(
        default="tool", kw_only=True, metadata={"serializable": True}
    )
    tool_choice: dict = field(default=Factory(lambda: {"auto": {}}), kw_only=True, metadata={"serializable": True})
    _client: Optional[Any] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value == "native":
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @lazy_property()
    def client(self) -> Any:
        return self.session.client("bedrock-runtime")

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        params = self._base_params(prompt_stack)
        logger.debug(params)
        response = self.client.converse(**params)
        logger.debug(response)

        usage = response["usage"]
        output_message = response["output"]["message"]

        message_content = output_message.get("content", [])

        # Move reasoning content to the beginning of the content to have it appear first in output
        reasoning_content_index = next(
            (i for i, content in enumerate(message_content) if "reasoningContent" in content),
            None,
        )
        if reasoning_content_index is not None:
            message_content.insert(0, message_content.pop(reasoning_content_index))

        return Message(
            content=[self.__to_prompt_stack_message_content(content) for content in message_content],
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        params = self._base_params(prompt_stack)
        logger.debug(params)
        response = self.client.converse_stream(**params)

        stream = response.get("stream")
        if stream is not None:
            for event in stream:
                logger.debug(event)
                if "contentBlockDelta" in event or "contentBlockStart" in event:
                    yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
                elif "metadata" in event:
                    usage = event["metadata"]["usage"]
                    yield DeltaMessage(
                        usage=DeltaMessage.Usage(
                            input_tokens=usage["inputTokens"],
                            output_tokens=usage["outputTokens"],
                        ),
                    )
        else:
            raise Exception("model response is empty")

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]
        messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])

        params = {
            "modelId": self.model,
            "messages": messages,
            "system": system_messages,
            "inferenceConfig": {
                "temperature": self.temperature,
                **({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}),
            },
            "additionalModelRequestFields": self.additional_model_request_fields,
            **self.extra_params,
        }

        if prompt_stack.tools and self.use_native_tools:
            params["toolConfig"] = {
                "tools": [],
                "toolChoice": self.tool_choice,
            }

            if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
                params["toolConfig"]["toolChoice"] = {"any": {}}

            params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)

        return params

    def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
        return [
            {
                "role": self.__to_bedrock_role(message),
                "content": [self.__to_bedrock_message_content(content) for content in message.content],
            }
            for message in messages
        ]

    def __to_bedrock_role(self, message: Message) -> str:
        if message.is_assistant():
            return "assistant"
        return "user"

    def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]:
        return [
            {
                "toolSpec": {
                    "name": tool.to_native_tool_name(activity),
                    "description": tool.activity_description(activity),
                    "inputSchema": {
                        "json": tool.to_activity_json_schema(activity, "http://json-schema.org/draft-07/schema#"),
                    },
                },
            }
            for tool in tools
            for activity in tool.activities()
        ]

    def __to_bedrock_message_content(self, content: BaseMessageContent) -> dict:
        if isinstance(content, TextMessageContent):
            return {"text": content.artifact.to_text()}
        if isinstance(content, ImageMessageContent):
            artifact = content.artifact
            if isinstance(artifact, ImageArtifact):
                return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}}
            if isinstance(artifact, ImageUrlArtifact):
                return {"image": {"format": "png", "source": {"s3Location": {"uri": artifact.value}}}}
            raise ValueError(f"Unsupported image artifact type: {type(artifact)}")
        if isinstance(content, ActionCallMessageContent):
            action_call = content.artifact.value

            return {
                "toolUse": {
                    "toolUseId": action_call.tag,
                    "name": f"{action_call.name}_{action_call.path}",
                    "input": action_call.input,
                },
            }
        if isinstance(content, ActionResultMessageContent):
            artifact = content.artifact

            if isinstance(artifact, ListArtifact):
                message_content = [self.__to_bedrock_tool_use_content(artifact) for artifact in artifact.value]
            else:
                message_content = [self.__to_bedrock_tool_use_content(artifact)]

            return {
                "toolResult": {
                    "toolUseId": content.action.tag,
                    "content": message_content,
                    "status": "error" if isinstance(artifact, ErrorArtifact) else "success",
                },
            }
        return content.artifact.value

    def __to_bedrock_tool_use_content(self, artifact: BaseArtifact) -> dict:
        if isinstance(artifact, ImageArtifact):
            return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}}
        return {"text": artifact.to_text()}

    def __to_prompt_stack_message_content(self, content: dict) -> BaseMessageContent:
        if "text" in content:
            return TextMessageContent(TextArtifact(content["text"]))
        if "toolUse" in content:
            name, path = ToolAction.from_native_tool_name(content["toolUse"]["name"])
            return ActionCallMessageContent(
                artifact=ActionArtifact(
                    value=ToolAction(
                        tag=content["toolUse"]["toolUseId"],
                        name=name,
                        path=path,
                        input=content["toolUse"]["input"],
                    ),
                ),
            )
        if "reasoningContent" in content:
            return TextMessageContent(TextArtifact(content["reasoningContent"]["reasoningText"]["text"]))
        raise ValueError(f"Unsupported message content type: {content}")

    def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessageContent:
        if "contentBlockStart" in event:
            content_block = event["contentBlockStart"]["start"]

            if "toolUse" in content_block:
                name, path = ToolAction.from_native_tool_name(content_block["toolUse"]["name"])

                return ActionCallDeltaMessageContent(
                    index=event["contentBlockStart"]["contentBlockIndex"],
                    tag=content_block["toolUse"]["toolUseId"],
                    name=name,
                    path=path,
                )
            if "text" in content_block:
                return TextDeltaMessageContent(
                    content_block["text"],
                    index=event["contentBlockStart"]["contentBlockIndex"],
                )
            raise ValueError(f"Unsupported message content type: {event}")
        if "contentBlockDelta" in event:
            content_block_delta = event["contentBlockDelta"]

            if "text" in content_block_delta["delta"]:
                return TextDeltaMessageContent(
                    content_block_delta["delta"]["text"],
                    index=content_block_delta["contentBlockIndex"],
                )
            if "toolUse" in content_block_delta["delta"]:
                return ActionCallDeltaMessageContent(
                    index=content_block_delta["contentBlockIndex"],
                    partial_input=content_block_delta["delta"]["toolUse"]["input"],
                )
            raise ValueError(f"Unsupported message content type: {event}")
        raise ValueError(f"Unsupported message content type: {event}")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • additional_model_request_fields = field(default=Factory(dict), kw_only=True) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

  • structured_output_strategy = field(default='tool', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

  • tool_choice = field(default=Factory(lambda: {'auto': {}}), kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • use_native_tools = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_bedrock_message_content(content)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_message_content(self, content: BaseMessageContent) -> dict:
    if isinstance(content, TextMessageContent):
        return {"text": content.artifact.to_text()}
    if isinstance(content, ImageMessageContent):
        artifact = content.artifact
        if isinstance(artifact, ImageArtifact):
            return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}}
        if isinstance(artifact, ImageUrlArtifact):
            return {"image": {"format": "png", "source": {"s3Location": {"uri": artifact.value}}}}
        raise ValueError(f"Unsupported image artifact type: {type(artifact)}")
    if isinstance(content, ActionCallMessageContent):
        action_call = content.artifact.value

        return {
            "toolUse": {
                "toolUseId": action_call.tag,
                "name": f"{action_call.name}_{action_call.path}",
                "input": action_call.input,
            },
        }
    if isinstance(content, ActionResultMessageContent):
        artifact = content.artifact

        if isinstance(artifact, ListArtifact):
            message_content = [self.__to_bedrock_tool_use_content(artifact) for artifact in artifact.value]
        else:
            message_content = [self.__to_bedrock_tool_use_content(artifact)]

        return {
            "toolResult": {
                "toolUseId": content.action.tag,
                "content": message_content,
                "status": "error" if isinstance(artifact, ErrorArtifact) else "success",
            },
        }
    return content.artifact.value

__to_bedrock_messages(messages)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_messages(self, messages: list[Message]) -> list[dict]:
    return [
        {
            "role": self.__to_bedrock_role(message),
            "content": [self.__to_bedrock_message_content(content) for content in message.content],
        }
        for message in messages
    ]

__to_bedrock_role(message)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_role(self, message: Message) -> str:
    if message.is_assistant():
        return "assistant"
    return "user"

__to_bedrock_tool_use_content(artifact)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_tool_use_content(self, artifact: BaseArtifact) -> dict:
    if isinstance(artifact, ImageArtifact):
        return {"image": {"format": artifact.format, "source": {"bytes": artifact.value}}}
    return {"text": artifact.to_text()}

__to_bedrock_tools(tools)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_bedrock_tools(self, tools: list[BaseTool]) -> list[dict]:
    return [
        {
            "toolSpec": {
                "name": tool.to_native_tool_name(activity),
                "description": tool.activity_description(activity),
                "inputSchema": {
                    "json": tool.to_activity_json_schema(activity, "http://json-schema.org/draft-07/schema#"),
                },
            },
        }
        for tool in tools
        for activity in tool.activities()
    ]

__to_prompt_stack_delta_message_content(event)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, event: dict) -> BaseDeltaMessageContent:
    if "contentBlockStart" in event:
        content_block = event["contentBlockStart"]["start"]

        if "toolUse" in content_block:
            name, path = ToolAction.from_native_tool_name(content_block["toolUse"]["name"])

            return ActionCallDeltaMessageContent(
                index=event["contentBlockStart"]["contentBlockIndex"],
                tag=content_block["toolUse"]["toolUseId"],
                name=name,
                path=path,
            )
        if "text" in content_block:
            return TextDeltaMessageContent(
                content_block["text"],
                index=event["contentBlockStart"]["contentBlockIndex"],
            )
        raise ValueError(f"Unsupported message content type: {event}")
    if "contentBlockDelta" in event:
        content_block_delta = event["contentBlockDelta"]

        if "text" in content_block_delta["delta"]:
            return TextDeltaMessageContent(
                content_block_delta["delta"]["text"],
                index=content_block_delta["contentBlockIndex"],
            )
        if "toolUse" in content_block_delta["delta"]:
            return ActionCallDeltaMessageContent(
                index=content_block_delta["contentBlockIndex"],
                partial_input=content_block_delta["delta"]["toolUse"]["input"],
            )
        raise ValueError(f"Unsupported message content type: {event}")
    raise ValueError(f"Unsupported message content type: {event}")

__to_prompt_stack_message_content(content)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def __to_prompt_stack_message_content(self, content: dict) -> BaseMessageContent:
    if "text" in content:
        return TextMessageContent(TextArtifact(content["text"]))
    if "toolUse" in content:
        name, path = ToolAction.from_native_tool_name(content["toolUse"]["name"])
        return ActionCallMessageContent(
            artifact=ActionArtifact(
                value=ToolAction(
                    tag=content["toolUse"]["toolUseId"],
                    name=name,
                    path=path,
                    input=content["toolUse"]["input"],
                ),
            ),
        )
    if "reasoningContent" in content:
        return TextMessageContent(TextArtifact(content["reasoningContent"]["reasoningText"]["text"]))
    raise ValueError(f"Unsupported message content type: {content}")

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    system_messages = [{"text": message.to_text()} for message in prompt_stack.system_messages]
    messages = self.__to_bedrock_messages([message for message in prompt_stack.messages if not message.is_system()])

    params = {
        "modelId": self.model,
        "messages": messages,
        "system": system_messages,
        "inferenceConfig": {
            "temperature": self.temperature,
            **({"maxTokens": self.max_tokens} if self.max_tokens is not None else {}),
        },
        "additionalModelRequestFields": self.additional_model_request_fields,
        **self.extra_params,
    }

    if prompt_stack.tools and self.use_native_tools:
        params["toolConfig"] = {
            "tools": [],
            "toolChoice": self.tool_choice,
        }

        if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
            params["toolConfig"]["toolChoice"] = {"any": {}}

        params["toolConfig"]["tools"] = self.__to_bedrock_tools(prompt_stack.tools)

    return params

client()

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@lazy_property()
def client(self) -> Any:
    return self.session.client("bedrock-runtime")

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    params = self._base_params(prompt_stack)
    logger.debug(params)
    response = self.client.converse(**params)
    logger.debug(response)

    usage = response["usage"]
    output_message = response["output"]["message"]

    message_content = output_message.get("content", [])

    # Move reasoning content to the beginning of the content to have it appear first in output
    reasoning_content_index = next(
        (i for i, content in enumerate(message_content) if "reasoningContent" in content),
        None,
    )
    if reasoning_content_index is not None:
        message_content.insert(0, message_content.pop(reasoning_content_index))

    return Message(
        content=[self.__to_prompt_stack_message_content(content) for content in message_content],
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(input_tokens=usage["inputTokens"], output_tokens=usage["outputTokens"]),
    )

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    params = self._base_params(prompt_stack)
    logger.debug(params)
    response = self.client.converse_stream(**params)

    stream = response.get("stream")
    if stream is not None:
        for event in stream:
            logger.debug(event)
            if "contentBlockDelta" in event or "contentBlockStart" in event:
                yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
            elif "metadata" in event:
                usage = event["metadata"]["usage"]
                yield DeltaMessage(
                    usage=DeltaMessage.Usage(
                        input_tokens=usage["inputTokens"],
                        output_tokens=usage["outputTokens"],
                    ),
                )
    else:
        raise Exception("model response is empty")

validatestructured_output_strategy(, value)

Source Code in griptape/drivers/prompt/amazon_bedrock_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value == "native":
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value

AmazonBedrockTitanEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes

NameTypeDescription
modelstrEmbedding model name. Defaults to DEFAULT_MODEL.
tokenizerBaseTokenizerOptionally provide custom BedrockTitanTokenizer.
sessionSessionOptionally provide custom boto3.Session.
clientBedrockRuntimeClientOptionally provide custom bedrock-runtime client.
Source Code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@define
class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
    """Amazon Bedrock Titan Embedding Driver.

    Attributes:
        model: Embedding model name. Defaults to DEFAULT_MODEL.
        tokenizer: Optionally provide custom `BedrockTitanTokenizer`.
        session: Optionally provide custom `boto3.Session`.
        client: Optionally provide custom `bedrock-runtime` client.
    """

    DEFAULT_MODEL = "amazon.titan-embed-text-v1"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    _client: Optional[BedrockRuntimeClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> BedrockRuntimeClient:
        return self.session.client("bedrock-runtime")

    def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact, **kwargs) -> list[float]:
        if isinstance(artifact, TextArtifact):
            return self.try_embed_chunk(artifact.value)
        return self._invoke_model({"inputImage": base64.b64encode(artifact.value).decode()})["embedding"]

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        return self._invoke_model(
            {
                "inputText": chunk,
            }
        )["embedding"]

    def _invoke_model(self, payload: dict) -> dict[str, Any]:
        response = self.client.invoke_model(
            body=json.dumps(payload),
            modelId=self.model,
            accept="application/json",
            contentType="application/json",
        )
        return json.loads(response.get("body").read())
  • DEFAULT_MODEL = 'amazon.titan-embed-text-v1' class-attribute instance-attribute

  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • model = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

_invoke_model(payload)

Source Code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def _invoke_model(self, payload: dict) -> dict[str, Any]:
    response = self.client.invoke_model(
        body=json.dumps(payload),
        modelId=self.model,
        accept="application/json",
        contentType="application/json",
    )
    return json.loads(response.get("body").read())

client()

Source Code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
@lazy_property()
def client(self) -> BedrockRuntimeClient:
    return self.session.client("bedrock-runtime")

try_embed_artifact(artifact, **kwargs)

Source Code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact, **kwargs) -> list[float]:
    if isinstance(artifact, TextArtifact):
        return self.try_embed_chunk(artifact.value)
    return self._invoke_model({"inputImage": base64.b64encode(artifact.value).decode()})["embedding"]

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    return self._invoke_model(
        {
            "inputText": chunk,
        }
    )["embedding"]

AmazonDynamoDbConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Source Code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
@define
class AmazonDynamoDbConversationMemoryDriver(BaseConversationMemoryDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    table_name: str = field(kw_only=True, metadata={"serializable": True})
    partition_key: str = field(kw_only=True, metadata={"serializable": True})
    value_attribute_key: str = field(kw_only=True, metadata={"serializable": True})
    partition_key_value: str = field(kw_only=True, metadata={"serializable": True})
    sort_key: Optional[str] = field(default=None, metadata={"serializable": True})
    sort_key_value: Optional[str | int] = field(default=None, metadata={"serializable": True})
    _table: Optional[Table] = field(default=None, kw_only=True, alias="table", metadata={"serializable": False})

    @lazy_property()
    def table(self) -> Table:
        return self.session.resource("dynamodb").Table(self.table_name)

    def store(self, runs: list[Run], metadata: dict) -> None:
        self.table.update_item(
            Key=self._get_key(),
            UpdateExpression="set #attr = :value",
            ExpressionAttributeNames={"#attr": self.value_attribute_key},
            ExpressionAttributeValues={
                ":value": json.dumps(self._to_params_dict(runs, metadata)),
            },
        )

    def load(self) -> tuple[list[Run], dict[str, Any]]:
        response = self.table.get_item(Key=self._get_key())

        if "Item" in response and self.value_attribute_key in response["Item"]:
            memory_dict = json.loads(str(response["Item"][self.value_attribute_key]))
            return self._from_params_dict(memory_dict)
        return [], {}

    def _get_key(self) -> dict[str, str | int]:
        key: dict[str, str | int] = {self.partition_key: self.partition_key_value}

        if self.sort_key is not None and self.sort_key_value is not None:
            key[self.sort_key] = self.sort_key_value

        return key
  • _table = field(default=None, kw_only=True, alias='table', metadata={'serializable': False}) class-attribute instance-attribute

  • partition_key = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • partition_key_value = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

  • sort_key = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

  • sort_key_value = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

  • table_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • value_attribute_key = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

_get_key()

Source Code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def _get_key(self) -> dict[str, str | int]:
    key: dict[str, str | int] = {self.partition_key: self.partition_key_value}

    if self.sort_key is not None and self.sort_key_value is not None:
        key[self.sort_key] = self.sort_key_value

    return key

load()

Source Code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]:
    response = self.table.get_item(Key=self._get_key())

    if "Item" in response and self.value_attribute_key in response["Item"]:
        memory_dict = json.loads(str(response["Item"][self.value_attribute_key]))
        return self._from_params_dict(memory_dict)
    return [], {}

store(runs, metadata)

Source Code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict) -> None:
    self.table.update_item(
        Key=self._get_key(),
        UpdateExpression="set #attr = :value",
        ExpressionAttributeNames={"#attr": self.value_attribute_key},
        ExpressionAttributeValues={
            ":value": json.dumps(self._to_params_dict(runs, metadata)),
        },
    )

table()

Source Code in griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py
@lazy_property()
def table(self) -> Table:
    return self.session.resource("dynamodb").Table(self.table_name)

AmazonOpenSearchVectorStoreDriver

Bases: OpenSearchVectorStoreDriver

Attributes

NameTypeDescription
sessionSessionThe boto3 session to use.
servicestrService name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'.
http_authstr | tuple[str, str]The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.
clientOpenSearchAn optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.
Source Code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
@define
class AmazonOpenSearchVectorStoreDriver(OpenSearchVectorStoreDriver):
    """A Vector Store Driver for Amazon OpenSearch.

    Attributes:
        session: The boto3 session to use.
        service: Service name for AWS Signature v4. Values can be 'es' or 'aoss' for for OpenSearch Serverless. Defaults to 'es'.
        http_auth: The HTTP authentication credentials to use. Defaults to using credentials in the boto3 session.
        client: An optional OpenSearch client to use. Defaults to a new client using the host, port, http_auth, use_ssl, and verify_certs attributes.
    """

    session: Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    service: str = field(default="es", kw_only=True)
    http_auth: str | tuple[str, str] = field(
        default=Factory(
            lambda self: import_optional_dependency("opensearchpy").AWSV4SignerAuth(
                self.session.get_credentials(),
                self.session.region_name,
                self.service,
            ),
            takes_self=True,
        ),
    )

    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in OpenSearch.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        Metadata associated with the vector can also be provided.
        """
        vector_id = vector_id or str_to_hash(str(vector))
        doc = {"vector": vector, "namespace": namespace, "metadata": meta}
        doc.update(kwargs)
        if self.service == "aoss":
            response = self.client.index(index=self.index_name, body=doc)
        else:
            response = self.client.index(index=self.index_name, id=vector_id, body=doc)

        return response["_id"]
  • http_auth = field(default=Factory(lambda self: import_optional_dependency('opensearchpy').AWSV4SignerAuth(self.session.get_credentials(), self.session.region_name, self.service), takes_self=True)) class-attribute instance-attribute

  • service = field(default='es', kw_only=True) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/amazon_opensearch_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in OpenSearch.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    Metadata associated with the vector can also be provided.
    """
    vector_id = vector_id or str_to_hash(str(vector))
    doc = {"vector": vector, "namespace": namespace, "metadata": meta}
    doc.update(kwargs)
    if self.service == "aoss":
        response = self.client.index(index=self.index_name, body=doc)
    else:
        response = self.client.index(index=self.index_name, id=vector_id, body=doc)

    return response["_id"]

AmazonRedshiftSqlDriver

Bases: BaseSqlDriver

Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@define
class AmazonRedshiftSqlDriver(BaseSqlDriver):
    database: str = field(kw_only=True)
    session: boto3.Session = field(kw_only=True)
    cluster_identifier: Optional[str] = field(default=None, kw_only=True)
    workgroup_name: Optional[str] = field(default=None, kw_only=True)
    db_user: Optional[str] = field(default=None, kw_only=True)
    database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True)
    wait_for_query_completion_sec: float = field(default=0.3, kw_only=True)
    _client: Optional[RedshiftDataAPIServiceClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> RedshiftDataAPIServiceClient:
        return self.session.client("redshift-data")

    @workgroup_name.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None:
        if not self.cluster_identifier and not self.workgroup_name:
            raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
        if self.cluster_identifier and self.workgroup_name:
            raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

    @classmethod
    def _process_rows_from_records(cls, records: list) -> list[list]:
        return [[c[list(c.keys())[0]] for c in r] for r in records]

    @classmethod
    def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]:
        return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows]

    @classmethod
    def _process_columns_from_column_metadata(cls, meta: dict) -> list:
        return [k["name"] for k in meta]

    @classmethod
    def _post_process(cls, meta: dict, records: list) -> list[dict[str, Any]]:
        columns = cls._process_columns_from_column_metadata(meta)
        rows = cls._process_rows_from_records(records)
        return cls._process_cells_from_rows_and_columns(columns, rows)

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)
        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
        function_kwargs = {"Sql": query, "Database": self.database}
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn

        response = self.client.execute_statement(**function_kwargs)  # pyright: ignore[reportArgumentType]
        response_id = response["Id"]

        statement = self.client.describe_statement(Id=response_id)

        while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
            time.sleep(self.wait_for_query_completion_sec)
            statement = self.client.describe_statement(Id=response_id)

        if statement["Status"] == "FINISHED":
            statement_result = self.client.get_statement_result(Id=response_id)
            results = statement_result.get("Records", [])

            while "NextToken" in statement_result:
                statement_result = self.client.get_statement_result(
                    Id=response_id,
                    NextToken=statement_result["NextToken"],
                )
                results = results + response.get("Records", [])

            return self._post_process(statement_result["ColumnMetadata"], results)  # pyright: ignore[reportArgumentType]

        if statement["Status"] in ["FAILED", "ABORTED"]:
            return None
        return None

    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        function_kwargs = {"Database": self.database, "Table": table_name}
        if schema:
            function_kwargs["Schema"] = schema
        if self.workgroup_name:
            function_kwargs["WorkgroupName"] = self.workgroup_name
        if self.cluster_identifier:
            function_kwargs["ClusterIdentifier"] = self.cluster_identifier
        if self.db_user:
            function_kwargs["DbUser"] = self.db_user
        if self.database_credentials_secret_arn:
            function_kwargs["SecretArn"] = self.database_credentials_secret_arn
        response = self.client.describe_table(**function_kwargs)  # pyright: ignore[reportArgumentType]
        return str([col["name"] for col in response["ColumnList"] if "name" in col])
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • cluster_identifier = field(default=None, kw_only=True) class-attribute instance-attribute

  • database = field(kw_only=True) class-attribute instance-attribute

  • database_credentials_secret_arn = field(default=None, kw_only=True) class-attribute instance-attribute

  • db_user = field(default=None, kw_only=True) class-attribute instance-attribute

  • session = field(kw_only=True) class-attribute instance-attribute

  • wait_for_query_completion_sec = field(default=0.3, kw_only=True) class-attribute instance-attribute

  • workgroup_name = field(default=None, kw_only=True) class-attribute instance-attribute

_post_process(meta, records)classmethod

Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod
def _post_process(cls, meta: dict, records: list) -> list[dict[str, Any]]:
    columns = cls._process_columns_from_column_metadata(meta)
    rows = cls._process_rows_from_records(records)
    return cls._process_cells_from_rows_and_columns(columns, rows)

_process_cells_from_rows_and_columns(columns, rows)classmethod

Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod
def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]:
    return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows]

_process_columns_from_column_metadata(meta)classmethod

Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod
def _process_columns_from_column_metadata(cls, meta: dict) -> list:
    return [k["name"] for k in meta]

_process_rows_from_records(records)classmethod

Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod
def _process_rows_from_records(cls, records: list) -> list[list]:
    return [[c[list(c.keys())[0]] for c in r] for r in records]

client()

Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@lazy_property()
def client(self) -> RedshiftDataAPIServiceClient:
    return self.session.client("redshift-data")

execute_query(query)

Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)
    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    return None

execute_query_raw(query)

Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
    function_kwargs = {"Sql": query, "Database": self.database}
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn

    response = self.client.execute_statement(**function_kwargs)  # pyright: ignore[reportArgumentType]
    response_id = response["Id"]

    statement = self.client.describe_statement(Id=response_id)

    while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]:
        time.sleep(self.wait_for_query_completion_sec)
        statement = self.client.describe_statement(Id=response_id)

    if statement["Status"] == "FINISHED":
        statement_result = self.client.get_statement_result(Id=response_id)
        results = statement_result.get("Records", [])

        while "NextToken" in statement_result:
            statement_result = self.client.get_statement_result(
                Id=response_id,
                NextToken=statement_result["NextToken"],
            )
            results = results + response.get("Records", [])

        return self._post_process(statement_result["ColumnMetadata"], results)  # pyright: ignore[reportArgumentType]

    if statement["Status"] in ["FAILED", "ABORTED"]:
        return None
    return None

get_table_schema(table_name, schema=None)

Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    function_kwargs = {"Database": self.database, "Table": table_name}
    if schema:
        function_kwargs["Schema"] = schema
    if self.workgroup_name:
        function_kwargs["WorkgroupName"] = self.workgroup_name
    if self.cluster_identifier:
        function_kwargs["ClusterIdentifier"] = self.cluster_identifier
    if self.db_user:
        function_kwargs["DbUser"] = self.db_user
    if self.database_credentials_secret_arn:
        function_kwargs["SecretArn"] = self.database_credentials_secret_arn
    response = self.client.describe_table(**function_kwargs)  # pyright: ignore[reportArgumentType]
    return str([col["name"] for col in response["ColumnList"] if "name" in col])

validateparams(, workgroup_name)

Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@workgroup_name.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None:
    if not self.cluster_identifier and not self.workgroup_name:
        raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`")
    if self.cluster_identifier and self.workgroup_name:
        raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")

AmazonS3FileManagerDriver

Bases: BaseFileManagerDriver

Attributes

NameTypeDescription
sessionSessionThe boto3 session to use for S3 operations.
bucketstrThe name of the S3 bucket.
workdirstrThe absolute working directory (must start with "/"). List, load, and save operations will be performed relative to this directory.
clientS3ClientThe S3 client to use for S3 operations.
Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@define
class AmazonS3FileManagerDriver(BaseFileManagerDriver):
    """AmazonS3FileManagerDriver can be used to list, load, and save files in an Amazon S3 bucket.

    Attributes:
        session: The boto3 session to use for S3 operations.
        bucket: The name of the S3 bucket.
        workdir: The absolute working directory (must start with "/"). List, load, and save
            operations will be performed relative to this directory.
        client: The S3 client to use for S3 operations.
    """

    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    bucket: str = field(kw_only=True)
    _workdir: str = field(default="/", kw_only=True, alias="workdir")
    _client: Optional[S3Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @property
    def workdir(self) -> str:
        if self._workdir.startswith("/"):
            return self._workdir
        return f"/{self._workdir}"

    @workdir.setter
    def workdir(self, value: str) -> None:
        self._workdir = value

    @lazy_property()
    def client(self) -> S3Client:
        return self.session.client("s3")

    def try_list_files(self, path: str) -> list[str]:
        full_key = self._to_dir_full_key(path)
        files_and_dirs = self._list_files_and_dirs(full_key)
        if len(files_and_dirs) == 0:
            if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
                raise NotADirectoryError
            raise FileNotFoundError
        return files_and_dirs

    def try_load_file(self, path: str) -> bytes:
        botocore = import_optional_dependency("botocore")
        full_key = self._to_full_key(path)

        if self._is_a_directory(full_key):
            raise IsADirectoryError

        try:
            response = self.client.get_object(Bucket=self.bucket, Key=full_key)
            return response["Body"].read()
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
                raise FileNotFoundError from e
            raise e

    def try_save_file(self, path: str, value: bytes) -> str:
        full_key = self._to_full_key(path)
        if self._is_a_directory(full_key):
            raise IsADirectoryError
        self.client.put_object(Bucket=self.bucket, Key=full_key, Body=value)
        return f"s3://{self.bucket}/{full_key}"

    def _to_full_key(self, path: str) -> str:
        path = path.lstrip("/")
        full_key = f"{self.workdir}/{path}"
        # Need to keep the trailing slash if it was there,
        # because it means the path is a directory.
        ended_with_slash = path.endswith("/")

        full_key = self._normpath(full_key)

        if ended_with_slash:
            full_key += "/"
        return full_key.lstrip("/")

    def _to_dir_full_key(self, path: str) -> str:
        full_key = self._to_full_key(path)
        # S3 "directories" always end with a slash, except for the root.
        if full_key != "" and not full_key.endswith("/"):
            full_key += "/"
        return full_key

    def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]:
        max_items = kwargs.get("max_items")
        pagination_config: PaginatorConfigTypeDef = {}
        if max_items is not None:
            pagination_config["MaxItems"] = max_items

        paginator = self.client.get_paginator("list_objects_v2")
        pages = paginator.paginate(
            Bucket=self.bucket,
            Prefix=full_key,
            Delimiter="/",
            PaginationConfig=pagination_config,
        )
        files_and_dirs = []
        for page in pages:
            for obj in page.get("CommonPrefixes", []):
                prefix = obj.get("Prefix", "")
                directory = prefix[len(full_key) :].rstrip("/")
                files_and_dirs.append(directory)

            for obj in page.get("Contents", []):
                key = obj.get("Key", "")
                file = key[len(full_key) :]
                files_and_dirs.append(file)
        return files_and_dirs

    def _is_a_directory(self, full_key: str) -> bool:
        botocore = import_optional_dependency("botocore")
        if full_key == "" or full_key.endswith("/"):
            return True

        try:
            self.client.head_object(Bucket=self.bucket, Key=full_key)
        except botocore.exceptions.ClientError as e:
            if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
                return len(self._list_files_and_dirs(full_key, max_items=1)) > 0
            raise e

        return False

    def _normpath(self, path: str) -> str:
        unix_path = path.replace("\\", "/")
        parts = unix_path.split("/")
        stack = []

        for part in parts:
            if part == "" or part == ".":
                continue
            if part == "..":
                if stack:
                    stack.pop()
            else:
                stack.append(part)

        return "/".join(stack)
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • _workdir = field(default='/', kw_only=True, alias='workdir') class-attribute instance-attribute

  • bucket = field(kw_only=True) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

  • workdir property writable

_is_a_directory(full_key)

Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def _is_a_directory(self, full_key: str) -> bool:
    botocore = import_optional_dependency("botocore")
    if full_key == "" or full_key.endswith("/"):
        return True

    try:
        self.client.head_object(Bucket=self.bucket, Key=full_key)
    except botocore.exceptions.ClientError as e:
        if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
            return len(self._list_files_and_dirs(full_key, max_items=1)) > 0
        raise e

    return False

_list_files_and_dirs(full_key, **kwargs)

Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]:
    max_items = kwargs.get("max_items")
    pagination_config: PaginatorConfigTypeDef = {}
    if max_items is not None:
        pagination_config["MaxItems"] = max_items

    paginator = self.client.get_paginator("list_objects_v2")
    pages = paginator.paginate(
        Bucket=self.bucket,
        Prefix=full_key,
        Delimiter="/",
        PaginationConfig=pagination_config,
    )
    files_and_dirs = []
    for page in pages:
        for obj in page.get("CommonPrefixes", []):
            prefix = obj.get("Prefix", "")
            directory = prefix[len(full_key) :].rstrip("/")
            files_and_dirs.append(directory)

        for obj in page.get("Contents", []):
            key = obj.get("Key", "")
            file = key[len(full_key) :]
            files_and_dirs.append(file)
    return files_and_dirs

_normpath(path)

Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def _normpath(self, path: str) -> str:
    unix_path = path.replace("\\", "/")
    parts = unix_path.split("/")
    stack = []

    for part in parts:
        if part == "" or part == ".":
            continue
        if part == "..":
            if stack:
                stack.pop()
        else:
            stack.append(part)

    return "/".join(stack)

_to_dir_full_key(path)

Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def _to_dir_full_key(self, path: str) -> str:
    full_key = self._to_full_key(path)
    # S3 "directories" always end with a slash, except for the root.
    if full_key != "" and not full_key.endswith("/"):
        full_key += "/"
    return full_key

_to_full_key(path)

Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def _to_full_key(self, path: str) -> str:
    path = path.lstrip("/")
    full_key = f"{self.workdir}/{path}"
    # Need to keep the trailing slash if it was there,
    # because it means the path is a directory.
    ended_with_slash = path.endswith("/")

    full_key = self._normpath(full_key)

    if ended_with_slash:
        full_key += "/"
    return full_key.lstrip("/")

client()

Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
@lazy_property()
def client(self) -> S3Client:
    return self.session.client("s3")

try_list_files(path)

Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_list_files(self, path: str) -> list[str]:
    full_key = self._to_dir_full_key(path)
    files_and_dirs = self._list_files_and_dirs(full_key)
    if len(files_and_dirs) == 0:
        if len(self._list_files_and_dirs(full_key.rstrip("/"), max_items=1)) > 0:
            raise NotADirectoryError
        raise FileNotFoundError
    return files_and_dirs

try_load_file(path)

Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_load_file(self, path: str) -> bytes:
    botocore = import_optional_dependency("botocore")
    full_key = self._to_full_key(path)

    if self._is_a_directory(full_key):
        raise IsADirectoryError

    try:
        response = self.client.get_object(Bucket=self.bucket, Key=full_key)
        return response["Body"].read()
    except botocore.exceptions.ClientError as e:
        if e.response["Error"]["Code"] in {"NoSuchKey", "404"}:
            raise FileNotFoundError from e
        raise e

try_save_file(path, value)

Source Code in griptape/drivers/file_manager/amazon_s3_file_manager_driver.py
def try_save_file(self, path: str, value: bytes) -> str:
    full_key = self._to_full_key(path)
    if self._is_a_directory(full_key):
        raise IsADirectoryError
    self.client.put_object(Bucket=self.bucket, Key=full_key, Body=value)
    return f"s3://{self.bucket}/{full_key}"

AmazonSageMakerJumpstartEmbeddingDriver

Bases: BaseEmbeddingDriver

Source Code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
@define
class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    endpoint: str = field(kw_only=True, metadata={"serializable": True})
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    _client: Optional[SageMakerRuntimeClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> SageMakerRuntimeClient:
        return self.session.client("sagemaker-runtime")

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        payload = {"text_inputs": chunk, "mode": "embedding"}

        endpoint_response = self.client.invoke_endpoint(
            EndpointName=self.endpoint,
            ContentType="application/json",
            Body=json.dumps(payload).encode("utf-8"),
            CustomAttributes=self.custom_attributes,
            **(
                {"InferenceComponentName": self.inference_component_name}
                if self.inference_component_name is not None
                else {}
            ),
        )

        response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))

        if "embedding" in response:
            embedding = response["embedding"]

            if embedding:
                if isinstance(embedding[0], list):
                    return embedding[0]
                return embedding
            raise ValueError("model response is empty")
        raise ValueError("invalid response from model")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • custom_attributes = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • endpoint = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • inference_component_name = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

client()

Source Code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
@lazy_property()
def client(self) -> SageMakerRuntimeClient:
    return self.session.client("sagemaker-runtime")

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    payload = {"text_inputs": chunk, "mode": "embedding"}

    endpoint_response = self.client.invoke_endpoint(
        EndpointName=self.endpoint,
        ContentType="application/json",
        Body=json.dumps(payload).encode("utf-8"),
        CustomAttributes=self.custom_attributes,
        **(
            {"InferenceComponentName": self.inference_component_name}
            if self.inference_component_name is not None
            else {}
        ),
    )

    response = json.loads(endpoint_response.get("Body").read().decode("utf-8"))

    if "embedding" in response:
        embedding = response["embedding"]

        if embedding:
            if isinstance(embedding[0], list):
                return embedding[0]
            return embedding
        raise ValueError("model response is empty")
    raise ValueError("invalid response from model")

AmazonSageMakerJumpstartPromptDriver

Bases: BasePromptDriver

Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@define
class AmazonSageMakerJumpstartPromptDriver(BasePromptDriver):
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    endpoint: str = field(kw_only=True, metadata={"serializable": True})
    custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
    inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens),
            takes_self=True,
        ),
        kw_only=True,
    )
    structured_output_strategy: StructuredOutputStrategy = field(
        default="rule", kw_only=True, metadata={"serializable": True}
    )
    _client: Optional[Any] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value != "rule":
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @lazy_property()
    def client(self) -> Any:
        return self.session.client("sagemaker-runtime")

    @stream.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_stream(self, _: Attribute, stream: bool) -> None:  # noqa: FBT001
        if stream:
            raise ValueError("streaming is not supported")

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        payload = {
            "inputs": self.prompt_stack_to_string(prompt_stack),
            "parameters": {**self._base_params(prompt_stack)},
        }
        logger.debug(payload)

        response = self.client.invoke_endpoint(
            EndpointName=self.endpoint,
            ContentType="application/json",
            Body=json.dumps(payload),
            CustomAttributes=self.custom_attributes,
            **(
                {"InferenceComponentName": self.inference_component_name}
                if self.inference_component_name is not None
                else {}
            ),
        )

        decoded_body = json.loads(response["Body"].read().decode("utf8"))
        logger.debug(decoded_body)

        if isinstance(decoded_body, list):
            if decoded_body:
                generated_text = decoded_body[0]["generated_text"]
            else:
                raise ValueError("model response is empty")
        else:
            generated_text = decoded_body["generated_text"]

        input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
        output_tokens = len(self.tokenizer.tokenizer.encode(generated_text))  # pyright: ignore[reportArgumentType]

        return Message(
            content=[TextMessageContent(TextArtifact(generated_text))],
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        raise NotImplementedError("streaming is not supported")

    def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
        return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))  # pyright: ignore[reportArgumentType]

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "temperature": self.temperature,
            "max_new_tokens": self.max_tokens,
            "do_sample": True,
            "eos_token_id": self.tokenizer.tokenizer.eos_token_id,
            "stop_strings": self.tokenizer.stop_sequences,
            "return_full_text": False,
            **self.extra_params,
        }

    def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
        messages = []

        for message in prompt_stack.messages:
            messages.append({"role": message.role, "content": message.to_text()})

        return messages

    def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
        messages = self._prompt_stack_to_messages(prompt_stack)

        tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

        if isinstance(tokens, list):
            return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
        raise ValueError("Invalid output type.")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • custom_attributes = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • endpoint = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • inference_component_name = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • max_tokens = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

  • stream = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • structured_output_strategy = field(default='rule', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

__prompt_stack_to_tokens(prompt_stack)

Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
    messages = self._prompt_stack_to_messages(prompt_stack)

    tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

    if isinstance(tokens, list):
        return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
    raise ValueError("Invalid output type.")

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    return {
        "temperature": self.temperature,
        "max_new_tokens": self.max_tokens,
        "do_sample": True,
        "eos_token_id": self.tokenizer.tokenizer.eos_token_id,
        "stop_strings": self.tokenizer.stop_sequences,
        "return_full_text": False,
        **self.extra_params,
    }

_prompt_stack_to_messages(prompt_stack)

Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
    messages = []

    for message in prompt_stack.messages:
        messages.append({"role": message.role, "content": message.to_text()})

    return messages

client()

Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@lazy_property()
def client(self) -> Any:
    return self.session.client("sagemaker-runtime")

prompt_stack_to_string(prompt_stack)

Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
    return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))  # pyright: ignore[reportArgumentType]

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    payload = {
        "inputs": self.prompt_stack_to_string(prompt_stack),
        "parameters": {**self._base_params(prompt_stack)},
    }
    logger.debug(payload)

    response = self.client.invoke_endpoint(
        EndpointName=self.endpoint,
        ContentType="application/json",
        Body=json.dumps(payload),
        CustomAttributes=self.custom_attributes,
        **(
            {"InferenceComponentName": self.inference_component_name}
            if self.inference_component_name is not None
            else {}
        ),
    )

    decoded_body = json.loads(response["Body"].read().decode("utf8"))
    logger.debug(decoded_body)

    if isinstance(decoded_body, list):
        if decoded_body:
            generated_text = decoded_body[0]["generated_text"]
        else:
            raise ValueError("model response is empty")
    else:
        generated_text = decoded_body["generated_text"]

    input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
    output_tokens = len(self.tokenizer.tokenizer.encode(generated_text))  # pyright: ignore[reportArgumentType]

    return Message(
        content=[TextMessageContent(TextArtifact(generated_text))],
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
    )

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    raise NotImplementedError("streaming is not supported")

validatestream(, stream)

Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@stream.validator  # pyright: ignore[reportAttributeAccessIssue]
def validate_stream(self, _: Attribute, stream: bool) -> None:  # noqa: FBT001
    if stream:
        raise ValueError("streaming is not supported")

validatestructured_output_strategy(, value)

Source Code in griptape/drivers/prompt/amazon_sagemaker_jumpstart_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value != "rule":
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value

AmazonSqsEventListenerDriver

Bases: BaseEventListenerDriver

Source Code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
@define
class AmazonSqsEventListenerDriver(BaseEventListenerDriver):
    queue_url: str = field(kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    _client: Optional[SQSClient] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> SQSClient:
        return self.session.client("sqs")

    def try_publish_event_payload(self, event_payload: dict) -> None:
        self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

    def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
        entries: Sequence[SendMessageBatchRequestEntryTypeDef] = [
            {"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)}
            for event_payload in event_payload_batch
        ]

        self.client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • queue_url = field(kw_only=True) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

client()

Source Code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
@lazy_property()
def client(self) -> SQSClient:
    return self.session.client("sqs")

try_publish_event_payload(event_payload)

Source Code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

try_publish_event_payload_batch(event_payload_batch)

Source Code in griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
    entries: Sequence[SendMessageBatchRequestEntryTypeDef] = [
        {"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)}
        for event_payload in event_payload_batch
    ]

    self.client.send_message_batch(QueueUrl=self.queue_url, Entries=entries)

AnthropicPromptDriver

Bases: BasePromptDriver

Attributes

NameTypeDescription
api_keyOptional[str]Anthropic API key.
modelstrAnthropic model name.
clientClientCustom Anthropic client.
Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
@define
class AnthropicPromptDriver(BasePromptDriver):
    """Anthropic Prompt Driver.

    Attributes:
        api_key: Anthropic API key.
        model: Anthropic model name.
        client: Custom `Anthropic` client.
    """

    api_key: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    top_p: float = field(default=0.999, kw_only=True, metadata={"serializable": True})
    top_k: int = field(default=250, kw_only=True, metadata={"serializable": True})
    tool_choice: dict = field(default=Factory(lambda: {"type": "auto"}), kw_only=True, metadata={"serializable": False})
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    structured_output_strategy: StructuredOutputStrategy = field(
        default="tool", kw_only=True, metadata={"serializable": True}
    )
    max_tokens: int = field(default=1000, kw_only=True, metadata={"serializable": True})
    _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> Client:
        return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value == "native":
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        params = self._base_params(prompt_stack)
        logger.debug(params)
        response = self.client.messages.create(**params)

        logger.debug(response.model_dump())

        return Message(
            content=[self.__to_prompt_stack_message_content(content) for content in response.content],
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        params = {**self._base_params(prompt_stack), "stream": True}
        logger.debug(params)
        events = self.client.messages.create(**params)

        for event in events:
            logger.debug(event)
            if event.type == "content_block_delta" or event.type == "content_block_start":
                yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
            elif event.type == "message_start":
                yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=event.message.usage.input_tokens))
            elif event.type == "message_delta":
                yield DeltaMessage(usage=DeltaMessage.Usage(output_tokens=event.usage.output_tokens))

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        messages = self.__to_anthropic_messages([i for i in prompt_stack.messages if not i.is_system()])

        system_messages = prompt_stack.system_messages
        system_message = system_messages[0].to_text() if system_messages else None

        params = {
            "model": self.model,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "max_tokens": self.max_tokens,
            "messages": messages,
            **({"system": system_message} if system_message else {}),
            **self.extra_params,
        }

        if prompt_stack.tools and self.use_native_tools:
            params["tool_choice"] = self.tool_choice

            if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
                params["tool_choice"] = {"type": "any"}

            params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)

        return params

    def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
        return [
            {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}
            for message in messages
        ]

    def __to_anthropic_role(self, message: Message) -> str:
        if message.is_assistant():
            return "assistant"
        return "user"

    def __to_anthropic_tools(self, tools: list[BaseTool]) -> list[dict]:
        tool_schemas = [
            {
                "name": tool.to_native_tool_name(activity),
                "description": tool.activity_description(activity),
                "input_schema": tool.to_activity_json_schema(activity, "Input Schema"),
            }
            for tool in tools
            for activity in tool.activities()
        ]

        # Anthropic doesn't support $schema and $id
        for tool_schema in tool_schemas:
            del tool_schema["input_schema"]["$schema"]
            del tool_schema["input_schema"]["$id"]

        return tool_schemas

    def __to_anthropic_content(self, message: Message) -> str | list[dict]:
        if message.has_all_content_type(TextMessageContent):
            return message.to_text()
        return [self.__to_anthropic_message_content(content) for content in message.content]

    def __to_anthropic_message_content(self, content: BaseMessageContent) -> dict:
        if isinstance(content, TextMessageContent):
            return {"type": "text", "text": content.artifact.value}
        if isinstance(content, ImageMessageContent):
            artifact = content.artifact
            if isinstance(artifact, ImageArtifact):
                return {
                    "type": "image",
                    "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64},
                }
            if isinstance(artifact, ImageUrlArtifact):
                return {
                    "type": "image",
                    "source": {"type": "url", "url": artifact.value},
                }
            raise ValueError(f"Unsupported image artifact type: {type(artifact)}")
        if isinstance(content, ActionCallMessageContent):
            action = content.artifact.value

            return {"type": "tool_use", "id": action.tag, "name": action.to_native_tool_name(), "input": action.input}
        if isinstance(content, ActionResultMessageContent):
            artifact = content.artifact

            if isinstance(artifact, ListArtifact):
                message_content = [self.__to_anthropic_tool_result_content(artifact) for artifact in artifact.value]
            else:
                message_content = [self.__to_anthropic_tool_result_content(artifact)]

            return {
                "type": "tool_result",
                "tool_use_id": content.action.tag,
                "content": message_content,
                "is_error": isinstance(artifact, ErrorArtifact),
            }
        return content.artifact.value

    def __to_anthropic_tool_result_content(self, artifact: BaseArtifact) -> dict:
        if isinstance(artifact, ImageArtifact):
            return {
                "type": "image",
                "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64},
            }
        if isinstance(artifact, (TextArtifact, ErrorArtifact, InfoArtifact)):
            return {"type": "text", "text": artifact.to_text()}
        raise ValueError(f"Unsupported tool result artifact type: {type(artifact)}")

    def __to_prompt_stack_message_content(self, content: ContentBlock) -> BaseMessageContent:
        if content.type == "text":
            return TextMessageContent(TextArtifact(content.text))
        if content.type == "tool_use":
            name, path = ToolAction.from_native_tool_name(content.name)

            return ActionCallMessageContent(
                artifact=ActionArtifact(
                    value=ToolAction(tag=content.id, name=name, path=path, input=content.input),  # pyright: ignore[reportArgumentType]
                ),
            )
        raise ValueError(f"Unsupported message content type: {content.type}")

    def __to_prompt_stack_delta_message_content(
        self,
        event: ContentBlockDeltaEvent | ContentBlockStartEvent,
    ) -> BaseDeltaMessageContent:
        if event.type == "content_block_start":
            content_block = event.content_block

            if content_block.type == "tool_use":
                name, path = ToolAction.from_native_tool_name(content_block.name)

                return ActionCallDeltaMessageContent(index=event.index, tag=content_block.id, name=name, path=path)
            if content_block.type == "text":
                return TextDeltaMessageContent(content_block.text, index=event.index)
            raise ValueError(f"Unsupported content block type: {content_block.type}")
        if event.type == "content_block_delta":
            content_block_delta = event.delta

            if content_block_delta.type == "text_delta":
                return TextDeltaMessageContent(content_block_delta.text, index=event.index)
            if content_block_delta.type == "input_json_delta":
                return ActionCallDeltaMessageContent(index=event.index, partial_input=content_block_delta.partial_json)
            raise ValueError(f"Unsupported message content type: {event}")
        raise ValueError(f"Unsupported message content type: {event}")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

  • max_tokens = field(default=1000, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • structured_output_strategy = field(default='tool', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: AnthropicTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

  • tool_choice = field(default=Factory(lambda: {'type': 'auto'}), kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • top_k = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • top_p = field(default=0.999, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • use_native_tools = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_anthropic_content(message)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_content(self, message: Message) -> str | list[dict]:
    if message.has_all_content_type(TextMessageContent):
        return message.to_text()
    return [self.__to_anthropic_message_content(content) for content in message.content]

__to_anthropic_message_content(content)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_message_content(self, content: BaseMessageContent) -> dict:
    if isinstance(content, TextMessageContent):
        return {"type": "text", "text": content.artifact.value}
    if isinstance(content, ImageMessageContent):
        artifact = content.artifact
        if isinstance(artifact, ImageArtifact):
            return {
                "type": "image",
                "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64},
            }
        if isinstance(artifact, ImageUrlArtifact):
            return {
                "type": "image",
                "source": {"type": "url", "url": artifact.value},
            }
        raise ValueError(f"Unsupported image artifact type: {type(artifact)}")
    if isinstance(content, ActionCallMessageContent):
        action = content.artifact.value

        return {"type": "tool_use", "id": action.tag, "name": action.to_native_tool_name(), "input": action.input}
    if isinstance(content, ActionResultMessageContent):
        artifact = content.artifact

        if isinstance(artifact, ListArtifact):
            message_content = [self.__to_anthropic_tool_result_content(artifact) for artifact in artifact.value]
        else:
            message_content = [self.__to_anthropic_tool_result_content(artifact)]

        return {
            "type": "tool_result",
            "tool_use_id": content.action.tag,
            "content": message_content,
            "is_error": isinstance(artifact, ErrorArtifact),
        }
    return content.artifact.value

__to_anthropic_messages(messages)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_messages(self, messages: list[Message]) -> list[dict]:
    return [
        {"role": self.__to_anthropic_role(message), "content": self.__to_anthropic_content(message)}
        for message in messages
    ]

__to_anthropic_role(message)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_role(self, message: Message) -> str:
    if message.is_assistant():
        return "assistant"
    return "user"

__to_anthropic_tool_result_content(artifact)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_tool_result_content(self, artifact: BaseArtifact) -> dict:
    if isinstance(artifact, ImageArtifact):
        return {
            "type": "image",
            "source": {"type": "base64", "media_type": artifact.mime_type, "data": artifact.base64},
        }
    if isinstance(artifact, (TextArtifact, ErrorArtifact, InfoArtifact)):
        return {"type": "text", "text": artifact.to_text()}
    raise ValueError(f"Unsupported tool result artifact type: {type(artifact)}")

__to_anthropic_tools(tools)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_anthropic_tools(self, tools: list[BaseTool]) -> list[dict]:
    tool_schemas = [
        {
            "name": tool.to_native_tool_name(activity),
            "description": tool.activity_description(activity),
            "input_schema": tool.to_activity_json_schema(activity, "Input Schema"),
        }
        for tool in tools
        for activity in tool.activities()
    ]

    # Anthropic doesn't support $schema and $id
    for tool_schema in tool_schemas:
        del tool_schema["input_schema"]["$schema"]
        del tool_schema["input_schema"]["$id"]

    return tool_schemas

__to_prompt_stack_delta_message_content(event)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_prompt_stack_delta_message_content(
    self,
    event: ContentBlockDeltaEvent | ContentBlockStartEvent,
) -> BaseDeltaMessageContent:
    if event.type == "content_block_start":
        content_block = event.content_block

        if content_block.type == "tool_use":
            name, path = ToolAction.from_native_tool_name(content_block.name)

            return ActionCallDeltaMessageContent(index=event.index, tag=content_block.id, name=name, path=path)
        if content_block.type == "text":
            return TextDeltaMessageContent(content_block.text, index=event.index)
        raise ValueError(f"Unsupported content block type: {content_block.type}")
    if event.type == "content_block_delta":
        content_block_delta = event.delta

        if content_block_delta.type == "text_delta":
            return TextDeltaMessageContent(content_block_delta.text, index=event.index)
        if content_block_delta.type == "input_json_delta":
            return ActionCallDeltaMessageContent(index=event.index, partial_input=content_block_delta.partial_json)
        raise ValueError(f"Unsupported message content type: {event}")
    raise ValueError(f"Unsupported message content type: {event}")

__to_prompt_stack_message_content(content)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def __to_prompt_stack_message_content(self, content: ContentBlock) -> BaseMessageContent:
    if content.type == "text":
        return TextMessageContent(TextArtifact(content.text))
    if content.type == "tool_use":
        name, path = ToolAction.from_native_tool_name(content.name)

        return ActionCallMessageContent(
            artifact=ActionArtifact(
                value=ToolAction(tag=content.id, name=name, path=path, input=content.input),  # pyright: ignore[reportArgumentType]
            ),
        )
    raise ValueError(f"Unsupported message content type: {content.type}")

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    messages = self.__to_anthropic_messages([i for i in prompt_stack.messages if not i.is_system()])

    system_messages = prompt_stack.system_messages
    system_message = system_messages[0].to_text() if system_messages else None

    params = {
        "model": self.model,
        "temperature": self.temperature,
        "stop_sequences": self.tokenizer.stop_sequences,
        "top_p": self.top_p,
        "top_k": self.top_k,
        "max_tokens": self.max_tokens,
        "messages": messages,
        **({"system": system_message} if system_message else {}),
        **self.extra_params,
    }

    if prompt_stack.tools and self.use_native_tools:
        params["tool_choice"] = self.tool_choice

        if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
            params["tool_choice"] = {"type": "any"}

        params["tools"] = self.__to_anthropic_tools(prompt_stack.tools)

    return params

client()

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
@lazy_property()
def client(self) -> Client:
    return import_optional_dependency("anthropic").Anthropic(api_key=self.api_key)

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    params = self._base_params(prompt_stack)
    logger.debug(params)
    response = self.client.messages.create(**params)

    logger.debug(response.model_dump())

    return Message(
        content=[self.__to_prompt_stack_message_content(content) for content in response.content],
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens),
    )

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    params = {**self._base_params(prompt_stack), "stream": True}
    logger.debug(params)
    events = self.client.messages.create(**params)

    for event in events:
        logger.debug(event)
        if event.type == "content_block_delta" or event.type == "content_block_start":
            yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))
        elif event.type == "message_start":
            yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=event.message.usage.input_tokens))
        elif event.type == "message_delta":
            yield DeltaMessage(usage=DeltaMessage.Usage(output_tokens=event.usage.output_tokens))

validatestructured_output_strategy(, value)

Source Code in griptape/drivers/prompt/anthropic_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value == "native":
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value

AstraDbVectorStoreDriver

Bases: BaseVectorStoreDriver

Attributes

NameTypeDescription
embedding_driverBaseEmbeddingDrivera griptape.drivers.BaseEmbeddingDriver for embedding computations within the store
api_endpointstrthe "API Endpoint" for the Astra DB instance.
tokenOptional[str | TokenProvider]a Database Token ("AstraCS:...") secret to access Astra DB. An instance of astrapy.authentication.TokenProvider is also accepted.
collection_namestrthe name of the collection on Astra DB. The collection must have been created beforehand, and support vectors with a vector dimension matching the embeddings being used by this driver.
environmentOptional[str]the environment ("prod", "hcd", ...) hosting the target Data API. It can be omitted for production Astra DB targets. See astrapy.constants.Environment for allowed values.
astra_db_namespaceOptional[str]optional specification of the namespace (in the Astra database) for the data. Note: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store.
caller_namestrthe name of the caller for the Astra DB client. Defaults to "griptape".
clientDataAPIClientan instance of astrapy.DataAPIClient for the Astra DB.
collectionCollectionan instance of astrapy.Collection for the Astra DB.
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
@define
class AstraDbVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for Astra DB.

    Attributes:
        embedding_driver: a `griptape.drivers.BaseEmbeddingDriver` for embedding computations within the store
        api_endpoint: the "API Endpoint" for the Astra DB instance.
        token: a Database Token ("AstraCS:...") secret to access Astra DB. An instance of `astrapy.authentication.TokenProvider` is also accepted.
        collection_name: the name of the collection on Astra DB. The collection must have been created beforehand,
            and support vectors with a vector dimension matching the embeddings being used by this driver.
        environment: the environment ("prod", "hcd", ...) hosting the target Data API.
            It can be omitted for production Astra DB targets. See `astrapy.constants.Environment` for allowed values.
        astra_db_namespace: optional specification of the namespace (in the Astra database) for the data.
            *Note*: not to be confused with the "namespace" mentioned elsewhere, which is a grouping within this vector store.
        caller_name: the name of the caller for the Astra DB client. Defaults to "griptape".
        client: an instance of `astrapy.DataAPIClient` for the Astra DB.
        collection: an instance of `astrapy.Collection` for the Astra DB.
    """

    api_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    token: Optional[str | astrapy.authentication.TokenProvider] = field(
        kw_only=True, default=None, metadata={"serializable": False}
    )
    collection_name: str = field(kw_only=True, metadata={"serializable": True})
    environment: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": True})
    astra_db_namespace: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    caller_name: str = field(default="griptape", kw_only=True, metadata={"serializable": False})
    _client: Optional[astrapy.DataAPIClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )
    _collection: Optional[astrapy.Collection] = field(
        default=None, kw_only=True, alias="collection", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> astrapy.DataAPIClient:
        astrapy = import_optional_dependency("astrapy")

        return astrapy.DataAPIClient(
            callers=[(self.caller_name, None)],
            environment=self.environment or astrapy.utils.unset.UnsetType(),
        )

    @lazy_property()
    def collection(self) -> astrapy.Collection:
        if self.token is None:
            raise ValueError("Astra DB token must be provided.")
        return self.client.get_database(
            self.api_endpoint, token=self.token, keyspace=self.astra_db_namespace
        ).get_collection(self.collection_name)

    def delete_vector(self, vector_id: str) -> None:
        """Delete a vector from Astra DB store.

        The method succeeds regardless of whether a vector with the provided ID
        was actually stored or not in the first place.

        Args:
            vector_id: ID of the vector to delete.
        """
        self.collection.delete_one({"_id": vector_id})

    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs: Any,
    ) -> str:
        """Write a vector to the Astra DB store.

        In case the provided ID exists already, an overwrite will take place.

        Args:
            vector: the vector to be upserted.
            vector_id: the ID for the vector to store. If omitted, a server-provided new ID will be employed.
            namespace: a namespace (a grouping within the vector store) to assign the vector to.
            meta: a metadata dictionary associated to the vector.
            kwargs: additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning.

        Returns:
            the ID of the written vector (str).
        """
        document = {
            k: v
            for k, v in {"$vector": vector, "_id": vector_id, "keyspace": namespace, "meta": meta}.items()
            if v is not None
        }
        if vector_id is not None:
            self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True)
            return vector_id
        insert_result = self.collection.insert_one(document)
        return insert_result.inserted_id

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Load a single vector entry from the Astra DB store given its ID.

        Args:
            vector_id: the ID of the required vector.
            namespace: a namespace, within the vector store, to constrain the search.

        Returns:
            The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None.
        """
        find_filter = {k: v for k, v in {"_id": vector_id, "keyspace": namespace}.items() if v is not None}
        match = self.collection.find_one(filter=find_filter, projection={"*": 1})
        if match is not None:
            return BaseVectorStoreDriver.Entry(
                id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("keyspace")
            )
        return None

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Load entries from the Astra DB store.

        Args:
            namespace: a namespace, within the vector store, to constrain the search.

        Returns:
            A list of vector (`BaseVectorStoreDriver.Entry`) entries.
        """
        find_filter: dict[str, str] = {} if namespace is None else {"keyspace": namespace}
        return [
            BaseVectorStoreDriver.Entry(
                id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("keyspace")
            )
            for match in self.collection.find(filter=find_filter, projection={"*": 1})
        ]

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs: Any,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Run a similarity search on the Astra DB store, based on a vector list.

        Args:
            vector: the vector to be queried.
            count: the maximum number of results to return. If omitted, defaults will apply.
            namespace: the namespace to filter results by.
            include_vectors: whether to include vector data in the results.
            kwargs: additional keyword arguments. Currently only the free-form dict `filter`
                is recognized (and goes straight to the Data API query);
                others will generate a warning and be ignored.

        Returns:
            A list of vector (`BaseVectorStoreDriver.Entry`) entries,
            with their `score` attribute set to the vector similarity to the query.
        """
        query_filter: Optional[dict[str, Any]] = kwargs.get("filter")
        find_filter_ns: dict[str, Any] = {} if namespace is None else {"keyspace": namespace}
        find_filter = {**(query_filter or {}), **find_filter_ns}
        find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
        ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
        matches = self.collection.find(
            filter=find_filter,
            sort={"$vector": vector},
            limit=ann_limit,
            projection=find_projection,
            include_similarity=True,
        )
        return [
            BaseVectorStoreDriver.Entry(
                id=match["_id"],
                vector=match.get("$vector"),
                score=match["$similarity"],
                meta=match.get("meta"),
                namespace=match.get("keyspace"),
            )
            for match in matches
        ]
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • _collection = field(default=None, kw_only=True, alias='collection', metadata={'serializable': False}) class-attribute instance-attribute

  • api_endpoint = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • astra_db_namespace = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • caller_name = field(default='griptape', kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • collection_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • environment = field(kw_only=True, default=None, metadata={'serializable': True}) class-attribute instance-attribute

  • token = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
@lazy_property()
def client(self) -> astrapy.DataAPIClient:
    astrapy = import_optional_dependency("astrapy")

    return astrapy.DataAPIClient(
        callers=[(self.caller_name, None)],
        environment=self.environment or astrapy.utils.unset.UnsetType(),
    )

collection()

Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
@lazy_property()
def collection(self) -> astrapy.Collection:
    if self.token is None:
        raise ValueError("Astra DB token must be provided.")
    return self.client.get_database(
        self.api_endpoint, token=self.token, keyspace=self.astra_db_namespace
    ).get_collection(self.collection_name)

delete_vector(vector_id)

Delete a vector from Astra DB store.

The method succeeds regardless of whether a vector with the provided ID was actually stored or not in the first place.

Parameters

NameTypeDescriptionDefault
vector_idstrID of the vector to delete.
required
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
def delete_vector(self, vector_id: str) -> None:
    """Delete a vector from Astra DB store.

    The method succeeds regardless of whether a vector with the provided ID
    was actually stored or not in the first place.

    Args:
        vector_id: ID of the vector to delete.
    """
    self.collection.delete_one({"_id": vector_id})

load_entries(*, namespace=None)

Load entries from the Astra DB store.

Parameters

NameTypeDescriptionDefault
namespaceOptional[str]a namespace, within the vector store, to constrain the search.
None

Returns

TypeDescription
list[Entry]A list of vector (BaseVectorStoreDriver.Entry) entries.
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Load entries from the Astra DB store.

    Args:
        namespace: a namespace, within the vector store, to constrain the search.

    Returns:
        A list of vector (`BaseVectorStoreDriver.Entry`) entries.
    """
    find_filter: dict[str, str] = {} if namespace is None else {"keyspace": namespace}
    return [
        BaseVectorStoreDriver.Entry(
            id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("keyspace")
        )
        for match in self.collection.find(filter=find_filter, projection={"*": 1})
    ]

load_entry(vector_id, *, namespace=None)

Load a single vector entry from the Astra DB store given its ID.

Parameters

NameTypeDescriptionDefault
vector_idstrthe ID of the required vector.
required
namespaceOptional[str]a namespace, within the vector store, to constrain the search.
None

Returns

TypeDescription
Optional[Entry]The vector entry (a BaseVectorStoreDriver.Entry) if found, otherwise None.
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Load a single vector entry from the Astra DB store given its ID.

    Args:
        vector_id: the ID of the required vector.
        namespace: a namespace, within the vector store, to constrain the search.

    Returns:
        The vector entry (a `BaseVectorStoreDriver.Entry`) if found, otherwise None.
    """
    find_filter = {k: v for k, v in {"_id": vector_id, "keyspace": namespace}.items() if v is not None}
    match = self.collection.find_one(filter=find_filter, projection={"*": 1})
    if match is not None:
        return BaseVectorStoreDriver.Entry(
            id=match["_id"], vector=match.get("$vector"), meta=match.get("meta"), namespace=match.get("keyspace")
        )
    return None

query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)

Run a similarity search on the Astra DB store, based on a vector list.

Parameters

NameTypeDescriptionDefault
vectorlist[float]the vector to be queried.
required
countOptional[int]the maximum number of results to return. If omitted, defaults will apply.
None
namespaceOptional[str]the namespace to filter results by.
None
include_vectorsboolwhether to include vector data in the results.
False
kwargsAnyadditional keyword arguments. Currently only the free-form dict filter is recognized (and goes straight to the Data API query); others will generate a warning and be ignored.
{}

Returns

TypeDescription
list[Entry]A list of vector (BaseVectorStoreDriver.Entry) entries,
list[Entry]with their score attribute set to the vector similarity to the query.
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs: Any,
) -> list[BaseVectorStoreDriver.Entry]:
    """Run a similarity search on the Astra DB store, based on a vector list.

    Args:
        vector: the vector to be queried.
        count: the maximum number of results to return. If omitted, defaults will apply.
        namespace: the namespace to filter results by.
        include_vectors: whether to include vector data in the results.
        kwargs: additional keyword arguments. Currently only the free-form dict `filter`
            is recognized (and goes straight to the Data API query);
            others will generate a warning and be ignored.

    Returns:
        A list of vector (`BaseVectorStoreDriver.Entry`) entries,
        with their `score` attribute set to the vector similarity to the query.
    """
    query_filter: Optional[dict[str, Any]] = kwargs.get("filter")
    find_filter_ns: dict[str, Any] = {} if namespace is None else {"keyspace": namespace}
    find_filter = {**(query_filter or {}), **find_filter_ns}
    find_projection: Optional[dict[str, int]] = {"*": 1} if include_vectors else None
    ann_limit = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
    matches = self.collection.find(
        filter=find_filter,
        sort={"$vector": vector},
        limit=ann_limit,
        projection=find_projection,
        include_similarity=True,
    )
    return [
        BaseVectorStoreDriver.Entry(
            id=match["_id"],
            vector=match.get("$vector"),
            score=match["$similarity"],
            meta=match.get("meta"),
            namespace=match.get("keyspace"),
        )
        for match in matches
    ]

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)

Write a vector to the Astra DB store.

In case the provided ID exists already, an overwrite will take place.

Parameters

NameTypeDescriptionDefault
vectorlist[float]the vector to be upserted.
required
vector_idOptional[str]the ID for the vector to store. If omitted, a server-provided new ID will be employed.
None
namespaceOptional[str]a namespace (a grouping within the vector store) to assign the vector to.
None
metaOptional[dict]a metadata dictionary associated to the vector.
None
kwargsAnyadditional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning.
{}

Returns

TypeDescription
strthe ID of the written vector (str).
Source Code in griptape/drivers/vector/astradb_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs: Any,
) -> str:
    """Write a vector to the Astra DB store.

    In case the provided ID exists already, an overwrite will take place.

    Args:
        vector: the vector to be upserted.
        vector_id: the ID for the vector to store. If omitted, a server-provided new ID will be employed.
        namespace: a namespace (a grouping within the vector store) to assign the vector to.
        meta: a metadata dictionary associated to the vector.
        kwargs: additional keyword arguments. Currently none is used: if they are passed, they will be ignored with a warning.

    Returns:
        the ID of the written vector (str).
    """
    document = {
        k: v
        for k, v in {"$vector": vector, "_id": vector_id, "keyspace": namespace, "meta": meta}.items()
        if v is not None
    }
    if vector_id is not None:
        self.collection.find_one_and_replace({"_id": vector_id}, document, upsert=True)
        return vector_id
    insert_result = self.collection.insert_one(document)
    return insert_result.inserted_id

AwsIotCoreEventListenerDriver

Bases: BaseEventListenerDriver

Source Code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
@define
class AwsIotCoreEventListenerDriver(BaseEventListenerDriver):
    iot_endpoint: str = field(kw_only=True)
    topic: str = field(kw_only=True)
    session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True)
    _client: Optional[IoTDataPlaneClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> IoTDataPlaneClient:
        return self.session.client("iot-data")

    def try_publish_event_payload(self, event_payload: dict) -> None:
        self.client.publish(topic=self.topic, payload=json.dumps(event_payload))

    def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
        self.client.publish(topic=self.topic, payload=json.dumps(event_payload_batch))
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • iot_endpoint = field(kw_only=True) class-attribute instance-attribute

  • session = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True) class-attribute instance-attribute

  • topic = field(kw_only=True) class-attribute instance-attribute

client()

Source Code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
@lazy_property()
def client(self) -> IoTDataPlaneClient:
    return self.session.client("iot-data")

try_publish_event_payload(event_payload)

Source Code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    self.client.publish(topic=self.topic, payload=json.dumps(event_payload))

try_publish_event_payload_batch(event_payload_batch)

Source Code in griptape/drivers/event_listener/aws_iot_core_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
    self.client.publish(topic=self.topic, payload=json.dumps(event_payload_batch))

AzureMongoDbVectorStoreDriver

Bases: MongoDbAtlasVectorStoreDriver

Source Code in griptape/drivers/vector/azure_mongodb_vector_store_driver.py
@define
class AzureMongoDbVectorStoreDriver(MongoDbAtlasVectorStoreDriver):
    """A Vector Store Driver for CosmosDB with MongoDB vCore API."""

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        offset: Optional[int] = None,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Queries the MongoDB collection for documents that match the provided vector list.

        Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
        """
        collection = self.get_collection()

        count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
        offset = offset or 0

        pipeline = []

        pipeline.append(
            {
                "$search": {
                    "cosmosSearch": {
                        "vector": vector,
                        "path": self.vector_path,
                        "k": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                    },
                    "returnStoredSource": True,
                },
            },
        )

        if namespace:
            pipeline.append({"$match": {"namespace": namespace}})

        pipeline.append({"$project": {"similarityScore": {"$meta": "searchScore"}, "document": "$$ROOT"}})

        return [
            BaseVectorStoreDriver.Entry(
                id=str(doc["_id"]),
                vector=doc[self.vector_path] if include_vectors else [],
                score=doc["similarityScore"],
                meta=doc["document"]["meta"],
                namespace=namespace,
            )
            for doc in collection.aggregate(pipeline)
        ]

query_vector(vector, *, count=None, namespace=None, include_vectors=False, offset=None, **kwargs)

Source Code in griptape/drivers/vector/azure_mongodb_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    offset: Optional[int] = None,
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    """Queries the MongoDB collection for documents that match the provided vector list.

    Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
    """
    collection = self.get_collection()

    count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
    offset = offset or 0

    pipeline = []

    pipeline.append(
        {
            "$search": {
                "cosmosSearch": {
                    "vector": vector,
                    "path": self.vector_path,
                    "k": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                },
                "returnStoredSource": True,
            },
        },
    )

    if namespace:
        pipeline.append({"$match": {"namespace": namespace}})

    pipeline.append({"$project": {"similarityScore": {"$meta": "searchScore"}, "document": "$$ROOT"}})

    return [
        BaseVectorStoreDriver.Entry(
            id=str(doc["_id"]),
            vector=doc[self.vector_path] if include_vectors else [],
            score=doc["similarityScore"],
            meta=doc["document"]["meta"],
            namespace=namespace,
        )
        for doc in collection.aggregate(pipeline)
    ]

AzureOpenAiChatPromptDriver

Bases: OpenAiChatPromptDriver

Attributes

NameTypeDescription
azure_deploymentstrAn optional Azure OpenAi deployment id. Defaults to the model name.
azure_endpointstrAn Azure OpenAi endpoint.
azure_ad_tokenOptional[str]An optional Azure Active Directory token.
azure_ad_token_providerOptional[Callable[[], str]]An optional Azure Active Directory token provider.
api_versionstrAn Azure OpenAi API version.
clientAzureOpenAIAn openai.AzureOpenAI client.
Source Code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@define
class AzureOpenAiChatPromptDriver(OpenAiChatPromptDriver):
    """Azure OpenAi Chat Prompt Driver.

    Attributes:
        azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(
        kw_only=True,
        default=Factory(lambda self: self.model, takes_self=True),
        metadata={"serializable": True},
    )
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True,
        default=None,
        metadata={"serializable": False},
    )
    api_version: str = field(default="2024-10-21", kw_only=True, metadata={"serializable": True})
    _client: Optional[openai.AzureOpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.AzureOpenAI:
        return openai.AzureOpenAI(
            organization=self.organization,
            api_key=self.api_key,
            api_version=self.api_version,
            azure_endpoint=self.azure_endpoint,
            azure_deployment=self.azure_deployment,
            azure_ad_token=self.azure_ad_token,
            azure_ad_token_provider=self.azure_ad_token_provider,
        )

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = super()._base_params(prompt_stack)
        if self.api_version < "2024-02-01" and "seed" in params:
            del params["seed"]
        if self.api_version < "2024-10-21":
            if "stream_options" in params:
                del params["stream_options"]
            if "parallel_tool_calls" in params:
                del params["parallel_tool_calls"]

        # TODO: Add once Azure supports modalities
        if "modalities" in params:
            del params["modalities"]
        return params
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_version = field(default='2024-10-21', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • azure_ad_token = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

  • azure_ad_token_provider = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

  • azure_deployment = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True}) class-attribute instance-attribute

  • azure_endpoint = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    params = super()._base_params(prompt_stack)
    if self.api_version < "2024-02-01" and "seed" in params:
        del params["seed"]
    if self.api_version < "2024-10-21":
        if "stream_options" in params:
            del params["stream_options"]
        if "parallel_tool_calls" in params:
            del params["parallel_tool_calls"]

    # TODO: Add once Azure supports modalities
    if "modalities" in params:
        del params["modalities"]
    return params

client()

Source Code in griptape/drivers/prompt/azure_openai_chat_prompt_driver.py
@lazy_property()
def client(self) -> openai.AzureOpenAI:
    return openai.AzureOpenAI(
        organization=self.organization,
        api_key=self.api_key,
        api_version=self.api_version,
        azure_endpoint=self.azure_endpoint,
        azure_deployment=self.azure_deployment,
        azure_ad_token=self.azure_ad_token,
        azure_ad_token_provider=self.azure_ad_token_provider,
    )

AzureOpenAiEmbeddingDriver

Bases: OpenAiEmbeddingDriver

Attributes

NameTypeDescription
azure_deploymentstrAn optional Azure OpenAi deployment id. Defaults to the model name.
azure_endpointstrAn Azure OpenAi endpoint.
azure_ad_tokenOptional[str]An optional Azure Active Directory token.
azure_ad_token_providerOptional[Callable[[], str]]An optional Azure Active Directory token provider.
api_versionstrAn Azure OpenAi API version.
tokenizerOpenAiTokenizerAn OpenAiTokenizer.
clientAzureOpenAIAn openai.AzureOpenAI client.
Source Code in griptape/drivers/embedding/azure_openai_embedding_driver.py
@define
class AzureOpenAiEmbeddingDriver(OpenAiEmbeddingDriver):
    """Azure OpenAi Embedding Driver.

    Attributes:
        azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        tokenizer: An `OpenAiTokenizer`.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(
        kw_only=True,
        default=Factory(lambda self: self.model, takes_self=True),
        metadata={"serializable": True},
    )
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True,
        default=None,
        metadata={"serializable": False},
    )
    api_version: str = field(default="2024-10-21", kw_only=True, metadata={"serializable": True})
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    _client: Optional[openai.AzureOpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.AzureOpenAI:
        return openai.AzureOpenAI(
            organization=self.organization,
            api_key=self.api_key,
            api_version=self.api_version,
            azure_endpoint=self.azure_endpoint,
            azure_deployment=self.azure_deployment,
            azure_ad_token=self.azure_ad_token,
            azure_ad_token_provider=self.azure_ad_token_provider,
        )
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_version = field(default='2024-10-21', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • azure_ad_token = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

  • azure_ad_token_provider = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

  • azure_deployment = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True}) class-attribute instance-attribute

  • azure_endpoint = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

client()

Source Code in griptape/drivers/embedding/azure_openai_embedding_driver.py
@lazy_property()
def client(self) -> openai.AzureOpenAI:
    return openai.AzureOpenAI(
        organization=self.organization,
        api_key=self.api_key,
        api_version=self.api_version,
        azure_endpoint=self.azure_endpoint,
        azure_deployment=self.azure_deployment,
        azure_ad_token=self.azure_ad_token,
        azure_ad_token_provider=self.azure_ad_token_provider,
    )

AzureOpenAiImageGenerationDriver

Bases: OpenAiImageGenerationDriver

Attributes

NameTypeDescription
azure_deploymentstrAn optional Azure OpenAi deployment id. Defaults to the model name.
azure_endpointstrAn Azure OpenAi endpoint.
azure_ad_tokenOptional[str]An optional Azure Active Directory token.
azure_ad_token_providerOptional[Callable[[], str]]An optional Azure Active Directory token provider.
api_versionstrAn Azure OpenAi API version.
clientAzureOpenAIAn openai.AzureOpenAI client.
Source Code in griptape/drivers/image_generation/azure_openai_image_generation_driver.py
@define
class AzureOpenAiImageGenerationDriver(OpenAiImageGenerationDriver):
    """Driver for Azure-hosted OpenAI image generation API.

    Attributes:
        azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    azure_deployment: str = field(
        kw_only=True,
        default=Factory(lambda self: self.model, takes_self=True),
        metadata={"serializable": True},
    )
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True,
        default=None,
        metadata={"serializable": False},
    )
    api_version: str = field(default="2024-02-01", kw_only=True, metadata={"serializable": True})
    _client: Optional[openai.AzureOpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.AzureOpenAI:
        return openai.AzureOpenAI(
            organization=self.organization,
            api_key=self.api_key,
            api_version=self.api_version,
            azure_endpoint=self.azure_endpoint,
            azure_deployment=self.azure_deployment,
            azure_ad_token=self.azure_ad_token,
            azure_ad_token_provider=self.azure_ad_token_provider,
        )
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_version = field(default='2024-02-01', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • azure_ad_token = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

  • azure_ad_token_provider = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

  • azure_deployment = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True}) class-attribute instance-attribute

  • azure_endpoint = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/image_generation/azure_openai_image_generation_driver.py
@lazy_property()
def client(self) -> openai.AzureOpenAI:
    return openai.AzureOpenAI(
        organization=self.organization,
        api_key=self.api_key,
        api_version=self.api_version,
        azure_endpoint=self.azure_endpoint,
        azure_deployment=self.azure_deployment,
        azure_ad_token=self.azure_ad_token,
        azure_ad_token_provider=self.azure_ad_token_provider,
    )

AzureOpenAiTextToSpeechDriver

Bases: OpenAiTextToSpeechDriver

Attributes

NameTypeDescription
azure_deploymentstrAn optional Azure OpenAi deployment id. Defaults to the model name.
azure_endpointstrAn Azure OpenAi endpoint.
azure_ad_tokenOptional[str]An optional Azure Active Directory token.
azure_ad_token_providerOptional[Callable[[], str]]An optional Azure Active Directory token provider.
api_versionstrAn Azure OpenAi API version.
clientAzureOpenAIAn openai.AzureOpenAI client.
Source Code in griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py
@define
class AzureOpenAiTextToSpeechDriver(OpenAiTextToSpeechDriver):
    """Azure OpenAi Text to Speech Driver.

    Attributes:
        azure_deployment: An optional Azure OpenAi deployment id. Defaults to the model name.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
        client: An `openai.AzureOpenAI` client.
    """

    model: str = field(default="tts", kw_only=True, metadata={"serializable": True})
    azure_deployment: str = field(
        kw_only=True,
        default=Factory(lambda self: self.model, takes_self=True),
        metadata={"serializable": True},
    )
    azure_endpoint: str = field(kw_only=True, metadata={"serializable": True})
    azure_ad_token: Optional[str] = field(kw_only=True, default=None, metadata={"serializable": False})
    azure_ad_token_provider: Optional[Callable[[], str]] = field(
        kw_only=True,
        default=None,
        metadata={"serializable": False},
    )
    api_version: str = field(default="2024-07-01-preview", kw_only=True, metadata={"serializable": True})
    _client: Optional[openai.AzureOpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.AzureOpenAI:
        return openai.AzureOpenAI(
            organization=self.organization,
            api_key=self.api_key,
            api_version=self.api_version,
            azure_endpoint=self.azure_endpoint,
            azure_deployment=self.azure_deployment,
            azure_ad_token=self.azure_ad_token,
            azure_ad_token_provider=self.azure_ad_token_provider,
        )
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_version = field(default='2024-07-01-preview', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • azure_ad_token = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

  • azure_ad_token_provider = field(kw_only=True, default=None, metadata={'serializable': False}) class-attribute instance-attribute

  • azure_deployment = field(kw_only=True, default=Factory(lambda self: self.model, takes_self=True), metadata={'serializable': True}) class-attribute instance-attribute

  • azure_endpoint = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(default='tts', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/text_to_speech/azure_openai_text_to_speech_driver.py
@lazy_property()
def client(self) -> openai.AzureOpenAI:
    return openai.AzureOpenAI(
        organization=self.organization,
        api_key=self.api_key,
        api_version=self.api_version,
        azure_endpoint=self.azure_endpoint,
        azure_deployment=self.azure_deployment,
        azure_ad_token=self.azure_ad_token,
        azure_ad_token_provider=self.azure_ad_token_provider,
    )

BaseAssistantDriver

Bases:

ABC
Source Code in griptape/drivers/assistant/base_assistant_driver.py
@define
class BaseAssistantDriver(ABC):
    """Base class for AssistantDrivers."""

    def run(self, *args: BaseArtifact) -> TextArtifact:
        return self.try_run(*args)

    @abstractmethod
    def try_run(self, *args: BaseArtifact) -> TextArtifact: ...

run(*args)

Source Code in griptape/drivers/assistant/base_assistant_driver.py
def run(self, *args: BaseArtifact) -> TextArtifact:
    return self.try_run(*args)

try_run(*args)abstractmethod

Source Code in griptape/drivers/assistant/base_assistant_driver.py
@abstractmethod
def try_run(self, *args: BaseArtifact) -> TextArtifact: ...

BaseAudioTranscriptionDriver

Bases: SerializableMixin , ExponentialBackoffMixin , ABC

Source Code in griptape/drivers/audio_transcription/base_audio_transcription_driver.py
@define
class BaseAudioTranscriptionDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    model: str = field(kw_only=True, metadata={"serializable": True})

    def before_run(self) -> None:
        EventBus.publish_event(StartAudioTranscriptionEvent())

    def after_run(self) -> None:
        EventBus.publish_event(FinishAudioTranscriptionEvent())

    def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run()
                result = self.try_run(audio, prompts)
                self.after_run()

                return result

        raise Exception("Failed to run audio transcription")

    @abstractmethod
    def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: ...
  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

after_run()

Source Code in griptape/drivers/audio_transcription/base_audio_transcription_driver.py
def after_run(self) -> None:
    EventBus.publish_event(FinishAudioTranscriptionEvent())

before_run()

Source Code in griptape/drivers/audio_transcription/base_audio_transcription_driver.py
def before_run(self) -> None:
    EventBus.publish_event(StartAudioTranscriptionEvent())

run(audio, prompts=None)

Source Code in griptape/drivers/audio_transcription/base_audio_transcription_driver.py
def run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run()
            result = self.try_run(audio, prompts)
            self.after_run()

            return result

    raise Exception("Failed to run audio transcription")

try_run(audio, prompts=None)abstractmethod

Source Code in griptape/drivers/audio_transcription/base_audio_transcription_driver.py
@abstractmethod
def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact: ...

BaseConversationMemoryDriver

Bases: SerializableMixin , ABC

Source Code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
class BaseConversationMemoryDriver(SerializableMixin, ABC):
    @abstractmethod
    def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: ...

    @abstractmethod
    def load(self) -> tuple[list[Run], dict[str, Any]]: ...

    def _to_params_dict(self, runs: list[Run], metadata: dict[str, Any]) -> dict:
        return {"runs": [run.to_dict() for run in runs], "metadata": metadata}

    def _from_params_dict(self, params_dict: dict[str, Any]) -> tuple[list[Run], dict[str, Any]]:
        from griptape.memory.structure import Run

        return [Run.from_dict(run) for run in params_dict.get("runs", [])], params_dict.get("metadata", {})

_from_params_dict(params_dict)

Source Code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
def _from_params_dict(self, params_dict: dict[str, Any]) -> tuple[list[Run], dict[str, Any]]:
    from griptape.memory.structure import Run

    return [Run.from_dict(run) for run in params_dict.get("runs", [])], params_dict.get("metadata", {})

_to_params_dict(runs, metadata)

Source Code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
def _to_params_dict(self, runs: list[Run], metadata: dict[str, Any]) -> dict:
    return {"runs": [run.to_dict() for run in runs], "metadata": metadata}

load()abstractmethod

Source Code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod
def load(self) -> tuple[list[Run], dict[str, Any]]: ...

store(runs, metadata)abstractmethod

Source Code in griptape/drivers/memory/conversation/base_conversation_memory_driver.py
@abstractmethod
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: ...

BaseDiffusionImageGenerationPipelineDriver

Bases:

ABC
Source Code in griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py
@define
class BaseDiffusionImageGenerationPipelineDriver(ABC):
    @abstractmethod
    def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ...

    @abstractmethod
    def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ...

    @abstractmethod
    def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict: ...

    @property
    @abstractmethod
    def output_image_dimensions(self) -> tuple[int, int]: ...
  • output_image_dimensions abstractmethod property

make_additional_params(negative_prompts, device)abstractmethod

Source Code in griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py
@abstractmethod
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict: ...

make_image_param(image)abstractmethod

Source Code in griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py
@abstractmethod
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]: ...

prepare_pipeline(model, device)abstractmethod

Source Code in griptape/drivers/image_generation_pipeline/base_image_generation_pipeline_driver.py
@abstractmethod
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any: ...

BaseEmbeddingDriver

Bases: SerializableMixin , ExponentialBackoffMixin , ABC

Attributes

NameTypeDescription
modelstrThe name of the model to use.
tokenizerOptional[BaseTokenizer]An instance of BaseTokenizer to use when calculating tokens.
Source Code in griptape/drivers/embedding/base_embedding_driver.py
@define
class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    """Base Embedding Driver.

    Attributes:
        model: The name of the model to use.
        tokenizer: An instance of `BaseTokenizer` to use when calculating tokens.
    """

    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True)
    chunker: Optional[BaseChunker] = field(init=False)

    def __attrs_post_init__(self) -> None:
        self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None

    def embed_text_artifact(
        self, artifact: TextArtifact, *, vector_operation: VectorOperation | None = None
    ) -> list[float]:
        warnings.warn(
            "`BaseEmbeddingDriver.embed_text_artifact` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.",
            DeprecationWarning,
            stacklevel=2,
        )
        return self.embed(artifact, vector_operation=vector_operation)

    def embed_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
        warnings.warn(
            "`BaseEmbeddingDriver.embed_string` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.",
            DeprecationWarning,
            stacklevel=2,
        )
        return self.embed(string, vector_operation=vector_operation)

    def embed(
        self, value: str | TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None
    ) -> list[float]:
        for attempt in self.retrying():
            with attempt:
                if isinstance(value, str):
                    if (
                        self.tokenizer is not None
                        and self.tokenizer.count_tokens(value) > self.tokenizer.max_input_tokens
                    ):
                        return self._embed_long_string(value, vector_operation=vector_operation)
                    return self.try_embed_chunk(value, vector_operation=vector_operation)
                if isinstance(value, TextArtifact):
                    return self.embed(value.to_text(), vector_operation=vector_operation)
                if isinstance(value, ImageArtifact):
                    return self.try_embed_artifact(value, vector_operation=vector_operation)
        raise RuntimeError("Failed to embed string.")

    def try_embed_artifact(
        self, artifact: TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None
    ) -> list[float]:
        # TODO: Mark as abstract method for griptape 2.0
        if isinstance(artifact, TextArtifact):
            return self.try_embed_chunk(artifact.value, vector_operation=vector_operation)
        raise ValueError(f"{self.__class__.__name__} does not support embedding images.")

    @abstractmethod
    def try_embed_chunk(self, chunk: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
        # TODO: Remove for griptape 2.0, subclasses should implement `try_embed_artifact` instead
        ...

    def _embed_long_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
        """Embeds a string that is too long to embed in one go.

        Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb
        """
        chunks = self.chunker.chunk(string)  # pyright: ignore[reportOptionalMemberAccess] In practice this is never None

        embedding_chunks = []
        length_chunks = []
        for chunk in chunks:
            embedding_chunks.append(self.embed(chunk.value, vector_operation=vector_operation))
            length_chunks.append(len(chunk))

        # generate weighted averages
        embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)

        # normalize length to 1
        embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)

        return embedding_chunks.tolist()
  • chunker = field(init=False) class-attribute instance-attribute

  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=None, kw_only=True) class-attribute instance-attribute

attrs_post_init()

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def __attrs_post_init__(self) -> None:
    self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None

_embed_long_string(string, *, vector_operation=None)

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def _embed_long_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
    """Embeds a string that is too long to embed in one go.

    Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb
    """
    chunks = self.chunker.chunk(string)  # pyright: ignore[reportOptionalMemberAccess] In practice this is never None

    embedding_chunks = []
    length_chunks = []
    for chunk in chunks:
        embedding_chunks.append(self.embed(chunk.value, vector_operation=vector_operation))
        length_chunks.append(len(chunk))

    # generate weighted averages
    embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)

    # normalize length to 1
    embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)

    return embedding_chunks.tolist()

embed(value, *, vector_operation=None)

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def embed(
    self, value: str | TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None
) -> list[float]:
    for attempt in self.retrying():
        with attempt:
            if isinstance(value, str):
                if (
                    self.tokenizer is not None
                    and self.tokenizer.count_tokens(value) > self.tokenizer.max_input_tokens
                ):
                    return self._embed_long_string(value, vector_operation=vector_operation)
                return self.try_embed_chunk(value, vector_operation=vector_operation)
            if isinstance(value, TextArtifact):
                return self.embed(value.to_text(), vector_operation=vector_operation)
            if isinstance(value, ImageArtifact):
                return self.try_embed_artifact(value, vector_operation=vector_operation)
    raise RuntimeError("Failed to embed string.")

embed_string(string, *, vector_operation=None)

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def embed_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
    warnings.warn(
        "`BaseEmbeddingDriver.embed_string` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.",
        DeprecationWarning,
        stacklevel=2,
    )
    return self.embed(string, vector_operation=vector_operation)

embed_text_artifact(artifact, *, vector_operation=None)

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def embed_text_artifact(
    self, artifact: TextArtifact, *, vector_operation: VectorOperation | None = None
) -> list[float]:
    warnings.warn(
        "`BaseEmbeddingDriver.embed_text_artifact` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.",
        DeprecationWarning,
        stacklevel=2,
    )
    return self.embed(artifact, vector_operation=vector_operation)

try_embed_artifact(artifact, *, vector_operation=None)

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def try_embed_artifact(
    self, artifact: TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None
) -> list[float]:
    # TODO: Mark as abstract method for griptape 2.0
    if isinstance(artifact, TextArtifact):
        return self.try_embed_chunk(artifact.value, vector_operation=vector_operation)
    raise ValueError(f"{self.__class__.__name__} does not support embedding images.")

try_embed_chunk(chunk, *, vector_operation=None)abstractmethod

Source Code in griptape/drivers/embedding/base_embedding_driver.py
@abstractmethod
def try_embed_chunk(self, chunk: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
    # TODO: Remove for griptape 2.0, subclasses should implement `try_embed_artifact` instead
    ...

BaseEventListenerDriver

Bases: FuturesExecutorMixin , ExponentialBackoffMixin , ABC

Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
@define
class BaseEventListenerDriver(FuturesExecutorMixin, ExponentialBackoffMixin, ABC):
    batched: bool = field(default=True, kw_only=True)
    batch_size: int = field(default=10, kw_only=True)

    _batch: list[dict] = field(default=Factory(list), kw_only=True)

    @property
    def batch(self) -> list[dict]:
        return self._batch

    def publish_event(self, event: BaseEvent | dict) -> None:
        event_payload = event if isinstance(event, dict) else event.to_dict()

        with self.create_futures_executor() as futures_executor:
            if self.batched:
                self._batch.append(event_payload)
                if len(self.batch) >= self.batch_size:
                    futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
                    self._batch = []
            else:
                futures_executor.submit(with_contextvars(self._safe_publish_event_payload), event_payload)

    def flush_events(self) -> None:
        if self.batch:
            with self.create_futures_executor() as futures_executor:
                futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
            self._batch = []

    @abstractmethod
    def try_publish_event_payload(self, event_payload: dict) -> None: ...

    @abstractmethod
    def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: ...

    def _safe_publish_event_payload(self, event_payload: dict) -> None:
        try:
            for attempt in self.retrying():
                with attempt:
                    self.try_publish_event_payload(event_payload)
        except Exception:
            logger.warning("Failed to publish event after %s attempts", self.max_attempts, exc_info=True)

    def _safe_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
        try:
            for attempt in self.retrying():
                with attempt:
                    self.try_publish_event_payload_batch(event_payload_batch)
        except Exception:
            logger.warning("Failed to publish event batch after %s attempts", self.max_attempts, exc_info=True)
  • _batch = field(default=Factory(list), kw_only=True) class-attribute instance-attribute

  • batch property

  • batch_size = field(default=10, kw_only=True) class-attribute instance-attribute

  • batched = field(default=True, kw_only=True) class-attribute instance-attribute

_safe_publish_event_payload(event_payload)

Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
def _safe_publish_event_payload(self, event_payload: dict) -> None:
    try:
        for attempt in self.retrying():
            with attempt:
                self.try_publish_event_payload(event_payload)
    except Exception:
        logger.warning("Failed to publish event after %s attempts", self.max_attempts, exc_info=True)

_safe_publish_event_payload_batch(event_payload_batch)

Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
def _safe_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
    try:
        for attempt in self.retrying():
            with attempt:
                self.try_publish_event_payload_batch(event_payload_batch)
    except Exception:
        logger.warning("Failed to publish event batch after %s attempts", self.max_attempts, exc_info=True)

flush_events()

Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
def flush_events(self) -> None:
    if self.batch:
        with self.create_futures_executor() as futures_executor:
            futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
        self._batch = []

publish_event(event)

Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
def publish_event(self, event: BaseEvent | dict) -> None:
    event_payload = event if isinstance(event, dict) else event.to_dict()

    with self.create_futures_executor() as futures_executor:
        if self.batched:
            self._batch.append(event_payload)
            if len(self.batch) >= self.batch_size:
                futures_executor.submit(with_contextvars(self._safe_publish_event_payload_batch), self.batch)
                self._batch = []
        else:
            futures_executor.submit(with_contextvars(self._safe_publish_event_payload), event_payload)

try_publish_event_payload(event_payload)abstractmethod

Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
@abstractmethod
def try_publish_event_payload(self, event_payload: dict) -> None: ...

try_publish_event_payload_batch(event_payload_batch)abstractmethod

Source Code in griptape/drivers/event_listener/base_event_listener_driver.py
@abstractmethod
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: ...

BaseFileManagerDriver

Bases:

ABC

Attributes

NameTypeDescription
default_loader``The default loader to use for loading file contents into artifacts.
loaders``Dictionary of file extension specific loaders to use for loading file contents into artifacts.
Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
@define
class BaseFileManagerDriver(ABC):
    """BaseFileManagerDriver can be used to list, load, and save files.

    Attributes:
        default_loader: The default loader to use for loading file contents into artifacts.
        loaders: Dictionary of file extension specific loaders to use for loading file contents into artifacts.
    """

    _workdir: str = field(kw_only=True, alias="workdir")
    encoding: Optional[str] = field(default=None, kw_only=True)

    @property
    @abstractmethod
    def workdir(self) -> str: ...

    @workdir.setter
    @abstractmethod
    def workdir(self, value: str) -> None: ...

    def list_files(self, path: str) -> TextArtifact:
        entries = self.try_list_files(path)
        return TextArtifact("\n".join(list(entries)))

    @abstractmethod
    def try_list_files(self, path: str) -> list[str]: ...

    def load_file(self, path: str) -> BlobArtifact | TextArtifact:
        if self.encoding is None:
            return BlobArtifact(self.try_load_file(path))
        return TextArtifact(self.try_load_file(path).decode(encoding=self.encoding), encoding=self.encoding)

    @abstractmethod
    def try_load_file(self, path: str) -> bytes: ...

    def save_file(self, path: str, value: bytes | str) -> InfoArtifact:
        if isinstance(value, str):
            value = value.encode() if self.encoding is None else value.encode(encoding=self.encoding)
        elif isinstance(value, (bytearray, memoryview)):
            raise ValueError(f"Unsupported type: {type(value)}")

        location = self.try_save_file(path, value)

        return InfoArtifact(f"Successfully saved file at: {location}")

    @abstractmethod
    def try_save_file(self, path: str, value: bytes) -> str: ...

    def load_artifact(self, path: str) -> BaseArtifact:
        response = self.try_load_file(path)
        return BaseArtifact.from_json(
            response.decode() if self.encoding is None else response.decode(encoding=self.encoding)
        )

    def save_artifact(self, path: str, artifact: BaseArtifact) -> InfoArtifact:
        artifact_json = artifact.to_json()
        value = artifact_json.encode() if self.encoding is None else artifact_json.encode(encoding=self.encoding)

        location = self.try_save_file(path, value)

        return InfoArtifact(f"Successfully saved artifact at: {location}")
  • _workdir = field(kw_only=True, alias='workdir') class-attribute instance-attribute

  • encoding = field(default=None, kw_only=True) class-attribute instance-attribute

  • workdir abstractmethod property writable

list_files(path)

Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
def list_files(self, path: str) -> TextArtifact:
    entries = self.try_list_files(path)
    return TextArtifact("\n".join(list(entries)))

load_artifact(path)

Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
def load_artifact(self, path: str) -> BaseArtifact:
    response = self.try_load_file(path)
    return BaseArtifact.from_json(
        response.decode() if self.encoding is None else response.decode(encoding=self.encoding)
    )

load_file(path)

Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
def load_file(self, path: str) -> BlobArtifact | TextArtifact:
    if self.encoding is None:
        return BlobArtifact(self.try_load_file(path))
    return TextArtifact(self.try_load_file(path).decode(encoding=self.encoding), encoding=self.encoding)

save_artifact(path, artifact)

Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
def save_artifact(self, path: str, artifact: BaseArtifact) -> InfoArtifact:
    artifact_json = artifact.to_json()
    value = artifact_json.encode() if self.encoding is None else artifact_json.encode(encoding=self.encoding)

    location = self.try_save_file(path, value)

    return InfoArtifact(f"Successfully saved artifact at: {location}")

save_file(path, value)

Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
def save_file(self, path: str, value: bytes | str) -> InfoArtifact:
    if isinstance(value, str):
        value = value.encode() if self.encoding is None else value.encode(encoding=self.encoding)
    elif isinstance(value, (bytearray, memoryview)):
        raise ValueError(f"Unsupported type: {type(value)}")

    location = self.try_save_file(path, value)

    return InfoArtifact(f"Successfully saved file at: {location}")

try_list_files(path)abstractmethod

Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
@abstractmethod
def try_list_files(self, path: str) -> list[str]: ...

try_load_file(path)abstractmethod

Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
@abstractmethod
def try_load_file(self, path: str) -> bytes: ...

try_save_file(path, value)abstractmethod

Source Code in griptape/drivers/file_manager/base_file_manager_driver.py
@abstractmethod
def try_save_file(self, path: str, value: bytes) -> str: ...

BaseImageGenerationDriver

Bases: SerializableMixin , ExponentialBackoffMixin , ABC

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
@define
class BaseImageGenerationDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    model: str = field(kw_only=True, metadata={"serializable": True})

    def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None:
        EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts))

    def after_run(self) -> None:
        EventBus.publish_event(FinishImageGenerationEvent())

    def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_text_to_image(prompts, negative_prompts)
                self.after_run()

                return result

        raise Exception("Failed to run text to image generation")

    def run_image_variation(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_image_variation(prompts, image, negative_prompts)
                self.after_run()

                return result

        raise Exception("Failed to generate image variations")

    def run_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_image_inpainting(prompts, image, mask, negative_prompts)
                self.after_run()

                return result

        raise Exception("Failed to run image inpainting")

    def run_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts, negative_prompts)
                result = self.try_image_outpainting(prompts, image, mask, negative_prompts)
                self.after_run()

                return result

        raise Exception("Failed to run image outpainting")

    @abstractmethod
    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: ...

    @abstractmethod
    def try_image_variation(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact: ...

    @abstractmethod
    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact: ...

    @abstractmethod
    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact: ...
  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

after_run()

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def after_run(self) -> None:
    EventBus.publish_event(FinishImageGenerationEvent())

before_run(prompts, negative_prompts=None)

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def before_run(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> None:
    EventBus.publish_event(StartImageGenerationEvent(prompts=prompts, negative_prompts=negative_prompts))

run_image_inpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_image_inpainting(prompts, image, mask, negative_prompts)
            self.after_run()

            return result

    raise Exception("Failed to run image inpainting")

run_image_outpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_image_outpainting(prompts, image, mask, negative_prompts)
            self.after_run()

            return result

    raise Exception("Failed to run image outpainting")

run_image_variation(prompts, image, negative_prompts=None)

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_image_variation(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_image_variation(prompts, image, negative_prompts)
            self.after_run()

            return result

    raise Exception("Failed to generate image variations")

run_text_to_image(prompts, negative_prompts=None)

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
def run_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts, negative_prompts)
            result = self.try_text_to_image(prompts, negative_prompts)
            self.after_run()

            return result

    raise Exception("Failed to run text to image generation")

try_image_inpainting(prompts, image, mask, negative_prompts=None)abstractmethod

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact: ...

try_image_outpainting(prompts, image, mask, negative_prompts=None)abstractmethod

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact: ...

try_image_variation(prompts, image, negative_prompts=None)abstractmethod

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_image_variation(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact: ...

try_text_to_image(prompts, negative_prompts=None)abstractmethod

Source Code in griptape/drivers/image_generation/base_image_generation_driver.py
@abstractmethod
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: ...

BaseImageGenerationModelDriver

Bases: SerializableMixin , ABC

Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@define
class BaseImageGenerationModelDriver(SerializableMixin, ABC):
    @abstractmethod
    def get_generated_image(self, response: dict) -> bytes: ...

    @abstractmethod
    def text_to_image_request_parameters(
        self,
        prompts: list[str],
        image_width: int,
        image_height: int,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict[str, Any]: ...

    @abstractmethod
    def image_variation_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict[str, Any]: ...

    @abstractmethod
    def image_inpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict[str, Any]: ...

    @abstractmethod
    def image_outpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict[str, Any]: ...

get_generated_image(response)abstractmethod

Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def get_generated_image(self, response: dict) -> bytes: ...

image_inpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)abstractmethod

Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def image_inpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict[str, Any]: ...

image_outpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)abstractmethod

Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def image_outpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict[str, Any]: ...

image_variation_request_parameters(prompts, image, negative_prompts=None, seed=None)abstractmethod

Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def image_variation_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict[str, Any]: ...

text_to_image_request_parameters(prompts, image_width, image_height, negative_prompts=None, seed=None)abstractmethod

Source Code in griptape/drivers/image_generation_model/base_image_generation_model_driver.py
@abstractmethod
def text_to_image_request_parameters(
    self,
    prompts: list[str],
    image_width: int,
    image_height: int,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict[str, Any]: ...

BaseMultiModelImageGenerationDriver

Bases: BaseImageGenerationDriver , ABC

Attributes

NameTypeDescription
image_generation_model_driverBaseImageGenerationModelDriverImage Model Driver to use.
Source Code in griptape/drivers/image_generation/base_multi_model_image_generation_driver.py
@define
class BaseMultiModelImageGenerationDriver(BaseImageGenerationDriver, ABC):
    """Image Generation Driver for platforms like Amazon Bedrock that host many LLM models.

    Instances of this Image Generation Driver require a Image Generation Model Driver which is used to structure the
    image generation request in the format required by the model and to process the output.

    Attributes:
        image_generation_model_driver: Image Model Driver to use.
    """

    image_generation_model_driver: BaseImageGenerationModelDriver = field(kw_only=True, metadata={"serializable": True})
  • image_generation_model_driver = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

BaseObservabilityDriver

Bases:

ABC
Source Code in griptape/drivers/observability/base_observability_driver.py
@define
class BaseObservabilityDriver(ABC):
    def __enter__(self) -> None:  # noqa: B027
        pass

    def __exit__(
        self,
        exc_type: Optional[type[BaseException]],
        exc_value: Optional[BaseException],
        exc_traceback: Optional[TracebackType],
    ) -> bool:
        return False

    @abstractmethod
    def observe(self, call: Observable.Call) -> Any: ...

    @abstractmethod
    def get_span_id(self) -> Optional[str]: ...

enter()

Source Code in griptape/drivers/observability/base_observability_driver.py
def __enter__(self) -> None:  # noqa: B027
    pass

exit(exc_type, exc_value, exc_traceback)

Source Code in griptape/drivers/observability/base_observability_driver.py
def __exit__(
    self,
    exc_type: Optional[type[BaseException]],
    exc_value: Optional[BaseException],
    exc_traceback: Optional[TracebackType],
) -> bool:
    return False

get_span_id()abstractmethod

Source Code in griptape/drivers/observability/base_observability_driver.py
@abstractmethod
def get_span_id(self) -> Optional[str]: ...

observe(call)abstractmethod

Source Code in griptape/drivers/observability/base_observability_driver.py
@abstractmethod
def observe(self, call: Observable.Call) -> Any: ...

BasePromptDriver

Bases: SerializableMixin , ExponentialBackoffMixin , ABC

Attributes

NameTypeDescription
temperaturefloatThe temperature to use for the completion.
max_tokensOptional[int]The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.
prompt_stack_to_stringstrA function that converts a PromptStack to a string.
ignored_exception_typestuple[type[Exception], ...]A tuple of exception types to ignore.
modelstrThe model name.
tokenizerBaseTokenizerAn instance of BaseTokenizer to when calculating tokens.
streamboolWhether to stream the completion or not. CompletionChunkEvents will be published to the Structure if one is provided.
use_native_toolsboolWhether to use LLM's native function calling capabilities. Must be supported by the model.
extra_paramsdictExtra parameters to pass to the model.
Source Code in griptape/drivers/prompt/base_prompt_driver.py
@define(kw_only=True)
class BasePromptDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    """Base class for the Prompt Drivers.

    Attributes:
        temperature: The temperature to use for the completion.
        max_tokens: The maximum number of tokens to generate. If not specified, the value will be automatically generated based by the tokenizer.
        prompt_stack_to_string: A function that converts a `PromptStack` to a string.
        ignored_exception_types: A tuple of exception types to ignore.
        model: The model name.
        tokenizer: An instance of `BaseTokenizer` to when calculating tokens.
        stream: Whether to stream the completion or not. `CompletionChunkEvent`s will be published to the `Structure` if one is provided.
        use_native_tools: Whether to use LLM's native function calling capabilities. Must be supported by the model.
        extra_params: Extra parameters to pass to the model.
    """

    temperature: float = field(default=0.1, metadata={"serializable": True})
    max_tokens: Optional[int] = field(default=None, metadata={"serializable": True})
    ignored_exception_types: tuple[type[Exception], ...] = field(default=Factory(lambda: (ImportError, ValueError)))
    model: str = field(metadata={"serializable": True})
    tokenizer: BaseTokenizer
    stream: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    use_native_tools: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    structured_output_strategy: StructuredOutputStrategy = field(
        default="rule", kw_only=True, metadata={"serializable": True}
    )
    extra_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})

    def before_run(self, prompt_stack: PromptStack) -> None:
        self._init_structured_output(prompt_stack)
        EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))

    def after_run(self, result: Message) -> None:
        EventBus.publish_event(
            FinishPromptEvent(
                model=self.model,
                result=result.value,
                input_token_count=result.usage.input_tokens,
                output_token_count=result.usage.output_tokens,
            ),
        )

    @observable(tags=["PromptDriver.run()"])
    def run(self, prompt_input: PromptStack | BaseArtifact) -> Message:
        if isinstance(prompt_input, BaseArtifact):
            prompt_stack = PromptStack.from_artifact(prompt_input)
        else:
            prompt_stack = prompt_input

        for attempt in self.retrying():
            with attempt:
                self.before_run(prompt_stack)

                result = self.__process_stream(prompt_stack) if self.stream else self.__process_run(prompt_stack)

                self.after_run(result)

                return result
        raise Exception("prompt driver failed after all retry attempts")

    def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
        """Converts a Prompt Stack to a string for token counting or model prompt_input.

        This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens.

        Args:
            prompt_stack: The Prompt Stack to convert to a string.

        Returns:
            A single string representation of the Prompt Stack.
        """
        prompt_lines = []

        for i in prompt_stack.messages:
            content = i.to_text()
            if i.is_user():
                prompt_lines.append(f"User: {content}")
            elif i.is_assistant():
                prompt_lines.append(f"Assistant: {content}")
            else:
                prompt_lines.append(content)

        prompt_lines.append("Assistant:")

        return "\n\n".join(prompt_lines)

    @abstractmethod
    def try_run(self, prompt_stack: PromptStack) -> Message: ...

    @abstractmethod
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...

    def _init_structured_output(self, prompt_stack: PromptStack) -> None:
        from griptape.tools import StructuredOutputTool

        if (output_schema := prompt_stack.output_schema) is not None:
            if self.structured_output_strategy == "tool":
                structured_output_tool = StructuredOutputTool(output_schema=output_schema)
                if structured_output_tool not in prompt_stack.tools:
                    prompt_stack.tools.append(structured_output_tool)
            elif self.structured_output_strategy == "rule":
                output_artifact = TextArtifact(JsonSchemaRule(prompt_stack.to_output_json_schema()).to_text())
                system_messages = prompt_stack.system_messages
                if system_messages:
                    last_system_message = prompt_stack.system_messages[-1]
                    last_system_message.content.extend(
                        [
                            TextMessageContent(TextArtifact("\n\n")),
                            TextMessageContent(output_artifact),
                        ]
                    )
                else:
                    prompt_stack.messages.insert(
                        0,
                        Message(
                            content=[TextMessageContent(output_artifact)],
                            role=Message.SYSTEM_ROLE,
                        ),
                    )

    def __process_run(self, prompt_stack: PromptStack) -> Message:
        return self.try_run(prompt_stack)

    def __process_stream(self, prompt_stack: PromptStack) -> Message:
        delta_contents: dict[int, list[BaseDeltaMessageContent]] = {}
        usage = DeltaMessage.Usage()

        # Aggregate all content deltas from the stream
        message_deltas = self.try_stream(prompt_stack)
        for message_delta in message_deltas:
            usage += message_delta.usage
            content = message_delta.content

            if content is not None:
                if content.index in delta_contents:
                    delta_contents[content.index].append(content)
                else:
                    delta_contents[content.index] = [content]
                if isinstance(content, TextDeltaMessageContent):
                    EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index))
                elif isinstance(content, AudioDeltaMessageContent) and content.data is not None:
                    EventBus.publish_event(AudioChunkEvent(data=content.data))
                elif isinstance(content, ActionCallDeltaMessageContent):
                    EventBus.publish_event(
                        ActionChunkEvent(
                            partial_input=content.partial_input,
                            tag=content.tag,
                            name=content.name,
                            path=content.path,
                            index=content.index,
                        ),
                    )

        # Build a complete content from the content deltas
        return self.__build_message(list(delta_contents.values()), usage)

    def __build_message(
        self, delta_contents: list[list[BaseDeltaMessageContent]], usage: DeltaMessage.Usage
    ) -> Message:
        content = []
        for delta_content in delta_contents:
            text_deltas = [delta for delta in delta_content if isinstance(delta, TextDeltaMessageContent)]
            audio_deltas = [delta for delta in delta_content if isinstance(delta, AudioDeltaMessageContent)]
            action_deltas = [delta for delta in delta_content if isinstance(delta, ActionCallDeltaMessageContent)]

            if text_deltas:
                content.append(TextMessageContent.from_deltas(text_deltas))
            if audio_deltas:
                content.append(AudioMessageContent.from_deltas(audio_deltas))
            if action_deltas:
                content.append(ActionCallMessageContent.from_deltas(action_deltas))

        return Message(
            content=content,
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
        )
  • extra_params = field(factory=dict, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • ignored_exception_types = field(default=Factory(lambda: (ImportError, ValueError))) class-attribute instance-attribute

  • max_tokens = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(metadata={'serializable': True}) class-attribute instance-attribute

  • stream = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • structured_output_strategy = field(default='rule', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • temperature = field(default=0.1, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer instance-attribute

  • use_native_tools = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__build_message(delta_contents, usage)

Source Code in griptape/drivers/prompt/base_prompt_driver.py
def __build_message(
    self, delta_contents: list[list[BaseDeltaMessageContent]], usage: DeltaMessage.Usage
) -> Message:
    content = []
    for delta_content in delta_contents:
        text_deltas = [delta for delta in delta_content if isinstance(delta, TextDeltaMessageContent)]
        audio_deltas = [delta for delta in delta_content if isinstance(delta, AudioDeltaMessageContent)]
        action_deltas = [delta for delta in delta_content if isinstance(delta, ActionCallDeltaMessageContent)]

        if text_deltas:
            content.append(TextMessageContent.from_deltas(text_deltas))
        if audio_deltas:
            content.append(AudioMessageContent.from_deltas(audio_deltas))
        if action_deltas:
            content.append(ActionCallMessageContent.from_deltas(action_deltas))

    return Message(
        content=content,
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
    )

__process_run(prompt_stack)

Source Code in griptape/drivers/prompt/base_prompt_driver.py
def __process_run(self, prompt_stack: PromptStack) -> Message:
    return self.try_run(prompt_stack)

__process_stream(prompt_stack)

Source Code in griptape/drivers/prompt/base_prompt_driver.py
def __process_stream(self, prompt_stack: PromptStack) -> Message:
    delta_contents: dict[int, list[BaseDeltaMessageContent]] = {}
    usage = DeltaMessage.Usage()

    # Aggregate all content deltas from the stream
    message_deltas = self.try_stream(prompt_stack)
    for message_delta in message_deltas:
        usage += message_delta.usage
        content = message_delta.content

        if content is not None:
            if content.index in delta_contents:
                delta_contents[content.index].append(content)
            else:
                delta_contents[content.index] = [content]
            if isinstance(content, TextDeltaMessageContent):
                EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index))
            elif isinstance(content, AudioDeltaMessageContent) and content.data is not None:
                EventBus.publish_event(AudioChunkEvent(data=content.data))
            elif isinstance(content, ActionCallDeltaMessageContent):
                EventBus.publish_event(
                    ActionChunkEvent(
                        partial_input=content.partial_input,
                        tag=content.tag,
                        name=content.name,
                        path=content.path,
                        index=content.index,
                    ),
                )

    # Build a complete content from the content deltas
    return self.__build_message(list(delta_contents.values()), usage)

_init_structured_output(prompt_stack)

Source Code in griptape/drivers/prompt/base_prompt_driver.py
def _init_structured_output(self, prompt_stack: PromptStack) -> None:
    from griptape.tools import StructuredOutputTool

    if (output_schema := prompt_stack.output_schema) is not None:
        if self.structured_output_strategy == "tool":
            structured_output_tool = StructuredOutputTool(output_schema=output_schema)
            if structured_output_tool not in prompt_stack.tools:
                prompt_stack.tools.append(structured_output_tool)
        elif self.structured_output_strategy == "rule":
            output_artifact = TextArtifact(JsonSchemaRule(prompt_stack.to_output_json_schema()).to_text())
            system_messages = prompt_stack.system_messages
            if system_messages:
                last_system_message = prompt_stack.system_messages[-1]
                last_system_message.content.extend(
                    [
                        TextMessageContent(TextArtifact("\n\n")),
                        TextMessageContent(output_artifact),
                    ]
                )
            else:
                prompt_stack.messages.insert(
                    0,
                    Message(
                        content=[TextMessageContent(output_artifact)],
                        role=Message.SYSTEM_ROLE,
                    ),
                )

after_run(result)

Source Code in griptape/drivers/prompt/base_prompt_driver.py
def after_run(self, result: Message) -> None:
    EventBus.publish_event(
        FinishPromptEvent(
            model=self.model,
            result=result.value,
            input_token_count=result.usage.input_tokens,
            output_token_count=result.usage.output_tokens,
        ),
    )

before_run(prompt_stack)

Source Code in griptape/drivers/prompt/base_prompt_driver.py
def before_run(self, prompt_stack: PromptStack) -> None:
    self._init_structured_output(prompt_stack)
    EventBus.publish_event(StartPromptEvent(model=self.model, prompt_stack=prompt_stack))

prompt_stack_to_string(prompt_stack)

Converts a Prompt Stack to a string for token counting or model prompt_input.

This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens.

Parameters

NameTypeDescriptionDefault
prompt_stackPromptStackThe Prompt Stack to convert to a string.
required

Returns

TypeDescription
strA single string representation of the Prompt Stack.
Source Code in griptape/drivers/prompt/base_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
    """Converts a Prompt Stack to a string for token counting or model prompt_input.

    This base implementation is only a rough approximation, and should be overridden by subclasses with model-specific tokens.

    Args:
        prompt_stack: The Prompt Stack to convert to a string.

    Returns:
        A single string representation of the Prompt Stack.
    """
    prompt_lines = []

    for i in prompt_stack.messages:
        content = i.to_text()
        if i.is_user():
            prompt_lines.append(f"User: {content}")
        elif i.is_assistant():
            prompt_lines.append(f"Assistant: {content}")
        else:
            prompt_lines.append(content)

    prompt_lines.append("Assistant:")

    return "\n\n".join(prompt_lines)

run(prompt_input)

Source Code in griptape/drivers/prompt/base_prompt_driver.py
@observable(tags=["PromptDriver.run()"])
def run(self, prompt_input: PromptStack | BaseArtifact) -> Message:
    if isinstance(prompt_input, BaseArtifact):
        prompt_stack = PromptStack.from_artifact(prompt_input)
    else:
        prompt_stack = prompt_input

    for attempt in self.retrying():
        with attempt:
            self.before_run(prompt_stack)

            result = self.__process_stream(prompt_stack) if self.stream else self.__process_run(prompt_stack)

            self.after_run(result)

            return result
    raise Exception("prompt driver failed after all retry attempts")

try_run(prompt_stack)abstractmethod

Source Code in griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod
def try_run(self, prompt_stack: PromptStack) -> Message: ...

try_stream(prompt_stack)abstractmethod

Source Code in griptape/drivers/prompt/base_prompt_driver.py
@abstractmethod
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: ...

BaseRerankDriver

Bases:

ABC
Source Code in griptape/drivers/rerank/base_rerank_driver.py
@define(kw_only=True)
class BaseRerankDriver(ABC):
    @abstractmethod
    def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: ...

run(query, artifacts)abstractmethod

Source Code in griptape/drivers/rerank/base_rerank_driver.py
@abstractmethod
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: ...

BaseRulesetDriver

Bases: SerializableMixin , ABC

Attributes

NameTypeDescription
raise_not_foundboolWhether to raise an error if the ruleset is not found. Defaults to True.
Source Code in griptape/drivers/ruleset/base_ruleset_driver.py
@define
class BaseRulesetDriver(SerializableMixin, ABC):
    """Base class for ruleset drivers.

    Attributes:
        raise_not_found: Whether to raise an error if the ruleset is not found. Defaults to True.
    """

    raise_not_found: bool = field(default=True, kw_only=True, metadata={"serializable": True})

    @abstractmethod
    def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]: ...

    def _from_ruleset_dict(self, params_dict: dict[str, Any]) -> tuple[list[BaseRule], dict[str, Any]]:
        return [
            self._get_rule(rule["value"], rule.get("meta", {})) for rule in params_dict.get("rules", [])
        ], params_dict.get("meta", {})

    def _get_rule(self, value: Any, meta: dict[str, Any]) -> BaseRule:
        from griptape.rules import JsonSchemaRule, Rule

        return JsonSchemaRule(value=value, meta=meta) if isinstance(value, dict) else Rule(value=str(value), meta=meta)
  • raise_not_found = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

_from_ruleset_dict(params_dict)

Source Code in griptape/drivers/ruleset/base_ruleset_driver.py
def _from_ruleset_dict(self, params_dict: dict[str, Any]) -> tuple[list[BaseRule], dict[str, Any]]:
    return [
        self._get_rule(rule["value"], rule.get("meta", {})) for rule in params_dict.get("rules", [])
    ], params_dict.get("meta", {})

_get_rule(value, meta)

Source Code in griptape/drivers/ruleset/base_ruleset_driver.py
def _get_rule(self, value: Any, meta: dict[str, Any]) -> BaseRule:
    from griptape.rules import JsonSchemaRule, Rule

    return JsonSchemaRule(value=value, meta=meta) if isinstance(value, dict) else Rule(value=str(value), meta=meta)

load(ruleset_name)abstractmethod

Source Code in griptape/drivers/ruleset/base_ruleset_driver.py
@abstractmethod
def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]: ...

BaseSqlDriver

Bases:

ABC
Source Code in griptape/drivers/sql/base_sql_driver.py
@define
class BaseSqlDriver(ABC):
    @dataclass
    class RowResult:
        cells: dict[str, Any]

    @abstractmethod
    def execute_query(self, query: str) -> Optional[list[RowResult]]: ...

    @abstractmethod
    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]: ...

    @abstractmethod
    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: ...

RowResultdataclass

Source Code in griptape/drivers/sql/base_sql_driver.py
@dataclass
class RowResult:
    cells: dict[str, Any]
  • cells instance-attribute

execute_query(query)abstractmethod

Source Code in griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def execute_query(self, query: str) -> Optional[list[RowResult]]: ...

execute_query_raw(query)abstractmethod

Source Code in griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]: ...

get_table_schema(table_name, schema=None)abstractmethod

Source Code in griptape/drivers/sql/base_sql_driver.py
@abstractmethod
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: ...

BaseStructureRunDriver

Bases:

ABC

Attributes

NameTypeDescription
envdict[str, str]Environment variables to set before running the Structure.
Source Code in griptape/drivers/structure_run/base_structure_run_driver.py
@define
class BaseStructureRunDriver(ABC):
    """Base class for Structure Run Drivers.

    Attributes:
        env: Environment variables to set before running the Structure.
    """

    env: dict[str, str] = field(default=Factory(dict), kw_only=True)

    def run(self, *args: BaseArtifact) -> BaseArtifact:
        return self.try_run(*args)

    @abstractmethod
    def try_run(self, *args: BaseArtifact) -> BaseArtifact: ...
  • env = field(default=Factory(dict), kw_only=True) class-attribute instance-attribute

run(*args)

Source Code in griptape/drivers/structure_run/base_structure_run_driver.py
def run(self, *args: BaseArtifact) -> BaseArtifact:
    return self.try_run(*args)

try_run(*args)abstractmethod

Source Code in griptape/drivers/structure_run/base_structure_run_driver.py
@abstractmethod
def try_run(self, *args: BaseArtifact) -> BaseArtifact: ...

BaseTextToSpeechDriver

Bases: SerializableMixin , ExponentialBackoffMixin , ABC

Source Code in griptape/drivers/text_to_speech/base_text_to_speech_driver.py
@define
class BaseTextToSpeechDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    model: str = field(kw_only=True, metadata={"serializable": True})

    def before_run(self, prompts: list[str]) -> None:
        EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts))

    def after_run(self) -> None:
        EventBus.publish_event(FinishTextToSpeechEvent())

    def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
        for attempt in self.retrying():
            with attempt:
                self.before_run(prompts)
                result = self.try_text_to_audio(prompts)
                self.after_run()

                return result

        raise Exception("Failed to run text to audio generation")

    @abstractmethod
    def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: ...
  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

after_run()

Source Code in griptape/drivers/text_to_speech/base_text_to_speech_driver.py
def after_run(self) -> None:
    EventBus.publish_event(FinishTextToSpeechEvent())

before_run(prompts)

Source Code in griptape/drivers/text_to_speech/base_text_to_speech_driver.py
def before_run(self, prompts: list[str]) -> None:
    EventBus.publish_event(StartTextToSpeechEvent(prompts=prompts))

run_text_to_audio(prompts)

Source Code in griptape/drivers/text_to_speech/base_text_to_speech_driver.py
def run_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
    for attempt in self.retrying():
        with attempt:
            self.before_run(prompts)
            result = self.try_text_to_audio(prompts)
            self.after_run()

            return result

    raise Exception("Failed to run text to audio generation")

try_text_to_audio(prompts)abstractmethod

Source Code in griptape/drivers/text_to_speech/base_text_to_speech_driver.py
@abstractmethod
def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact: ...

BaseVectorStoreDriver

Bases: SerializableMixin , FuturesExecutorMixin , ABC

Source Code in griptape/drivers/vector/base_vector_store_driver.py
@define
class BaseVectorStoreDriver(SerializableMixin, FuturesExecutorMixin, ABC):
    DEFAULT_QUERY_COUNT = 5

    @define
    class Entry(SerializableMixin):
        id: str = field(metadata={"serializable": True})
        vector: Optional[list[float]] = field(default=None, metadata={"serializable": True})
        score: Optional[float] = field(default=None, metadata={"serializable": True})
        meta: Optional[dict] = field(default=None, metadata={"serializable": True})
        namespace: Optional[str] = field(default=None, metadata={"serializable": True})

        def to_artifact(self) -> BaseArtifact:
            return BaseArtifact.from_json(self.meta["artifact"])  # pyright: ignore[reportOptionalSubscript]

    embedding_driver: BaseEmbeddingDriver = field(kw_only=True, metadata={"serializable": True})

    def upsert_text_artifacts(
        self,
        artifacts: list[TextArtifact] | dict[str, list[TextArtifact]],
        *,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> list[str] | dict[str, list[str]]:
        warnings.warn(
            "`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert_collection` is a drop-in replacement.",
            DeprecationWarning,
            stacklevel=2,
        )
        return self.upsert_collection(artifacts, meta=meta, **kwargs)

    def upsert_text_artifact(
        self,
        artifact: TextArtifact,
        *,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        vector_id: Optional[str] = None,
        **kwargs,
    ) -> str:
        warnings.warn(
            "`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.",
            DeprecationWarning,
            stacklevel=2,
        )
        return self.upsert(artifact, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)

    def upsert_text(
        self,
        string: str,
        *,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        vector_id: Optional[str] = None,
        **kwargs,
    ) -> str:
        warnings.warn(
            "`BaseVectorStoreDriver.upsert_text` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.",
            DeprecationWarning,
            stacklevel=2,
        )
        return self.upsert(string, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)

    @overload
    def upsert_collection(
        self,
        artifacts: list[TextArtifact] | list[ImageArtifact],
        *,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> list[str]: ...

    @overload
    def upsert_collection(
        self,
        artifacts: dict[str, list[TextArtifact]] | dict[str, list[ImageArtifact]],
        *,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> dict[str, list[str]]: ...

    def upsert_collection(
        self,
        artifacts: list[TextArtifact]
        | list[ImageArtifact]
        | dict[str, list[TextArtifact]]
        | dict[str, list[ImageArtifact]],
        *,
        meta: Optional[dict] = None,
        **kwargs,
    ):
        with self.create_futures_executor() as futures_executor:
            if isinstance(artifacts, list):
                return utils.execute_futures_list(
                    [
                        futures_executor.submit(with_contextvars(self.upsert), a, namespace=None, meta=meta, **kwargs)
                        for a in artifacts
                    ],
                )
            futures_dict = {}

            for namespace, artifact_list in artifacts.items():
                for a in artifact_list:
                    if not futures_dict.get(namespace):
                        futures_dict[namespace] = []

                    futures_dict[namespace].append(
                        futures_executor.submit(
                            with_contextvars(self.upsert), a, namespace=namespace, meta=meta, **kwargs
                        )
                    )

            return utils.execute_futures_list_dict(futures_dict)

    def upsert(
        self,
        value: str | TextArtifact | ImageArtifact,
        *,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        vector_id: Optional[str] = None,
        **kwargs,
    ) -> str:
        artifact = TextArtifact(value) if isinstance(value, str) else value

        meta = {} if meta is None else meta

        if vector_id is None:
            value = artifact.to_text() if artifact.reference is None else artifact.to_text() + str(artifact.reference)
            vector_id = self._get_default_vector_id(value)

        if self.does_entry_exist(vector_id, namespace=namespace):
            return vector_id
        meta = {**meta, "artifact": artifact.to_json()}

        vector = self.embedding_driver.embed(artifact, vector_operation="upsert")

        return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs)

    def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool:
        try:
            return self.load_entry(vector_id, namespace=namespace) is not None
        except Exception:
            return False

    def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
        result = self.load_entries(namespace=namespace)
        artifacts = [r.to_artifact() for r in result]

        return ListArtifact([a for a in artifacts if isinstance(a, TextArtifact)])

    @abstractmethod
    def delete_vector(self, vector_id: str) -> None: ...

    @abstractmethod
    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str: ...

    @abstractmethod
    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[Entry]: ...

    @abstractmethod
    def load_entries(self, *, namespace: Optional[str] = None) -> list[Entry]: ...

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[Entry]:
        # TODO: Mark as abstract method for griptape 2.0
        raise NotImplementedError(f"{self.__class__.__name__} does not support vector query.")

    def query(
        self,
        query: str | TextArtifact | ImageArtifact,
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[Entry]:
        try:
            vector = self.embedding_driver.embed(query, vector_operation="query")
        except ValueError as e:
            raise ValueError(
                "The Embedding Driver, %s, used by the Vector Store does not support embedding the %s type."
                "To resolve, provide an Embedding Driver that supports this type.",
                self.embedding_driver.__class__.__name__,
                type(query),
            ) from e
        return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs)

    def _get_default_vector_id(self, value: str) -> str:
        return str(uuid.uuid5(uuid.NAMESPACE_OID, value))
  • DEFAULT_QUERY_COUNT = 5 class-attribute instance-attribute

  • embedding_driver = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

Entry

Bases: SerializableMixin

Source Code in griptape/drivers/vector/base_vector_store_driver.py
@define
class Entry(SerializableMixin):
    id: str = field(metadata={"serializable": True})
    vector: Optional[list[float]] = field(default=None, metadata={"serializable": True})
    score: Optional[float] = field(default=None, metadata={"serializable": True})
    meta: Optional[dict] = field(default=None, metadata={"serializable": True})
    namespace: Optional[str] = field(default=None, metadata={"serializable": True})

    def to_artifact(self) -> BaseArtifact:
        return BaseArtifact.from_json(self.meta["artifact"])  # pyright: ignore[reportOptionalSubscript]
  • id = field(metadata={'serializable': True}) class-attribute instance-attribute

  • meta = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

  • namespace = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

  • score = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

  • vector = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

_get_default_vector_id(value)

Source Code in griptape/drivers/vector/base_vector_store_driver.py
def _get_default_vector_id(self, value: str) -> str:
    return str(uuid.uuid5(uuid.NAMESPACE_OID, value))

delete_vector(vector_id)abstractmethod

Source Code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def delete_vector(self, vector_id: str) -> None: ...

does_entry_exist(vector_id, *, namespace=None)

Source Code in griptape/drivers/vector/base_vector_store_driver.py
def does_entry_exist(self, vector_id: str, *, namespace: Optional[str] = None) -> bool:
    try:
        return self.load_entry(vector_id, namespace=namespace) is not None
    except Exception:
        return False

load_artifacts(*, namespace=None)

Source Code in griptape/drivers/vector/base_vector_store_driver.py
def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
    result = self.load_entries(namespace=namespace)
    artifacts = [r.to_artifact() for r in result]

    return ListArtifact([a for a in artifacts if isinstance(a, TextArtifact)])

load_entries(*, namespace=None)abstractmethod

Source Code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def load_entries(self, *, namespace: Optional[str] = None) -> list[Entry]: ...

load_entry(vector_id, *, namespace=None)abstractmethod

Source Code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[Entry]: ...

query(query, *, count=None, namespace=None, include_vectors=False, **kwargs)

Source Code in griptape/drivers/vector/base_vector_store_driver.py
def query(
    self,
    query: str | TextArtifact | ImageArtifact,
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[Entry]:
    try:
        vector = self.embedding_driver.embed(query, vector_operation="query")
    except ValueError as e:
        raise ValueError(
            "The Embedding Driver, %s, used by the Vector Store does not support embedding the %s type."
            "To resolve, provide an Embedding Driver that supports this type.",
            self.embedding_driver.__class__.__name__,
            type(query),
        ) from e
    return self.query_vector(vector, count=count, namespace=namespace, include_vectors=include_vectors, **kwargs)

query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)

Source Code in griptape/drivers/vector/base_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[Entry]:
    # TODO: Mark as abstract method for griptape 2.0
    raise NotImplementedError(f"{self.__class__.__name__} does not support vector query.")

upsert(value, *, namespace=None, meta=None, vector_id=None, **kwargs)

Source Code in griptape/drivers/vector/base_vector_store_driver.py
def upsert(
    self,
    value: str | TextArtifact | ImageArtifact,
    *,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    vector_id: Optional[str] = None,
    **kwargs,
) -> str:
    artifact = TextArtifact(value) if isinstance(value, str) else value

    meta = {} if meta is None else meta

    if vector_id is None:
        value = artifact.to_text() if artifact.reference is None else artifact.to_text() + str(artifact.reference)
        vector_id = self._get_default_vector_id(value)

    if self.does_entry_exist(vector_id, namespace=namespace):
        return vector_id
    meta = {**meta, "artifact": artifact.to_json()}

    vector = self.embedding_driver.embed(artifact, vector_operation="upsert")

    return self.upsert_vector(vector, vector_id=vector_id, namespace=namespace, meta=meta, **kwargs)

upsert_collection(artifacts, *, meta=None, **kwargs)

Source Code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_collection(
    self,
    artifacts: list[TextArtifact]
    | list[ImageArtifact]
    | dict[str, list[TextArtifact]]
    | dict[str, list[ImageArtifact]],
    *,
    meta: Optional[dict] = None,
    **kwargs,
):
    with self.create_futures_executor() as futures_executor:
        if isinstance(artifacts, list):
            return utils.execute_futures_list(
                [
                    futures_executor.submit(with_contextvars(self.upsert), a, namespace=None, meta=meta, **kwargs)
                    for a in artifacts
                ],
            )
        futures_dict = {}

        for namespace, artifact_list in artifacts.items():
            for a in artifact_list:
                if not futures_dict.get(namespace):
                    futures_dict[namespace] = []

                futures_dict[namespace].append(
                    futures_executor.submit(
                        with_contextvars(self.upsert), a, namespace=namespace, meta=meta, **kwargs
                    )
                )

        return utils.execute_futures_list_dict(futures_dict)

upsert_text(string, *, namespace=None, meta=None, vector_id=None, **kwargs)

Source Code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_text(
    self,
    string: str,
    *,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    vector_id: Optional[str] = None,
    **kwargs,
) -> str:
    warnings.warn(
        "`BaseVectorStoreDriver.upsert_text` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.",
        DeprecationWarning,
        stacklevel=2,
    )
    return self.upsert(string, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)

upsert_text_artifact(artifact, *, namespace=None, meta=None, vector_id=None, **kwargs)

Source Code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_text_artifact(
    self,
    artifact: TextArtifact,
    *,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    vector_id: Optional[str] = None,
    **kwargs,
) -> str:
    warnings.warn(
        "`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert` is a drop-in replacement.",
        DeprecationWarning,
        stacklevel=2,
    )
    return self.upsert(artifact, namespace=namespace, meta=meta, vector_id=vector_id, **kwargs)

upsert_text_artifacts(artifacts, *, meta=None, **kwargs)

Source Code in griptape/drivers/vector/base_vector_store_driver.py
def upsert_text_artifacts(
    self,
    artifacts: list[TextArtifact] | dict[str, list[TextArtifact]],
    *,
    meta: Optional[dict] = None,
    **kwargs,
) -> list[str] | dict[str, list[str]]:
    warnings.warn(
        "`BaseVectorStoreDriver.upsert_text_artifacts` is deprecated and will be removed in a future release. `BaseVectorStoreDriver.upsert_collection` is a drop-in replacement.",
        DeprecationWarning,
        stacklevel=2,
    )
    return self.upsert_collection(artifacts, meta=meta, **kwargs)

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)abstractmethod

Source Code in griptape/drivers/vector/base_vector_store_driver.py
@abstractmethod
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str: ...

BaseWebScraperDriver

Bases:

ABC
Source Code in griptape/drivers/web_scraper/base_web_scraper_driver.py
class BaseWebScraperDriver(ABC):
    def scrape_url(self, url: str) -> TextArtifact:
        source = self.fetch_url(url)

        return self.extract_page(source)

    @abstractmethod
    def fetch_url(self, url: str) -> str: ...

    @abstractmethod
    def extract_page(self, page: str) -> TextArtifact: ...

extract_page(page)abstractmethod

Source Code in griptape/drivers/web_scraper/base_web_scraper_driver.py
@abstractmethod
def extract_page(self, page: str) -> TextArtifact: ...

fetch_url(url)abstractmethod

Source Code in griptape/drivers/web_scraper/base_web_scraper_driver.py
@abstractmethod
def fetch_url(self, url: str) -> str: ...

scrape_url(url)

Source Code in griptape/drivers/web_scraper/base_web_scraper_driver.py
def scrape_url(self, url: str) -> TextArtifact:
    source = self.fetch_url(url)

    return self.extract_page(source)

BaseWebSearchDriver

Bases:

ABC
Source Code in griptape/drivers/web_search/base_web_search_driver.py
@define
class BaseWebSearchDriver(ABC):
    results_count: int = field(default=5, kw_only=True)

    @abstractmethod
    def search(self, query: str, **kwargs) -> ListArtifact: ...
  • results_count = field(default=5, kw_only=True) class-attribute instance-attribute

search(query, **kwargs)abstractmethod

Source Code in griptape/drivers/web_search/base_web_search_driver.py
@abstractmethod
def search(self, query: str, **kwargs) -> ListArtifact: ...

BedrockStableDiffusionImageGenerationModelDriver

Bases: BaseImageGenerationModelDriver

Attributes

NameTypeDescription
cfg_scaleintSpecifies how strictly image generation follows the provided prompt. Defaults to 7.
mask_sourceintSpecifies mask image configuration for image-to-image generations. Defaults to "MASK_IMAGE_BLACK".
style_presetOptional[str]If provided, specifies a specific image generation style preset.
clip_guidance_presetOptional[str]If provided, requests a specific clip guidance preset to be used in the diffusion process.
samplerOptional[str]If provided, requests a specific sampler to be used in the diffusion process.
stepsOptional[int]If provided, specifies the number of diffusion steps to use in the image generation.
start_scheduleOptional[float]If provided, specifies the start_schedule parameter used to determine the influence of the input image in image-to-image generation.
Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
@define
class BedrockStableDiffusionImageGenerationModelDriver(BaseImageGenerationModelDriver):
    """Image generation model driver for Stable Diffusion models on Amazon Bedrock.

    For more information on all supported parameters, see the Stable Diffusion documentation:
        https://platform.stability.ai/docs/api-reference#tag/v1generation

    Attributes:
        cfg_scale: Specifies how strictly image generation follows the provided prompt. Defaults to 7.
        mask_source: Specifies mask image configuration for image-to-image generations. Defaults to "MASK_IMAGE_BLACK".
        style_preset: If provided, specifies a specific image generation style preset.
        clip_guidance_preset: If provided, requests a specific clip guidance preset to be used in the diffusion process.
        sampler: If provided, requests a specific sampler to be used in the diffusion process.
        steps: If provided, specifies the number of diffusion steps to use in the image generation.
        start_schedule: If provided, specifies the start_schedule parameter used to determine the influence of the input
            image in image-to-image generation.
    """

    cfg_scale: int = field(default=7, kw_only=True, metadata={"serializable": True})
    style_preset: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    clip_guidance_preset: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    sampler: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    steps: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    start_schedule: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})

    def text_to_image_request_parameters(
        self,
        prompts: list[str],
        image_width: int,
        image_height: int,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        return self._request_parameters(
            prompts,
            width=image_width,
            height=image_height,
            negative_prompts=negative_prompts,
            seed=seed,
        )

    def image_variation_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        return self._request_parameters(prompts, image=image, negative_prompts=negative_prompts, seed=seed)

    def image_inpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        return self._request_parameters(
            prompts,
            image=image,
            mask=mask,
            mask_source="MASK_IMAGE_BLACK",
            negative_prompts=negative_prompts,
            seed=seed,
        )

    def image_outpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        return self._request_parameters(
            prompts,
            image=image,
            mask=mask,
            mask_source="MASK_IMAGE_WHITE",
            negative_prompts=negative_prompts,
            seed=seed,
        )

    def _request_parameters(
        self,
        prompts: list[str],
        width: Optional[int] = None,
        height: Optional[int] = None,
        image: Optional[ImageArtifact] = None,
        mask: Optional[ImageArtifact] = None,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
        mask_source: Optional[str] = None,
    ) -> dict:
        if negative_prompts is None:
            negative_prompts = []

        text_prompts = [{"text": prompt, "weight": 1.0} for prompt in prompts]
        text_prompts += [{"text": negative_prompt, "weight": -1.0} for negative_prompt in negative_prompts]

        request = {
            "text_prompts": text_prompts,
            "cfg_scale": self.cfg_scale,
            "style_preset": self.style_preset,
            "clip_guidance_preset": self.clip_guidance_preset,
            "sampler": self.sampler,
            "steps": self.steps,
            "seed": seed,
            "start_schedule": self.start_schedule,
        }

        if image is not None:
            request["init_image"] = image.base64
            request["width"] = image.width
            request["height"] = image.height
        else:
            request["width"] = width
            request["height"] = height

        if mask is not None:
            if not mask_source:
                raise ValueError("mask_source must be provided when mask is provided")

            request["mask_source"] = mask_source
            request["mask_image"] = mask.base64

        return {k: v for k, v in request.items() if v is not None}

    def get_generated_image(self, response: dict) -> bytes:
        image_response = response["artifacts"][0]

        # finishReason may be SUCCESS, CONTENT_FILTERED, or ERROR.
        if image_response.get("finishReason") == "ERROR":
            raise Exception(f"Image generation failed: {image_response.get('finishReason')}")
        if image_response.get("finishReason") == "CONTENT_FILTERED":
            logging.warning("Image generation triggered content filter and may be blurred")

        return base64.decodebytes(bytes(image_response.get("base64"), "utf-8"))
  • cfg_scale = field(default=7, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • clip_guidance_preset = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • sampler = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • start_schedule = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • steps = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • style_preset = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

_request_parameters(prompts, width=None, height=None, image=None, mask=None, negative_prompts=None, seed=None, mask_source=None)

Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def _request_parameters(
    self,
    prompts: list[str],
    width: Optional[int] = None,
    height: Optional[int] = None,
    image: Optional[ImageArtifact] = None,
    mask: Optional[ImageArtifact] = None,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
    mask_source: Optional[str] = None,
) -> dict:
    if negative_prompts is None:
        negative_prompts = []

    text_prompts = [{"text": prompt, "weight": 1.0} for prompt in prompts]
    text_prompts += [{"text": negative_prompt, "weight": -1.0} for negative_prompt in negative_prompts]

    request = {
        "text_prompts": text_prompts,
        "cfg_scale": self.cfg_scale,
        "style_preset": self.style_preset,
        "clip_guidance_preset": self.clip_guidance_preset,
        "sampler": self.sampler,
        "steps": self.steps,
        "seed": seed,
        "start_schedule": self.start_schedule,
    }

    if image is not None:
        request["init_image"] = image.base64
        request["width"] = image.width
        request["height"] = image.height
    else:
        request["width"] = width
        request["height"] = height

    if mask is not None:
        if not mask_source:
            raise ValueError("mask_source must be provided when mask is provided")

        request["mask_source"] = mask_source
        request["mask_image"] = mask.base64

    return {k: v for k, v in request.items() if v is not None}

get_generated_image(response)

Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def get_generated_image(self, response: dict) -> bytes:
    image_response = response["artifacts"][0]

    # finishReason may be SUCCESS, CONTENT_FILTERED, or ERROR.
    if image_response.get("finishReason") == "ERROR":
        raise Exception(f"Image generation failed: {image_response.get('finishReason')}")
    if image_response.get("finishReason") == "CONTENT_FILTERED":
        logging.warning("Image generation triggered content filter and may be blurred")

    return base64.decodebytes(bytes(image_response.get("base64"), "utf-8"))

image_inpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)

Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def image_inpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    return self._request_parameters(
        prompts,
        image=image,
        mask=mask,
        mask_source="MASK_IMAGE_BLACK",
        negative_prompts=negative_prompts,
        seed=seed,
    )

image_outpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)

Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def image_outpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    return self._request_parameters(
        prompts,
        image=image,
        mask=mask,
        mask_source="MASK_IMAGE_WHITE",
        negative_prompts=negative_prompts,
        seed=seed,
    )

image_variation_request_parameters(prompts, image, negative_prompts=None, seed=None)

Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def image_variation_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    return self._request_parameters(prompts, image=image, negative_prompts=negative_prompts, seed=seed)

text_to_image_request_parameters(prompts, image_width, image_height, negative_prompts=None, seed=None)

Source Code in griptape/drivers/image_generation_model/bedrock_stable_diffusion_image_generation_model_driver.py
def text_to_image_request_parameters(
    self,
    prompts: list[str],
    image_width: int,
    image_height: int,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    return self._request_parameters(
        prompts,
        width=image_width,
        height=image_height,
        negative_prompts=negative_prompts,
        seed=seed,
    )

BedrockTitanImageGenerationModelDriver

Bases: BaseImageGenerationModelDriver

Attributes

NameTypeDescription
qualitystrThe quality of the generated image, defaults to standard.
cfg_scaleintSpecifies how strictly image generation follows the provided prompt. Defaults to 7, (1.0 to 10.0].
outpainting_modestrSpecifies the outpainting mode, defaults to PRECISE.
Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
@define
class BedrockTitanImageGenerationModelDriver(BaseImageGenerationModelDriver):
    """Image Generation Model Driver for Amazon Bedrock Titan Image Generator.

    Attributes:
        quality: The quality of the generated image, defaults to standard.
        cfg_scale: Specifies how strictly image generation follows the provided prompt. Defaults to 7, (1.0 to 10.0].
        outpainting_mode: Specifies the outpainting mode, defaults to PRECISE.
    """

    quality: str = field(default="standard", kw_only=True, metadata={"serializable": True})
    cfg_scale: int = field(default=7, kw_only=True, metadata={"serializable": True})
    outpainting_mode: str = field(default="PRECISE", kw_only=True, metadata={"serializable": True})

    def text_to_image_request_parameters(
        self,
        prompts: list[str],
        image_width: int,
        image_height: int,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        prompt = ", ".join(prompts)

        request = {
            "taskType": "TEXT_IMAGE",
            "textToImageParams": {"text": prompt},
            "imageGenerationConfig": {
                "numberOfImages": 1,
                "quality": self.quality,
                "width": image_width,
                "height": image_height,
                "cfgScale": self.cfg_scale,
            },
        }

        if negative_prompts:
            request["textToImageParams"]["negativeText"] = ", ".join(negative_prompts)

        if seed:
            request["imageGenerationConfig"]["seed"] = seed

        return self._add_common_params(request, image_width, image_height, seed=seed)

    def image_variation_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        prompt = ", ".join(prompts)

        request = {
            "taskType": "IMAGE_VARIATION",
            "imageVariationParams": {"text": prompt, "images": [image.base64]},
            "imageGenerationConfig": {
                "numberOfImages": 1,
                "quality": self.quality,
                "width": image.width,
                "height": image.height,
                "cfgScale": self.cfg_scale,
            },
        }

        if negative_prompts:
            request["imageVariationParams"]["negativeText"] = ", ".join(negative_prompts)

        if seed:
            request["imageGenerationConfig"]["seed"] = seed

        return self._add_common_params(request, image.width, image.height, seed=seed)

    def image_inpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        prompt = ", ".join(prompts)

        request = {
            "taskType": "INPAINTING",
            "inPaintingParams": {"text": prompt, "image": image.base64, "maskImage": mask.base64},
        }

        if negative_prompts:
            request["inPaintingParams"]["negativeText"] = ", ".join(negative_prompts)

        return self._add_common_params(request, image.width, image.height, seed=seed)

    def image_outpainting_request_parameters(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
        seed: Optional[int] = None,
    ) -> dict:
        prompt = ", ".join(prompts)

        request = {
            "taskType": "OUTPAINTING",
            "outPaintingParams": {
                "text": prompt,
                "image": image.base64,
                "maskImage": mask.base64,
                "outPaintingMode": self.outpainting_mode,
            },
        }

        if negative_prompts:
            request["outPaintingParams"]["negativeText"] = ", ".join(negative_prompts)

        return self._add_common_params(request, image.width, image.height, seed=seed)

    def get_generated_image(self, response: dict) -> bytes:
        b64_image_data = response["images"][0]

        return base64.decodebytes(bytes(b64_image_data, "utf-8"))

    def _add_common_params(self, request: dict[str, Any], width: int, height: int, seed: Optional[int] = None) -> dict:
        request["imageGenerationConfig"] = {
            "numberOfImages": 1,
            "quality": self.quality,
            "width": width,
            "height": height,
            "cfgScale": self.cfg_scale,
        }

        if seed:
            request["imageGenerationConfig"]["seed"] = seed

        return request
  • cfg_scale = field(default=7, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • outpainting_mode = field(default='PRECISE', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • quality = field(default='standard', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

_add_common_params(request, width, height, seed=None)

Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def _add_common_params(self, request: dict[str, Any], width: int, height: int, seed: Optional[int] = None) -> dict:
    request["imageGenerationConfig"] = {
        "numberOfImages": 1,
        "quality": self.quality,
        "width": width,
        "height": height,
        "cfgScale": self.cfg_scale,
    }

    if seed:
        request["imageGenerationConfig"]["seed"] = seed

    return request

get_generated_image(response)

Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def get_generated_image(self, response: dict) -> bytes:
    b64_image_data = response["images"][0]

    return base64.decodebytes(bytes(b64_image_data, "utf-8"))

image_inpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)

Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def image_inpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    prompt = ", ".join(prompts)

    request = {
        "taskType": "INPAINTING",
        "inPaintingParams": {"text": prompt, "image": image.base64, "maskImage": mask.base64},
    }

    if negative_prompts:
        request["inPaintingParams"]["negativeText"] = ", ".join(negative_prompts)

    return self._add_common_params(request, image.width, image.height, seed=seed)

image_outpainting_request_parameters(prompts, image, mask, negative_prompts=None, seed=None)

Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def image_outpainting_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    prompt = ", ".join(prompts)

    request = {
        "taskType": "OUTPAINTING",
        "outPaintingParams": {
            "text": prompt,
            "image": image.base64,
            "maskImage": mask.base64,
            "outPaintingMode": self.outpainting_mode,
        },
    }

    if negative_prompts:
        request["outPaintingParams"]["negativeText"] = ", ".join(negative_prompts)

    return self._add_common_params(request, image.width, image.height, seed=seed)

image_variation_request_parameters(prompts, image, negative_prompts=None, seed=None)

Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def image_variation_request_parameters(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    prompt = ", ".join(prompts)

    request = {
        "taskType": "IMAGE_VARIATION",
        "imageVariationParams": {"text": prompt, "images": [image.base64]},
        "imageGenerationConfig": {
            "numberOfImages": 1,
            "quality": self.quality,
            "width": image.width,
            "height": image.height,
            "cfgScale": self.cfg_scale,
        },
    }

    if negative_prompts:
        request["imageVariationParams"]["negativeText"] = ", ".join(negative_prompts)

    if seed:
        request["imageGenerationConfig"]["seed"] = seed

    return self._add_common_params(request, image.width, image.height, seed=seed)

text_to_image_request_parameters(prompts, image_width, image_height, negative_prompts=None, seed=None)

Source Code in griptape/drivers/image_generation_model/bedrock_titan_image_generation_model_driver.py
def text_to_image_request_parameters(
    self,
    prompts: list[str],
    image_width: int,
    image_height: int,
    negative_prompts: Optional[list[str]] = None,
    seed: Optional[int] = None,
) -> dict:
    prompt = ", ".join(prompts)

    request = {
        "taskType": "TEXT_IMAGE",
        "textToImageParams": {"text": prompt},
        "imageGenerationConfig": {
            "numberOfImages": 1,
            "quality": self.quality,
            "width": image_width,
            "height": image_height,
            "cfgScale": self.cfg_scale,
        },
    }

    if negative_prompts:
        request["textToImageParams"]["negativeText"] = ", ".join(negative_prompts)

    if seed:
        request["imageGenerationConfig"]["seed"] = seed

    return self._add_common_params(request, image_width, image_height, seed=seed)

CohereEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes

NameTypeDescription
api_keystrCohere API key.
modelstrCohere model name.
clientClientCustom cohere.Client.
tokenizerCohereTokenizerCustom CohereTokenizer.
input_typestrCohere embedding input type.
Source Code in griptape/drivers/embedding/cohere_embedding_driver.py
@define
class CohereEmbeddingDriver(BaseEmbeddingDriver):
    """Cohere Embedding Driver.

    Attributes:
        api_key: Cohere API key.
        model: 	Cohere model name.
        client: Custom `cohere.Client`.
        tokenizer: Custom `CohereTokenizer`.
        input_type: Cohere embedding input type.
    """

    DEFAULT_MODEL = "models/embedding-001"

    api_key: str = field(kw_only=True, metadata={"serializable": False})
    input_type: str = field(kw_only=True, metadata={"serializable": True})
    _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
    tokenizer: CohereTokenizer = field(
        default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
        kw_only=True,
    )

    @lazy_property()
    def client(self) -> Client:
        return import_optional_dependency("cohere").Client(self.api_key)

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)

        if isinstance(result.embeddings, list):
            return result.embeddings[0]
        raise ValueError("Non-float embeddings are not supported.")
  • DEFAULT_MODEL = 'models/embedding-001' class-attribute instance-attribute

  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • input_type = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), kw_only=True) class-attribute instance-attribute

client()

Source Code in griptape/drivers/embedding/cohere_embedding_driver.py
@lazy_property()
def client(self) -> Client:
    return import_optional_dependency("cohere").Client(self.api_key)

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/cohere_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    result = self.client.embed(texts=[chunk], model=self.model, input_type=self.input_type)

    if isinstance(result.embeddings, list):
        return result.embeddings[0]
    raise ValueError("Non-float embeddings are not supported.")

CoherePromptDriver

Bases: BasePromptDriver

Attributes

NameTypeDescription
api_keystrCohere API key.
modelstrCohere model name.
clientClientV2Custom cohere.Client.
Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@define(kw_only=True)
class CoherePromptDriver(BasePromptDriver):
    """Cohere Prompt Driver.

    Attributes:
        api_key: Cohere API key.
        model: 	Cohere model name.
        client: Custom `cohere.Client`.
    """

    api_key: str = field(default=None, metadata={"serializable": False})
    model: str = field(metadata={"serializable": True})
    force_single_step: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    _client: Optional[ClientV2] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True),
    )

    @lazy_property()
    def client(self) -> ClientV2:
        return import_optional_dependency("cohere").ClientV2(self.api_key, log_warning_experimental_features=False)

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        params = self._base_params(prompt_stack)

        logger.debug(params)

        result: ChatResponse = self.client.chat(**params)
        logger.debug(result.model_dump())

        return Message(
            content=self.__to_prompt_stack_message_content(result.message),
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(
                input_tokens=result.usage.tokens.input_tokens if result.usage and result.usage.tokens else 0,
                output_tokens=result.usage.tokens.output_tokens if result.usage and result.usage.tokens else 0,
            ),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        params = self._base_params(prompt_stack)
        logger.debug(params)
        result: Iterator[StreamedChatResponseV2] = self.client.chat_stream(**params)

        for event in result:
            if event.type == "stream-end":
                usage = event.response.meta.tokens

                yield DeltaMessage(
                    usage=DeltaMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
                )
            elif event.type in ("tool-plan-delta", "content-delta", "tool-call-start", "tool-call-delta"):
                yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        tool_results = []

        messages = self.__to_cohere_messages(prompt_stack.messages)

        params = {
            "model": self.model,
            "messages": messages,
            "temperature": self.temperature,
            "stop_sequences": self.tokenizer.stop_sequences,
            "max_tokens": self.max_tokens,
            **({"tool_results": tool_results} if tool_results else {}),
            **self.extra_params,
        }

        if prompt_stack.output_schema is not None and self.structured_output_strategy == "native":
            params["response_format"] = {"type": "json_object", "schema": prompt_stack.to_output_json_schema()}

        if prompt_stack.tools and self.use_native_tools:
            params["tools"] = self.__to_cohere_tools(prompt_stack.tools)

        return params

    def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
        cohere_messages = []

        for message in messages:
            # If the message only contains textual content we can send it as a single content.
            if message.is_text():
                cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()})
            # Action results must be sent as separate messages.
            elif message.has_any_content_type(ActionResultMessageContent):
                cohere_messages.extend(
                    {
                        "role": self.__to_cohere_role(message, action_result),
                        "content": self.__to_cohere_message_content(action_result),
                        "tool_call_id": action_result.action.tag,
                    }
                    for action_result in message.get_content_type(ActionResultMessageContent)
                )

                if message.has_any_content_type(TextMessageContent):
                    cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()})
            else:
                cohere_message = {
                    "role": self.__to_cohere_role(message),
                    "content": [
                        self.__to_cohere_message_content(content)
                        for content in [
                            content for content in message.content if not isinstance(content, ActionCallMessageContent)
                        ]
                    ],
                }

                # Action calls must be attached to the message, not sent as content.
                action_call_content = [
                    content for content in message.content if isinstance(content, ActionCallMessageContent)
                ]
                if action_call_content:
                    cohere_message["tool_calls"] = [
                        self.__to_cohere_message_content(action_call) for action_call in action_call_content
                    ]

                cohere_messages.append(cohere_message)

        return cohere_messages

    def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict | list[dict]:
        if isinstance(content, ActionCallMessageContent):
            action = content.artifact.value

            return {
                "type": "function",
                "id": action.tag,
                "function": {
                    "name": action.to_native_tool_name(),
                    "arguments": json.dumps(action.input),
                },
            }
        if isinstance(content, ActionResultMessageContent):
            artifact = content.artifact

            if isinstance(artifact, ListArtifact):
                message_content = [{"type": "text", "text": artifact.to_text()} for artifact in artifact.value]
            else:
                message_content = {"type": "text", "text": artifact.to_text()}

            return message_content
        return {"type": "text", "text": content.artifact.to_text()}

    def __to_cohere_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str:
        if message.is_system():
            return "system"
        if message.is_assistant():
            return "assistant"
        if isinstance(message_content, ActionResultMessageContent):
            return "tool"
        return "user"

    def __to_cohere_tools(self, tools: list[BaseTool]) -> list[dict]:
        return [
            {
                "function": {
                    "name": tool.to_native_tool_name(activity),
                    "description": tool.activity_description(activity),
                    "parameters": tool.to_activity_json_schema(activity, "Parameters Schema"),
                },
                "type": "function",
            }
            for tool in tools
            for activity in tool.activities()
        ]

    def __to_prompt_stack_message_content(self, response: AssistantMessageResponse) -> list[BaseMessageContent]:
        content = []
        if response.content:
            content.extend([TextMessageContent(TextArtifact(content.text)) for content in response.content])
        if response.tool_plan:
            content.append(TextMessageContent(TextArtifact(response.tool_plan)))
        if response.tool_calls is not None:
            content.extend(
                [
                    ActionCallMessageContent(
                        ActionArtifact(
                            ToolAction(
                                tag=tool_call.id,
                                name=ToolAction.from_native_tool_name(tool_call.function.name)[0],
                                path=ToolAction.from_native_tool_name(tool_call.function.name)[1],
                                input=json.loads(tool_call.function.arguments),
                            ),
                        ),
                    )
                    for tool_call in response.tool_calls
                    if tool_call.id is not None
                    and tool_call.function is not None
                    and tool_call.function.name is not None
                    and tool_call.function.arguments is not None
                ],
            )

        return content

    def __to_prompt_stack_delta_message_content(self, event: Any) -> BaseDeltaMessageContent:
        if event.type == "content-delta":
            return TextDeltaMessageContent(event.delta.message.content.text, index=0)
        if event.type == "tool-plan-delta":
            return TextDeltaMessageContent(event.delta.message["tool_plan"])
        if event.type == "tool-call-start":
            tool_call_delta = event.delta.message["tool_calls"]
            name, path = ToolAction.from_native_tool_name(tool_call_delta["function"]["name"])

            return ActionCallDeltaMessageContent(tag=tool_call_delta["id"], name=name, path=path)
        if event.type == "tool-call-delta":
            tool_call_delta = event.delta.message["tool_calls"]["function"]

            return ActionCallDeltaMessageContent(partial_input=tool_call_delta["arguments"])
        raise ValueError(f"Unsupported event type: {event.type}")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, metadata={'serializable': False}) class-attribute instance-attribute

  • force_single_step = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True)) class-attribute instance-attribute

  • use_native_tools = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_cohere_message_content(content)

Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_message_content(self, content: BaseMessageContent) -> str | dict | list[dict]:
    if isinstance(content, ActionCallMessageContent):
        action = content.artifact.value

        return {
            "type": "function",
            "id": action.tag,
            "function": {
                "name": action.to_native_tool_name(),
                "arguments": json.dumps(action.input),
            },
        }
    if isinstance(content, ActionResultMessageContent):
        artifact = content.artifact

        if isinstance(artifact, ListArtifact):
            message_content = [{"type": "text", "text": artifact.to_text()} for artifact in artifact.value]
        else:
            message_content = {"type": "text", "text": artifact.to_text()}

        return message_content
    return {"type": "text", "text": content.artifact.to_text()}

__to_cohere_messages(messages)

Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_messages(self, messages: list[Message]) -> list[dict]:
    cohere_messages = []

    for message in messages:
        # If the message only contains textual content we can send it as a single content.
        if message.is_text():
            cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()})
        # Action results must be sent as separate messages.
        elif message.has_any_content_type(ActionResultMessageContent):
            cohere_messages.extend(
                {
                    "role": self.__to_cohere_role(message, action_result),
                    "content": self.__to_cohere_message_content(action_result),
                    "tool_call_id": action_result.action.tag,
                }
                for action_result in message.get_content_type(ActionResultMessageContent)
            )

            if message.has_any_content_type(TextMessageContent):
                cohere_messages.append({"role": self.__to_cohere_role(message), "content": message.to_text()})
        else:
            cohere_message = {
                "role": self.__to_cohere_role(message),
                "content": [
                    self.__to_cohere_message_content(content)
                    for content in [
                        content for content in message.content if not isinstance(content, ActionCallMessageContent)
                    ]
                ],
            }

            # Action calls must be attached to the message, not sent as content.
            action_call_content = [
                content for content in message.content if isinstance(content, ActionCallMessageContent)
            ]
            if action_call_content:
                cohere_message["tool_calls"] = [
                    self.__to_cohere_message_content(action_call) for action_call in action_call_content
                ]

            cohere_messages.append(cohere_message)

    return cohere_messages

__to_cohere_role(message, message_content=None)

Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str:
    if message.is_system():
        return "system"
    if message.is_assistant():
        return "assistant"
    if isinstance(message_content, ActionResultMessageContent):
        return "tool"
    return "user"

__to_cohere_tools(tools)

Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_cohere_tools(self, tools: list[BaseTool]) -> list[dict]:
    return [
        {
            "function": {
                "name": tool.to_native_tool_name(activity),
                "description": tool.activity_description(activity),
                "parameters": tool.to_activity_json_schema(activity, "Parameters Schema"),
            },
            "type": "function",
        }
        for tool in tools
        for activity in tool.activities()
    ]

__to_prompt_stack_delta_message_content(event)

Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, event: Any) -> BaseDeltaMessageContent:
    if event.type == "content-delta":
        return TextDeltaMessageContent(event.delta.message.content.text, index=0)
    if event.type == "tool-plan-delta":
        return TextDeltaMessageContent(event.delta.message["tool_plan"])
    if event.type == "tool-call-start":
        tool_call_delta = event.delta.message["tool_calls"]
        name, path = ToolAction.from_native_tool_name(tool_call_delta["function"]["name"])

        return ActionCallDeltaMessageContent(tag=tool_call_delta["id"], name=name, path=path)
    if event.type == "tool-call-delta":
        tool_call_delta = event.delta.message["tool_calls"]["function"]

        return ActionCallDeltaMessageContent(partial_input=tool_call_delta["arguments"])
    raise ValueError(f"Unsupported event type: {event.type}")

__to_prompt_stack_message_content(response)

Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def __to_prompt_stack_message_content(self, response: AssistantMessageResponse) -> list[BaseMessageContent]:
    content = []
    if response.content:
        content.extend([TextMessageContent(TextArtifact(content.text)) for content in response.content])
    if response.tool_plan:
        content.append(TextMessageContent(TextArtifact(response.tool_plan)))
    if response.tool_calls is not None:
        content.extend(
            [
                ActionCallMessageContent(
                    ActionArtifact(
                        ToolAction(
                            tag=tool_call.id,
                            name=ToolAction.from_native_tool_name(tool_call.function.name)[0],
                            path=ToolAction.from_native_tool_name(tool_call.function.name)[1],
                            input=json.loads(tool_call.function.arguments),
                        ),
                    ),
                )
                for tool_call in response.tool_calls
                if tool_call.id is not None
                and tool_call.function is not None
                and tool_call.function.name is not None
                and tool_call.function.arguments is not None
            ],
        )

    return content

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    tool_results = []

    messages = self.__to_cohere_messages(prompt_stack.messages)

    params = {
        "model": self.model,
        "messages": messages,
        "temperature": self.temperature,
        "stop_sequences": self.tokenizer.stop_sequences,
        "max_tokens": self.max_tokens,
        **({"tool_results": tool_results} if tool_results else {}),
        **self.extra_params,
    }

    if prompt_stack.output_schema is not None and self.structured_output_strategy == "native":
        params["response_format"] = {"type": "json_object", "schema": prompt_stack.to_output_json_schema()}

    if prompt_stack.tools and self.use_native_tools:
        params["tools"] = self.__to_cohere_tools(prompt_stack.tools)

    return params

client()

Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@lazy_property()
def client(self) -> ClientV2:
    return import_optional_dependency("cohere").ClientV2(self.api_key, log_warning_experimental_features=False)

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    params = self._base_params(prompt_stack)

    logger.debug(params)

    result: ChatResponse = self.client.chat(**params)
    logger.debug(result.model_dump())

    return Message(
        content=self.__to_prompt_stack_message_content(result.message),
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(
            input_tokens=result.usage.tokens.input_tokens if result.usage and result.usage.tokens else 0,
            output_tokens=result.usage.tokens.output_tokens if result.usage and result.usage.tokens else 0,
        ),
    )

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/cohere_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    params = self._base_params(prompt_stack)
    logger.debug(params)
    result: Iterator[StreamedChatResponseV2] = self.client.chat_stream(**params)

    for event in result:
        if event.type == "stream-end":
            usage = event.response.meta.tokens

            yield DeltaMessage(
                usage=DeltaMessage.Usage(input_tokens=usage.input_tokens, output_tokens=usage.output_tokens),
            )
        elif event.type in ("tool-plan-delta", "content-delta", "tool-call-start", "tool-call-delta"):
            yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(event))

CohereRerankDriver

Bases: BaseRerankDriver

Source Code in griptape/drivers/rerank/cohere_rerank_driver.py
@define(kw_only=True)
class CohereRerankDriver(BaseRerankDriver):
    model: str = field(default="rerank-english-v3.0", metadata={"serializable": True})
    top_n: Optional[int] = field(default=None)

    api_key: str = field(metadata={"serializable": True})
    client: Client = field(
        default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True),
    )

    def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
        # Cohere errors out if passed "empty" documents or no documents at all
        artifacts_dict = {str(hash(a.to_text())): a for a in artifacts if a}

        if artifacts_dict:
            response = self.client.rerank(
                model=self.model,
                query=query,
                documents=[a.to_text() for a in artifacts_dict.values()],
                return_documents=True,
                top_n=self.top_n,
            )
            return [artifacts_dict[str(hash(r.document.text))] for r in response.results if r.document is not None]
        return []
  • api_key = field(metadata={'serializable': True}) class-attribute instance-attribute

  • client = field(default=Factory(lambda self: import_optional_dependency('cohere').Client(self.api_key), takes_self=True)) class-attribute instance-attribute

  • model = field(default='rerank-english-v3.0', metadata={'serializable': True}) class-attribute instance-attribute

  • top_n = field(default=None) class-attribute instance-attribute

run(query, artifacts)

Source Code in griptape/drivers/rerank/cohere_rerank_driver.py
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
    # Cohere errors out if passed "empty" documents or no documents at all
    artifacts_dict = {str(hash(a.to_text())): a for a in artifacts if a}

    if artifacts_dict:
        response = self.client.rerank(
            model=self.model,
            query=query,
            documents=[a.to_text() for a in artifacts_dict.values()],
            return_documents=True,
            top_n=self.top_n,
        )
        return [artifacts_dict[str(hash(r.document.text))] for r in response.results if r.document is not None]
    return []

DatadogObservabilityDriver

Bases: OpenTelemetryObservabilityDriver

Source Code in griptape/drivers/observability/datadog_observability_driver.py
@define
class DatadogObservabilityDriver(OpenTelemetryObservabilityDriver):
    datadog_agent_endpoint: str = field(
        default=Factory(lambda: os.getenv("DD_AGENT_ENDPOINT", "http://localhost:4318")), kw_only=True
    )
    span_processor: SpanProcessor = field(
        default=Factory(
            lambda self: import_optional_dependency("opentelemetry.sdk.trace.export").BatchSpanProcessor(
                import_optional_dependency("opentelemetry.exporter.otlp.proto.http.trace_exporter").OTLPSpanExporter(
                    endpoint=f"{self.datadog_agent_endpoint}/v1/traces"
                )
            ),
            takes_self=True,
        ),
        kw_only=True,
    )
  • datadog_agent_endpoint = field(default=Factory(lambda: os.getenv('DD_AGENT_ENDPOINT', 'http://localhost:4318')), kw_only=True) class-attribute instance-attribute

  • span_processor = field(default=Factory(lambda self: import_optional_dependency('opentelemetry.sdk.trace.export').BatchSpanProcessor(import_optional_dependency('opentelemetry.exporter.otlp.proto.http.trace_exporter').OTLPSpanExporter(endpoint=f'{self.datadog_agent_endpoint}/v1/traces')), takes_self=True), kw_only=True) class-attribute instance-attribute

DuckDuckGoWebSearchDriver

Bases: BaseWebSearchDriver

Source Code in griptape/drivers/web_search/duck_duck_go_web_search_driver.py
@define
class DuckDuckGoWebSearchDriver(BaseWebSearchDriver):
    language: str = field(default="en", kw_only=True)
    country: str = field(default="us", kw_only=True)
    _client: Optional[DDGS] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> DDGS:
        return import_optional_dependency("duckduckgo_search").DDGS()

    def search(self, query: str, **kwargs) -> ListArtifact:
        try:
            results = self.client.text(
                query, region=f"{self.language}-{self.country}", max_results=self.results_count, **kwargs
            )
            return ListArtifact(
                [
                    TextArtifact(
                        json.dumps({"title": result["title"], "url": result["href"], "description": result["body"]}),
                    )
                    for result in results
                ],
            )
        except Exception as e:
            raise Exception(f"Error searching '{query}' with DuckDuckGo: {e}") from e
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • country = field(default='us', kw_only=True) class-attribute instance-attribute

  • language = field(default='en', kw_only=True) class-attribute instance-attribute

client()

Source Code in griptape/drivers/web_search/duck_duck_go_web_search_driver.py
@lazy_property()
def client(self) -> DDGS:
    return import_optional_dependency("duckduckgo_search").DDGS()

search(query, **kwargs)

Source Code in griptape/drivers/web_search/duck_duck_go_web_search_driver.py
def search(self, query: str, **kwargs) -> ListArtifact:
    try:
        results = self.client.text(
            query, region=f"{self.language}-{self.country}", max_results=self.results_count, **kwargs
        )
        return ListArtifact(
            [
                TextArtifact(
                    json.dumps({"title": result["title"], "url": result["href"], "description": result["body"]}),
                )
                for result in results
            ],
        )
    except Exception as e:
        raise Exception(f"Error searching '{query}' with DuckDuckGo: {e}") from e

DummyAudioTranscriptionDriver

Bases: BaseAudioTranscriptionDriver

Source Code in griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py
@define
class DummyAudioTranscriptionDriver(BaseAudioTranscriptionDriver):
    model: str = field(init=False)

    def try_run(self, audio: AudioArtifact, prompts: Optional[list] = None) -> TextArtifact:
        raise DummyError(__class__.__name__, "try_transcription")
  • model = field(init=False) class-attribute instance-attribute

try_run(audio, prompts=None)

Source Code in griptape/drivers/audio_transcription/dummy_audio_transcription_driver.py
def try_run(self, audio: AudioArtifact, prompts: Optional[list] = None) -> TextArtifact:
    raise DummyError(__class__.__name__, "try_transcription")

DummyEmbeddingDriver

Bases: BaseEmbeddingDriver

Source Code in griptape/drivers/embedding/dummy_embedding_driver.py
@define
class DummyEmbeddingDriver(BaseEmbeddingDriver):
    model: None = field(init=False, default=None, kw_only=True)

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        raise DummyError(__class__.__name__, "try_embed_chunk")
  • model = field(init=False, default=None, kw_only=True) class-attribute instance-attribute

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/dummy_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    raise DummyError(__class__.__name__, "try_embed_chunk")

DummyImageGenerationDriver

Bases: BaseImageGenerationDriver

Source Code in griptape/drivers/image_generation/dummy_image_generation_driver.py
@define
class DummyImageGenerationDriver(BaseImageGenerationDriver):
    model: None = field(init=False, default=None, kw_only=True)

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        raise DummyError(__class__.__name__, "try_text_to_image")

    def try_image_variation(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise DummyError(__class__.__name__, "try_image_variation")

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise DummyError(__class__.__name__, "try_image_inpainting")

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise DummyError(__class__.__name__, "try_image_outpainting")
  • model = field(init=False, default=None, kw_only=True) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise DummyError(__class__.__name__, "try_image_inpainting")

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise DummyError(__class__.__name__, "try_image_outpainting")

try_image_variation(prompts, image, negative_prompts=None)

Source Code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_image_variation(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise DummyError(__class__.__name__, "try_image_variation")

try_text_to_image(prompts, negative_prompts=None)

Source Code in griptape/drivers/image_generation/dummy_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    raise DummyError(__class__.__name__, "try_text_to_image")

DummyPromptDriver

Bases: BasePromptDriver

Source Code in griptape/drivers/prompt/dummy_prompt_driver.py
@define
class DummyPromptDriver(BasePromptDriver):
    model: None = field(init=False, default=None, kw_only=True)
    tokenizer: DummyTokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True)

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        raise DummyError(__class__.__name__, "try_run")

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        raise DummyError(__class__.__name__, "try_stream")
  • model = field(init=False, default=None, kw_only=True) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda: DummyTokenizer()), kw_only=True) class-attribute instance-attribute

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/dummy_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    raise DummyError(__class__.__name__, "try_run")

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/dummy_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    raise DummyError(__class__.__name__, "try_stream")

DummyTextToSpeechDriver

Bases: BaseTextToSpeechDriver

Source Code in griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py
@define
class DummyTextToSpeechDriver(BaseTextToSpeechDriver):
    model: None = field(init=False, default=None, kw_only=True)

    def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
        raise DummyError(__class__.__name__, "try_text_to_audio")
  • model = field(init=False, default=None, kw_only=True) class-attribute instance-attribute

try_text_to_audio(prompts)

Source Code in griptape/drivers/text_to_speech/dummy_text_to_speech_driver.py
def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
    raise DummyError(__class__.__name__, "try_text_to_audio")

DummyVectorStoreDriver

Bases: BaseVectorStoreDriver

Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
@define()
class DummyVectorStoreDriver(BaseVectorStoreDriver):
    embedding_driver: BaseEmbeddingDriver = field(
        kw_only=True,
        default=Factory(lambda: DummyEmbeddingDriver()),
        metadata={"serializable": True},
    )

    def delete_vector(self, vector_id: str) -> None:
        raise DummyError(__class__.__name__, "delete_vector")

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        raise DummyError(__class__.__name__, "upsert_vector")

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        raise DummyError(__class__.__name__, "load_entry")

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        raise DummyError(__class__.__name__, "load_entries")

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        raise DummyError(__class__.__name__, "query_vector")

    def query(
        self,
        query: str | TextArtifact | ImageArtifact,
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        raise DummyError(__class__.__name__, "query")
  • embedding_driver = field(kw_only=True, default=Factory(lambda: DummyEmbeddingDriver()), metadata={'serializable': True}) class-attribute instance-attribute

delete_vector(vector_id)

Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def delete_vector(self, vector_id: str) -> None:
    raise DummyError(__class__.__name__, "delete_vector")

load_entries(*, namespace=None)

Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    raise DummyError(__class__.__name__, "load_entries")

load_entry(vector_id, *, namespace=None)

Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    raise DummyError(__class__.__name__, "load_entry")

query(query, *, count=None, namespace=None, include_vectors=False, **kwargs)

Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def query(
    self,
    query: str | TextArtifact | ImageArtifact,
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    raise DummyError(__class__.__name__, "query")

query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)

Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    raise DummyError(__class__.__name__, "query_vector")

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/dummy_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    raise DummyError(__class__.__name__, "upsert_vector")

ElevenLabsTextToSpeechDriver

Bases: BaseTextToSpeechDriver

Source Code in griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py
@define
class ElevenLabsTextToSpeechDriver(BaseTextToSpeechDriver):
    api_key: str = field(kw_only=True, metadata={"serializable": True})
    voice: str = field(kw_only=True, metadata={"serializable": True})
    output_format: str = field(default="mp3_44100_128", kw_only=True, metadata={"serializable": True})
    _client: Optional[ElevenLabs] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> ElevenLabs:
        return import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key)

    def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
        audio = self.client.generate(
            text=". ".join(prompts),
            voice=self.voice,
            model=self.model,
            output_format=self.output_format,
        )

        content = b""
        for chunk in audio:
            content += chunk

        # All ElevenLabs audio format strings have the following structure:
        # {format}_{sample_rate}_{bitrate}
        artifact_format = self.output_format.split("_")[0]

        return AudioArtifact(value=content, format=artifact_format)
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • output_format = field(default='mp3_44100_128', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • voice = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py
@lazy_property()
def client(self) -> ElevenLabs:
    return import_optional_dependency("elevenlabs.client").ElevenLabs(api_key=self.api_key)

try_text_to_audio(prompts)

Source Code in griptape/drivers/text_to_speech/elevenlabs_text_to_speech_driver.py
def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
    audio = self.client.generate(
        text=". ".join(prompts),
        voice=self.voice,
        model=self.model,
        output_format=self.output_format,
    )

    content = b""
    for chunk in audio:
        content += chunk

    # All ElevenLabs audio format strings have the following structure:
    # {format}_{sample_rate}_{bitrate}
    artifact_format = self.output_format.split("_")[0]

    return AudioArtifact(value=content, format=artifact_format)

ExaWebSearchDriver

Bases: BaseWebSearchDriver

Source Code in griptape/drivers/web_search/exa_web_search_driver.py
@define
class ExaWebSearchDriver(BaseWebSearchDriver):
    api_key: str = field(kw_only=True, default=None)
    highlights: bool = field(default=False, kw_only=True)
    use_autoprompt: bool = field(default=False, kw_only=True)
    params: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
    _client: Optional[Exa] = field(default=None, kw_only=True, alias="client")

    @lazy_property()
    def client(self) -> Exa:
        return import_optional_dependency("exa_py").Exa(api_key=self.api_key)

    def search(self, query: str, **kwargs) -> ListArtifact[JsonArtifact]:
        response = self.client.search_and_contents(  # pyright: ignore[reportCallIssue]
            highlights=self.highlights,  # pyright: ignore[reportArgumentType]
            use_autoprompt=self.use_autoprompt,
            query=query,
            num_results=self.results_count,
            text=True,
            **self.params,
            **kwargs,
        )
        results = [
            {"title": result.title, "url": result.url, "highlights": result.highlights, "text": result.text}
            for result in response.results
        ]
        return ListArtifact([JsonArtifact(result) for result in results])
  • _client = field(default=None, kw_only=True, alias='client') class-attribute instance-attribute

  • api_key = field(kw_only=True, default=None) class-attribute instance-attribute

  • highlights = field(default=False, kw_only=True) class-attribute instance-attribute

  • params = field(factory=dict, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • use_autoprompt = field(default=False, kw_only=True) class-attribute instance-attribute

client()

Source Code in griptape/drivers/web_search/exa_web_search_driver.py
@lazy_property()
def client(self) -> Exa:
    return import_optional_dependency("exa_py").Exa(api_key=self.api_key)

search(query, **kwargs)

Source Code in griptape/drivers/web_search/exa_web_search_driver.py
def search(self, query: str, **kwargs) -> ListArtifact[JsonArtifact]:
    response = self.client.search_and_contents(  # pyright: ignore[reportCallIssue]
        highlights=self.highlights,  # pyright: ignore[reportArgumentType]
        use_autoprompt=self.use_autoprompt,
        query=query,
        num_results=self.results_count,
        text=True,
        **self.params,
        **kwargs,
    )
    results = [
        {"title": result.title, "url": result.url, "highlights": result.highlights, "text": result.text}
        for result in response.results
    ]
    return ListArtifact([JsonArtifact(result) for result in results])

GoogleEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes

NameTypeDescription
api_keyOptional[str]Google API key.
modelstrGoogle model name.
task_typestrEmbedding model task type (https://ai.google.dev/tutorials/python\_quickstart#use\_embeddings). Defaults to retrieval_document.
titleOptional[str]Optional title for the content. Only works with retrieval_document task type.
Source Code in griptape/drivers/embedding/google_embedding_driver.py
@define
class GoogleEmbeddingDriver(BaseEmbeddingDriver):
    """Google Embedding Driver.

    Attributes:
        api_key: Google API key.
        model: Google model name.
        task_type: Embedding model task type (https://ai.google.dev/tutorials/python_quickstart#use_embeddings). Defaults to `retrieval_document`.
        title: Optional title for the content. Only works with `retrieval_document` task type.
    """

    DEFAULT_MODEL = "models/embedding-001"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    task_type: str = field(default="retrieval_document", kw_only=True, metadata={"serializable": True})
    title: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        genai = import_optional_dependency("google.generativeai")
        genai.configure(api_key=self.api_key)

        result = genai.embed_content(model=self.model, content=chunk, task_type=self.task_type, title=self.title)

        return result["embedding"]

    def _params(self, chunk: str) -> dict:
        return {"input": chunk, "model": self.model}
  • DEFAULT_MODEL = 'models/embedding-001' class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • model = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • task_type = field(default='retrieval_document', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • title = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

_params(chunk)

Source Code in griptape/drivers/embedding/google_embedding_driver.py
def _params(self, chunk: str) -> dict:
    return {"input": chunk, "model": self.model}

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/google_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    genai = import_optional_dependency("google.generativeai")
    genai.configure(api_key=self.api_key)

    result = genai.embed_content(model=self.model, content=chunk, task_type=self.task_type, title=self.title)

    return result["embedding"]

GooglePromptDriver

Bases: BasePromptDriver

Attributes

NameTypeDescription
api_keyOptional[str]Google API key.
modelstrGoogle model name.
clientGenerativeModelCustom GenerativeModel client.
top_pOptional[float]Optional value for top_p.
top_kOptional[int]Optional value for top_k.
Source Code in griptape/drivers/prompt/google_prompt_driver.py
@define
class GooglePromptDriver(BasePromptDriver):
    """Google Prompt Driver.

    Attributes:
        api_key: Google API key.
        model: Google model name.
        client: Custom `GenerativeModel` client.
        top_p: Optional value for top_p.
        top_k: Optional value for top_k.
    """

    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True),
        kw_only=True,
    )
    top_p: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
    top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    structured_output_strategy: StructuredOutputStrategy = field(
        default="tool", kw_only=True, metadata={"serializable": True}
    )
    tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True})
    _client: Optional[GenerativeModel] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value == "native":
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @lazy_property()
    def client(self) -> GenerativeModel:
        genai = import_optional_dependency("google.generativeai")
        genai.configure(api_key=self.api_key)

        return genai.GenerativeModel(self.model)

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        messages = self.__to_google_messages(prompt_stack)
        params = self._base_params(prompt_stack)
        logger.debug((messages, params["generation_config"].__dict__))
        response: GenerateContentResponse = self.client.generate_content(messages, **params)
        logger.debug(response.to_dict())

        usage_metadata = response.usage_metadata

        return Message(
            content=[self.__to_prompt_stack_message_content(part) for part in response.parts],
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(
                input_tokens=usage_metadata.prompt_token_count,
                output_tokens=usage_metadata.candidates_token_count,
            ),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        messages = self.__to_google_messages(prompt_stack)
        params = {**self._base_params(prompt_stack), "stream": True}
        logger.debug((messages, params))
        response: GenerateContentResponse = self.client.generate_content(
            messages,
            **params,
        )

        prompt_token_count = None
        for chunk in response:
            logger.debug(chunk.to_dict())
            usage_metadata = chunk.usage_metadata

            content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None
            # Only want to output the prompt token count once since it is static each chunk
            if prompt_token_count is None:
                prompt_token_count = usage_metadata.prompt_token_count
                yield DeltaMessage(
                    content=content,
                    usage=DeltaMessage.Usage(
                        input_tokens=usage_metadata.prompt_token_count,
                        output_tokens=usage_metadata.candidates_token_count,
                    ),
                )
            else:
                yield DeltaMessage(
                    content=content,
                    usage=DeltaMessage.Usage(output_tokens=usage_metadata.candidates_token_count),
                )

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        types = import_optional_dependency("google.generativeai.types")
        protos = import_optional_dependency("google.generativeai.protos")

        system_messages = prompt_stack.system_messages
        if system_messages:
            self.client._system_instruction = types.ContentDict(
                role="system",
                parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages],
            )

        params = {
            "generation_config": types.GenerationConfig(
                **{
                    # For some reason, providing stop sequences when streaming breaks native functions
                    # https://github.com/google-gemini/generative-ai-python/issues/446
                    "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences,
                    "max_output_tokens": self.max_tokens,
                    "temperature": self.temperature,
                    "top_p": self.top_p,
                    "top_k": self.top_k,
                    **self.extra_params,
                },
            ),
        }

        if prompt_stack.tools and self.use_native_tools:
            params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}

            if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
                params["tool_config"]["function_calling_config"]["mode"] = "auto"

            params["tools"] = self.__to_google_tools(prompt_stack.tools)

        return params

    def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType:
        types = import_optional_dependency("google.generativeai.types")

        return [
            types.ContentDict(
                {
                    "role": self.__to_google_role(message),
                    "parts": [self.__to_google_message_content(content) for content in message.content],
                },
            )
            for message in prompt_stack.messages
            if not message.is_system()
        ]

    def __to_google_role(self, message: Message) -> str:
        if message.is_assistant():
            return "model"
        return "user"

    def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
        types = import_optional_dependency("google.generativeai.types")

        tool_declarations = []
        for tool in tools:
            for activity in tool.activities():
                schema = tool.to_activity_json_schema(activity, "Parameters Schema")

                if "values" in schema["properties"]:
                    schema = schema["properties"]["values"]

                schema = remove_key_in_dict_recursively(schema, "additionalProperties")
                tool_declaration = types.FunctionDeclaration(
                    name=tool.to_native_tool_name(activity),
                    description=tool.activity_description(activity),
                    **(
                        {
                            "parameters": {
                                "type": schema["type"],
                                "properties": schema["properties"],
                                "required": schema.get("required", []),
                            }
                        }
                        if schema.get("properties")
                        else {}
                    ),
                )

                tool_declarations.append(tool_declaration)

        return tool_declarations

    def __to_google_message_content(self, content: BaseMessageContent) -> ContentDict | Part | str:
        types = import_optional_dependency("google.generativeai.types")
        protos = import_optional_dependency("google.generativeai.protos")

        if isinstance(content, TextMessageContent):
            return content.artifact.to_text()
        if isinstance(content, ImageMessageContent):
            if isinstance(content.artifact, ImageArtifact):
                return types.ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value)
            # TODO: Google requires uploading to the files endpoint: https://ai.google.dev/gemini-api/docs/image-understanding#upload-image
            # Can be worked around by using GenericMessageContent, similar to videos.
            raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}")
        if isinstance(content, ActionCallMessageContent):
            action = content.artifact.value

            return protos.Part(function_call=protos.FunctionCall(name=action.tag, args=action.input))
        if isinstance(content, ActionResultMessageContent):
            artifact = content.artifact

            return protos.Part(
                function_response=protos.FunctionResponse(
                    name=content.action.to_native_tool_name(),
                    response=artifact.to_dict(),
                ),
            )
        if isinstance(content, GenericMessageContent):
            return content.artifact.value
        raise ValueError(f"Unsupported prompt stack content type: {type(content)}")

    def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent:
        json_format = import_optional_dependency("google.protobuf.json_format")

        if content.text:
            return TextMessageContent(TextArtifact(content.text))
        if content.function_call:
            function_call = content.function_call

            name, path = ToolAction.from_native_tool_name(function_call.name)

            args = json_format.MessageToDict(function_call._pb).get("args", {})
            return ActionCallMessageContent(
                artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args)),
            )
        raise ValueError(f"Unsupported message content type {content}")

    def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent:
        json_format = import_optional_dependency("google.protobuf.json_format")

        if content.text:
            return TextDeltaMessageContent(content.text)
        if content.function_call:
            function_call = content.function_call

            name, path = ToolAction.from_native_tool_name(function_call.name)

            args = json_format.MessageToDict(function_call._pb).get("args", {})
            return ActionCallDeltaMessageContent(
                tag=function_call.name,
                name=name,
                path=path,
                partial_input=json.dumps(args),
            )
        raise ValueError(f"Unsupported message content type {content}")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • structured_output_strategy = field(default='tool', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: GoogleTokenizer(api_key=self.api_key, model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

  • tool_choice = field(default='auto', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • top_k = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • top_p = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • use_native_tools = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_google_message_content(content)

Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_message_content(self, content: BaseMessageContent) -> ContentDict | Part | str:
    types = import_optional_dependency("google.generativeai.types")
    protos = import_optional_dependency("google.generativeai.protos")

    if isinstance(content, TextMessageContent):
        return content.artifact.to_text()
    if isinstance(content, ImageMessageContent):
        if isinstance(content.artifact, ImageArtifact):
            return types.ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value)
        # TODO: Google requires uploading to the files endpoint: https://ai.google.dev/gemini-api/docs/image-understanding#upload-image
        # Can be worked around by using GenericMessageContent, similar to videos.
        raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}")
    if isinstance(content, ActionCallMessageContent):
        action = content.artifact.value

        return protos.Part(function_call=protos.FunctionCall(name=action.tag, args=action.input))
    if isinstance(content, ActionResultMessageContent):
        artifact = content.artifact

        return protos.Part(
            function_response=protos.FunctionResponse(
                name=content.action.to_native_tool_name(),
                response=artifact.to_dict(),
            ),
        )
    if isinstance(content, GenericMessageContent):
        return content.artifact.value
    raise ValueError(f"Unsupported prompt stack content type: {type(content)}")

__to_google_messages(prompt_stack)

Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType:
    types = import_optional_dependency("google.generativeai.types")

    return [
        types.ContentDict(
            {
                "role": self.__to_google_role(message),
                "parts": [self.__to_google_message_content(content) for content in message.content],
            },
        )
        for message in prompt_stack.messages
        if not message.is_system()
    ]

__to_google_role(message)

Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_role(self, message: Message) -> str:
    if message.is_assistant():
        return "model"
    return "user"

__to_google_tools(tools)

Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]:
    types = import_optional_dependency("google.generativeai.types")

    tool_declarations = []
    for tool in tools:
        for activity in tool.activities():
            schema = tool.to_activity_json_schema(activity, "Parameters Schema")

            if "values" in schema["properties"]:
                schema = schema["properties"]["values"]

            schema = remove_key_in_dict_recursively(schema, "additionalProperties")
            tool_declaration = types.FunctionDeclaration(
                name=tool.to_native_tool_name(activity),
                description=tool.activity_description(activity),
                **(
                    {
                        "parameters": {
                            "type": schema["type"],
                            "properties": schema["properties"],
                            "required": schema.get("required", []),
                        }
                    }
                    if schema.get("properties")
                    else {}
                ),
            )

            tool_declarations.append(tool_declaration)

    return tool_declarations

__to_prompt_stack_delta_message_content(content)

Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent:
    json_format = import_optional_dependency("google.protobuf.json_format")

    if content.text:
        return TextDeltaMessageContent(content.text)
    if content.function_call:
        function_call = content.function_call

        name, path = ToolAction.from_native_tool_name(function_call.name)

        args = json_format.MessageToDict(function_call._pb).get("args", {})
        return ActionCallDeltaMessageContent(
            tag=function_call.name,
            name=name,
            path=path,
            partial_input=json.dumps(args),
        )
    raise ValueError(f"Unsupported message content type {content}")

__to_prompt_stack_message_content(content)

Source Code in griptape/drivers/prompt/google_prompt_driver.py
def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent:
    json_format = import_optional_dependency("google.protobuf.json_format")

    if content.text:
        return TextMessageContent(TextArtifact(content.text))
    if content.function_call:
        function_call = content.function_call

        name, path = ToolAction.from_native_tool_name(function_call.name)

        args = json_format.MessageToDict(function_call._pb).get("args", {})
        return ActionCallMessageContent(
            artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args)),
        )
    raise ValueError(f"Unsupported message content type {content}")

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/google_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    types = import_optional_dependency("google.generativeai.types")
    protos = import_optional_dependency("google.generativeai.protos")

    system_messages = prompt_stack.system_messages
    if system_messages:
        self.client._system_instruction = types.ContentDict(
            role="system",
            parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages],
        )

    params = {
        "generation_config": types.GenerationConfig(
            **{
                # For some reason, providing stop sequences when streaming breaks native functions
                # https://github.com/google-gemini/generative-ai-python/issues/446
                "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences,
                "max_output_tokens": self.max_tokens,
                "temperature": self.temperature,
                "top_p": self.top_p,
                "top_k": self.top_k,
                **self.extra_params,
            },
        ),
    }

    if prompt_stack.tools and self.use_native_tools:
        params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}}

        if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool":
            params["tool_config"]["function_calling_config"]["mode"] = "auto"

        params["tools"] = self.__to_google_tools(prompt_stack.tools)

    return params

client()

Source Code in griptape/drivers/prompt/google_prompt_driver.py
@lazy_property()
def client(self) -> GenerativeModel:
    genai = import_optional_dependency("google.generativeai")
    genai.configure(api_key=self.api_key)

    return genai.GenerativeModel(self.model)

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/google_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    messages = self.__to_google_messages(prompt_stack)
    params = self._base_params(prompt_stack)
    logger.debug((messages, params["generation_config"].__dict__))
    response: GenerateContentResponse = self.client.generate_content(messages, **params)
    logger.debug(response.to_dict())

    usage_metadata = response.usage_metadata

    return Message(
        content=[self.__to_prompt_stack_message_content(part) for part in response.parts],
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(
            input_tokens=usage_metadata.prompt_token_count,
            output_tokens=usage_metadata.candidates_token_count,
        ),
    )

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/google_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    messages = self.__to_google_messages(prompt_stack)
    params = {**self._base_params(prompt_stack), "stream": True}
    logger.debug((messages, params))
    response: GenerateContentResponse = self.client.generate_content(
        messages,
        **params,
    )

    prompt_token_count = None
    for chunk in response:
        logger.debug(chunk.to_dict())
        usage_metadata = chunk.usage_metadata

        content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None
        # Only want to output the prompt token count once since it is static each chunk
        if prompt_token_count is None:
            prompt_token_count = usage_metadata.prompt_token_count
            yield DeltaMessage(
                content=content,
                usage=DeltaMessage.Usage(
                    input_tokens=usage_metadata.prompt_token_count,
                    output_tokens=usage_metadata.candidates_token_count,
                ),
            )
        else:
            yield DeltaMessage(
                content=content,
                usage=DeltaMessage.Usage(output_tokens=usage_metadata.candidates_token_count),
            )

validatestructured_output_strategy(, value)

Source Code in griptape/drivers/prompt/google_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value == "native":
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value

GoogleWebSearchDriver

Bases: BaseWebSearchDriver

Source Code in griptape/drivers/web_search/google_web_search_driver.py
@define
class GoogleWebSearchDriver(BaseWebSearchDriver):
    api_key: str = field(kw_only=True)
    search_id: str = field(kw_only=True)
    language: str = field(default="en", kw_only=True)
    country: str = field(default="us", kw_only=True)

    def search(self, query: str, **kwargs) -> ListArtifact:
        return ListArtifact([TextArtifact(json.dumps(result)) for result in self._search_google(query, **kwargs)])

    def _search_google(self, query: str, **kwargs) -> list[dict]:
        query_params = {
            "key": self.api_key,
            "cx": self.search_id,
            "q": query,
            "start": 0,
            "lr": f"lang_{self.language}",
            "num": self.results_count,
            "gl": self.country,
            **kwargs,
        }
        response = requests.get("https://www.googleapis.com/customsearch/v1", params=query_params)

        if response.status_code == 200:
            data = response.json()

            return [{"url": r["link"], "title": r["title"], "description": r["snippet"]} for r in data["items"]]

        raise Exception(
            f"Google Search API returned an error with status code "
            f"{response.status_code} and reason '{response.reason}'",
        )
  • api_key = field(kw_only=True) class-attribute instance-attribute

  • country = field(default='us', kw_only=True) class-attribute instance-attribute

  • language = field(default='en', kw_only=True) class-attribute instance-attribute

  • search_id = field(kw_only=True) class-attribute instance-attribute

_search_google(query, **kwargs)

Source Code in griptape/drivers/web_search/google_web_search_driver.py
def _search_google(self, query: str, **kwargs) -> list[dict]:
    query_params = {
        "key": self.api_key,
        "cx": self.search_id,
        "q": query,
        "start": 0,
        "lr": f"lang_{self.language}",
        "num": self.results_count,
        "gl": self.country,
        **kwargs,
    }
    response = requests.get("https://www.googleapis.com/customsearch/v1", params=query_params)

    if response.status_code == 200:
        data = response.json()

        return [{"url": r["link"], "title": r["title"], "description": r["snippet"]} for r in data["items"]]

    raise Exception(
        f"Google Search API returned an error with status code "
        f"{response.status_code} and reason '{response.reason}'",
    )

search(query, **kwargs)

Source Code in griptape/drivers/web_search/google_web_search_driver.py
def search(self, query: str, **kwargs) -> ListArtifact:
    return ListArtifact([TextArtifact(json.dumps(result)) for result in self._search_google(query, **kwargs)])

GriptapeCloudAssistantDriver

Bases: BaseAssistantDriver

Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
@define
class GriptapeCloudAssistantDriver(BaseAssistantDriver):
    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
    )
    api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
        kw_only=True,
    )
    input: Optional[str] = field(default=None, kw_only=True)
    assistant_id: str = field(kw_only=True)
    thread_id: Optional[str] = field(default=None, kw_only=True)
    thread_alias: Optional[str] = field(default=None, kw_only=True)
    ruleset_ids: Optional[list[str]] = field(default=None, kw_only=True)
    additional_ruleset_ids: list[str] = field(factory=list, kw_only=True)
    knowledge_base_ids: Optional[list[str]] = field(default=None, kw_only=True)
    additional_knowledge_base_ids: list[str] = field(factory=list, kw_only=True)
    structure_ids: Optional[list[str]] = field(default=None, kw_only=True)
    additional_structure_ids: list[str] = field(factory=list, kw_only=True)
    tool_ids: Optional[list[str]] = field(default=None, kw_only=True)
    additional_tool_ids: list[str] = field(factory=list, kw_only=True)
    stream: bool = field(default=False, kw_only=True)
    poll_interval: int = field(default=1, kw_only=True)
    max_attempts: int = field(default=20, kw_only=True)
    auto_create_thread: bool = field(default=True, kw_only=True)

    def try_run(self, *args: BaseArtifact) -> TextArtifact:
        if self.thread_id is None and self.auto_create_thread:
            self._create_or_find_thread(self.thread_alias)
        assistant_run_id = self._create_run(*args)
        run_result = self._get_run_result(assistant_run_id)

        run_result.meta.update(
            {"assistant_id": self.assistant_id, "assistant_run_id": assistant_run_id, "thread_id": self.thread_id}
        )

        return run_result

    def _create_or_find_thread(self, thread_alias: Optional[str] = None) -> None:
        if thread_alias is None:
            self.thread_id = self._create_thread()
        else:
            thread = self._find_thread_by_alias(thread_alias)

            if thread is None:
                self.thread_id = self._create_thread(thread_alias)
            else:
                self.thread_id = thread["thread_id"]

    def _create_thread(self, thread_alias: Optional[str] = None) -> str:
        url = griptape_cloud_url(self.base_url, "api/threads")

        body = {"name": uuid.uuid4().hex}
        if thread_alias is not None:
            body["alias"] = thread_alias

        response = requests.post(url, json=body, headers=self.headers)
        response.raise_for_status()
        return response.json()["thread_id"]

    def _create_run(self, *args: BaseArtifact) -> str:
        url = griptape_cloud_url(self.base_url, f"api/assistants/{self.assistant_id}/runs")

        response = requests.post(
            url,
            json={
                "args": [arg.value for arg in args],
                "stream": self.stream,
                "thread_id": self.thread_id,
                "input": self.input,
                **({"ruleset_ids": self.ruleset_ids} if self.ruleset_ids is not None else {}),
                "additional_ruleset_ids": self.additional_ruleset_ids,
                **({"knowledge_base_ids": self.knowledge_base_ids} if self.knowledge_base_ids is not None else {}),
                "additional_knowledge_base_ids": self.additional_knowledge_base_ids,
                **({"structure_ids": self.structure_ids} if self.structure_ids is not None else {}),
                "additional_structure_ids": self.additional_structure_ids,
                **({"tool_ids": self.tool_ids} if self.tool_ids is not None else {}),
                "additional_tool_ids": self.additional_tool_ids,
            },
            headers=self.headers,
        )
        response.raise_for_status()
        return response.json()["assistant_run_id"]

    def _get_run_result(self, assistant_run_id: str) -> TextArtifact:
        events = self._get_run_events(assistant_run_id)
        output = None

        for event in events:
            if event["origin"] == "ASSISTANT":
                event_payload = event["payload"]
                try:
                    EventBus.publish_event(BaseEvent.from_dict(event_payload))
                except ValueError as e:
                    logger.warning("Failed to deserialize event: %s", e)
                if event["type"] == "FinishStructureRunEvent":
                    output = TextArtifact.from_dict(event_payload["output_task_output"])

        if output is None:
            raise ValueError("Output not found.")

        return output

    def _get_run_events(self, assistant_run_id: str) -> Iterator[dict]:
        url = griptape_cloud_url(self.base_url, f"api/assistant-runs/{assistant_run_id}/events/stream")
        with requests.get(url, headers=self.headers, stream=True) as response:
            response.raise_for_status()
            for line in response.iter_lines():
                if line:
                    decoded_line = line.decode("utf-8")
                    if decoded_line.startswith("data:"):
                        yield json.loads(decoded_line.removeprefix("data:").strip())

    def _find_thread_by_alias(self, thread_alias: str) -> Optional[dict]:
        url = griptape_cloud_url(self.base_url, "api/threads")
        response = requests.get(url, params={"alias": thread_alias}, headers=self.headers)
        response.raise_for_status()

        threads = response.json()["threads"]

        return next((thread for thread in threads if thread["alias"] == thread_alias), None)
  • additional_knowledge_base_ids = field(factory=list, kw_only=True) class-attribute instance-attribute

  • additional_ruleset_ids = field(factory=list, kw_only=True) class-attribute instance-attribute

  • additional_structure_ids = field(factory=list, kw_only=True) class-attribute instance-attribute

  • additional_tool_ids = field(factory=list, kw_only=True) class-attribute instance-attribute

  • api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY'])) class-attribute instance-attribute

  • assistant_id = field(kw_only=True) class-attribute instance-attribute

  • auto_create_thread = field(default=True, kw_only=True) class-attribute instance-attribute

  • base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai'))) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

  • input = field(default=None, kw_only=True) class-attribute instance-attribute

  • knowledge_base_ids = field(default=None, kw_only=True) class-attribute instance-attribute

  • max_attempts = field(default=20, kw_only=True) class-attribute instance-attribute

  • poll_interval = field(default=1, kw_only=True) class-attribute instance-attribute

  • ruleset_ids = field(default=None, kw_only=True) class-attribute instance-attribute

  • stream = field(default=False, kw_only=True) class-attribute instance-attribute

  • structure_ids = field(default=None, kw_only=True) class-attribute instance-attribute

  • thread_alias = field(default=None, kw_only=True) class-attribute instance-attribute

  • thread_id = field(default=None, kw_only=True) class-attribute instance-attribute

  • tool_ids = field(default=None, kw_only=True) class-attribute instance-attribute

_create_or_find_thread(thread_alias=None)

Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _create_or_find_thread(self, thread_alias: Optional[str] = None) -> None:
    if thread_alias is None:
        self.thread_id = self._create_thread()
    else:
        thread = self._find_thread_by_alias(thread_alias)

        if thread is None:
            self.thread_id = self._create_thread(thread_alias)
        else:
            self.thread_id = thread["thread_id"]

_create_run(*args)

Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _create_run(self, *args: BaseArtifact) -> str:
    url = griptape_cloud_url(self.base_url, f"api/assistants/{self.assistant_id}/runs")

    response = requests.post(
        url,
        json={
            "args": [arg.value for arg in args],
            "stream": self.stream,
            "thread_id": self.thread_id,
            "input": self.input,
            **({"ruleset_ids": self.ruleset_ids} if self.ruleset_ids is not None else {}),
            "additional_ruleset_ids": self.additional_ruleset_ids,
            **({"knowledge_base_ids": self.knowledge_base_ids} if self.knowledge_base_ids is not None else {}),
            "additional_knowledge_base_ids": self.additional_knowledge_base_ids,
            **({"structure_ids": self.structure_ids} if self.structure_ids is not None else {}),
            "additional_structure_ids": self.additional_structure_ids,
            **({"tool_ids": self.tool_ids} if self.tool_ids is not None else {}),
            "additional_tool_ids": self.additional_tool_ids,
        },
        headers=self.headers,
    )
    response.raise_for_status()
    return response.json()["assistant_run_id"]

_create_thread(thread_alias=None)

Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _create_thread(self, thread_alias: Optional[str] = None) -> str:
    url = griptape_cloud_url(self.base_url, "api/threads")

    body = {"name": uuid.uuid4().hex}
    if thread_alias is not None:
        body["alias"] = thread_alias

    response = requests.post(url, json=body, headers=self.headers)
    response.raise_for_status()
    return response.json()["thread_id"]

_find_thread_by_alias(thread_alias)

Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _find_thread_by_alias(self, thread_alias: str) -> Optional[dict]:
    url = griptape_cloud_url(self.base_url, "api/threads")
    response = requests.get(url, params={"alias": thread_alias}, headers=self.headers)
    response.raise_for_status()

    threads = response.json()["threads"]

    return next((thread for thread in threads if thread["alias"] == thread_alias), None)

_get_run_events(assistant_run_id)

Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _get_run_events(self, assistant_run_id: str) -> Iterator[dict]:
    url = griptape_cloud_url(self.base_url, f"api/assistant-runs/{assistant_run_id}/events/stream")
    with requests.get(url, headers=self.headers, stream=True) as response:
        response.raise_for_status()
        for line in response.iter_lines():
            if line:
                decoded_line = line.decode("utf-8")
                if decoded_line.startswith("data:"):
                    yield json.loads(decoded_line.removeprefix("data:").strip())

_get_run_result(assistant_run_id)

Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def _get_run_result(self, assistant_run_id: str) -> TextArtifact:
    events = self._get_run_events(assistant_run_id)
    output = None

    for event in events:
        if event["origin"] == "ASSISTANT":
            event_payload = event["payload"]
            try:
                EventBus.publish_event(BaseEvent.from_dict(event_payload))
            except ValueError as e:
                logger.warning("Failed to deserialize event: %s", e)
            if event["type"] == "FinishStructureRunEvent":
                output = TextArtifact.from_dict(event_payload["output_task_output"])

    if output is None:
        raise ValueError("Output not found.")

    return output

try_run(*args)

Source Code in griptape/drivers/assistant/griptape_cloud_assistant_driver.py
def try_run(self, *args: BaseArtifact) -> TextArtifact:
    if self.thread_id is None and self.auto_create_thread:
        self._create_or_find_thread(self.thread_alias)
    assistant_run_id = self._create_run(*args)
    run_result = self._get_run_result(assistant_run_id)

    run_result.meta.update(
        {"assistant_id": self.assistant_id, "assistant_run_id": assistant_run_id, "thread_id": self.thread_id}
    )

    return run_result

GriptapeCloudConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Attributes

NameTypeDescription
thread_idOptional[str]The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to retrieve the ID from the environment variable GT_CLOUD_THREAD_ID.
aliasOptional[str]The alias of the Thread to store the conversation memory in.
base_urlstrThe base URL of the Gen AI Builder API. Defaults to the value of the environment variable GT_CLOUD_BASE_URL or https://cloud.griptape.ai.
api_keystrThe API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable GT_CLOUD_API_KEY.

Raises

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
@define(kw_only=True)
class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
    """A driver for storing conversation memory in the Gen AI Builder.

    Attributes:
        thread_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to
            retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`.
        alias: The alias of the Thread to store the conversation memory in.
        base_url: The base URL of the Gen AI Builder API. Defaults to the value of the environment variable
            `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`.
        api_key: The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will
            attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`.

    Raises:
        ValueError: If `api_key` is not provided.
    """

    thread_id: Optional[str] = field(
        default=None,
        metadata={"serializable": True},
    )
    alias: Optional[str] = field(
        default=None,
        metadata={"serializable": True},
    )
    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
    )
    api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
        init=False,
    )
    _thread: Optional[dict] = field(default=None, init=False)

    @api_key.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
        if value is None:
            raise ValueError(f"{self.__class__.__name__} requires an API key")
        return value

    @property
    def thread(self) -> dict:
        """Try to get the Thread by ID, alias, or create a new one."""
        if self._thread is None:
            thread = None
            if self.thread_id is None:
                self.thread_id = os.getenv("GT_CLOUD_THREAD_ID")

            if self.thread_id is not None:
                res = self._call_api("get", f"/threads/{self.thread_id}", raise_for_status=False)
                if res.status_code == 200:
                    thread = res.json()

            # use name as 'alias' to get thread
            if thread is None and self.alias is not None:
                res = self._call_api("get", f"/threads?alias={self.alias}").json()
                if res.get("threads"):
                    thread = res["threads"][0]
                    self.thread_id = thread.get("thread_id")

            # no thread by name or thread_id
            if thread is None:
                data = {"name": uuid.uuid4().hex} if self.alias is None else {"name": self.alias, "alias": self.alias}
                thread = self._call_api("post", "/threads", data).json()
                self.thread_id = thread["thread_id"]
                self.alias = thread.get("alias")

            self._thread = thread

        return self._thread  # pyright: ignore[reportReturnType]

    def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
        # serialize the run artifacts to json strings
        messages = [
            dict_merge(
                {
                    "input": run.input.to_json(),
                    "output": run.output.to_json(),
                    "metadata": {"run_id": run.id},
                },
                run.meta,
            )
            for run in runs
        ]

        body = dict_merge(
            {
                "messages": messages,
            },
            metadata,
        )

        # patch the Thread with the new messages and metadata
        # all old Messages are replaced with the new ones
        thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id
        self._call_api("patch", f"/threads/{thread_id}", body)
        self._thread = None

    def load(self) -> tuple[list[Run], dict[str, Any]]:
        from griptape.memory.structure import Run

        thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id

        # get the Messages from the Thread
        messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json()

        # retrieve the Thread to get the metadata
        thread_response = self._call_api("get", f"/threads/{thread_id}").json()

        runs = [
            Run(
                **({"id": message["metadata"].pop("run_id", None)} if "run_id" in message.get("metadata") else {}),
                meta=message["metadata"],
                input=BaseArtifact.from_json(message["input"]),
                output=BaseArtifact.from_json(message["output"]),
            )
            for message in messages_response.get("messages", [])
        ]

        return runs, thread_response.get("metadata", {})

    def _get_url(self, path: str) -> str:
        path = path.lstrip("/")
        return griptape_cloud_url(self.base_url, f"api/{path}")

    def _call_api(
        self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True
    ) -> requests.Response:
        res = requests.request(method, self._get_url(path), json=json, headers=self.headers)
        if raise_for_status:
            res.raise_for_status()
        return res
  • _thread = field(default=None, init=False) class-attribute instance-attribute

  • alias = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

  • api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY'])) class-attribute instance-attribute

  • base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai'))) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), init=False) class-attribute instance-attribute

  • thread property

  • thread_id = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

_call_api(method, path, json=None, *, raise_for_status=True)

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def _call_api(
    self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True
) -> requests.Response:
    res = requests.request(method, self._get_url(path), json=json, headers=self.headers)
    if raise_for_status:
        res.raise_for_status()
    return res

_get_url(path)

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def _get_url(self, path: str) -> str:
    path = path.lstrip("/")
    return griptape_cloud_url(self.base_url, f"api/{path}")

load()

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]:
    from griptape.memory.structure import Run

    thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id

    # get the Messages from the Thread
    messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json()

    # retrieve the Thread to get the metadata
    thread_response = self._call_api("get", f"/threads/{thread_id}").json()

    runs = [
        Run(
            **({"id": message["metadata"].pop("run_id", None)} if "run_id" in message.get("metadata") else {}),
            meta=message["metadata"],
            input=BaseArtifact.from_json(message["input"]),
            output=BaseArtifact.from_json(message["output"]),
        )
        for message in messages_response.get("messages", [])
    ]

    return runs, thread_response.get("metadata", {})

store(runs, metadata)

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
    # serialize the run artifacts to json strings
    messages = [
        dict_merge(
            {
                "input": run.input.to_json(),
                "output": run.output.to_json(),
                "metadata": {"run_id": run.id},
            },
            run.meta,
        )
        for run in runs
    ]

    body = dict_merge(
        {
            "messages": messages,
        },
        metadata,
    )

    # patch the Thread with the new messages and metadata
    # all old Messages are replaced with the new ones
    thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id
    self._call_api("patch", f"/threads/{thread_id}", body)
    self._thread = None

validateapi_key(, value)

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
@api_key.validator  # pyright: ignore[reportAttributeAccessIssue]
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
    if value is None:
        raise ValueError(f"{self.__class__.__name__} requires an API key")
    return value

GriptapeCloudEventListenerDriver

Bases: BaseEventListenerDriver

Attributes

NameTypeDescription
base_urlstrThe base URL of Gen AI Builder. Defaults to the GT_CLOUD_BASE_URL environment variable.
api_keyOptional[str]The API key to authenticate with Gen AI Builder.
headersdictThe headers to use when making requests to Gen AI Builder. Defaults to include the Authorization header.
structure_run_idOptional[str]The ID of the Structure Run to publish events to. Defaults to the GT_CLOUD_STRUCTURE_RUN_ID environment variable.
Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
@define
class GriptapeCloudEventListenerDriver(BaseEventListenerDriver):
    """Driver for publishing events to Gen AI Builder.

    Attributes:
        base_url: The base URL of Gen AI Builder. Defaults to the GT_CLOUD_BASE_URL environment variable.
        api_key: The API key to authenticate with Gen AI Builder.
        headers: The headers to use when making requests to Gen AI Builder. Defaults to include the Authorization header.
        structure_run_id: The ID of the Structure Run to publish events to. Defaults to the GT_CLOUD_STRUCTURE_RUN_ID environment variable.
    """

    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
        kw_only=True,
    )
    api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")), kw_only=True)
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
        kw_only=True,
    )
    structure_run_id: Optional[str] = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_STRUCTURE_RUN_ID")), kw_only=True
    )

    @structure_run_id.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_run_id(self, _: Attribute, structure_run_id: Optional[str]) -> None:
        if structure_run_id is None:
            raise ValueError(
                "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID).",
            )

    @api_key.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_api_key(self, _: Attribute, api_key: Optional[str]) -> None:
        if api_key is None:
            raise ValueError(
                "No value was found for the 'GT_CLOUD_API_KEY' environment variable. "
                "This environment variable is required when running in Gen AI Builder for authorization. "
                "You can generate a Gen AI Builder API Key by visiting https://cloud.griptape.ai/keys . "
                "Specify it as an environment variable when creating a Managed Structure in Gen AI Builder."
            )

    def publish_event(self, event: BaseEvent | dict) -> None:
        from griptape.observability.observability import Observability

        event_payload = event.to_dict() if isinstance(event, BaseEvent) else event

        span_id = Observability.get_span_id()
        if span_id is not None:
            event_payload["span_id"] = span_id

        super().publish_event(event_payload)

    def try_publish_event_payload(self, event_payload: dict) -> None:
        self._post_event(self._get_event_request(event_payload))

    def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
        self._post_event([self._get_event_request(event_payload) for event_payload in event_payload_batch])

    def _get_event_request(self, event_payload: dict) -> dict:
        return {
            "payload": event_payload,
            "timestamp": event_payload.get("timestamp", time.time()),
            "type": event_payload.get("type", "UserEvent"),
        }

    def _post_event(self, json: list[dict] | dict) -> None:
        requests.post(
            url=griptape_cloud_url(self.base_url, f"api/structure-runs/{self.structure_run_id}/events"),
            json=json,
            headers=self.headers,
        ).raise_for_status()
  • api_key = field(default=Factory(lambda: os.getenv('GT_CLOUD_API_KEY')), kw_only=True) class-attribute instance-attribute

  • base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')), kw_only=True) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

  • structure_run_id = field(default=Factory(lambda: os.getenv('GT_CLOUD_STRUCTURE_RUN_ID')), kw_only=True) class-attribute instance-attribute

_get_event_request(event_payload)

Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def _get_event_request(self, event_payload: dict) -> dict:
    return {
        "payload": event_payload,
        "timestamp": event_payload.get("timestamp", time.time()),
        "type": event_payload.get("type", "UserEvent"),
    }

_post_event(json)

Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def _post_event(self, json: list[dict] | dict) -> None:
    requests.post(
        url=griptape_cloud_url(self.base_url, f"api/structure-runs/{self.structure_run_id}/events"),
        json=json,
        headers=self.headers,
    ).raise_for_status()

publish_event(event)

Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def publish_event(self, event: BaseEvent | dict) -> None:
    from griptape.observability.observability import Observability

    event_payload = event.to_dict() if isinstance(event, BaseEvent) else event

    span_id = Observability.get_span_id()
    if span_id is not None:
        event_payload["span_id"] = span_id

    super().publish_event(event_payload)

try_publish_event_payload(event_payload)

Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    self._post_event(self._get_event_request(event_payload))

try_publish_event_payload_batch(event_payload_batch)

Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
    self._post_event([self._get_event_request(event_payload) for event_payload in event_payload_batch])

validateapi_key(, api_key)

Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
@api_key.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_api_key(self, _: Attribute, api_key: Optional[str]) -> None:
    if api_key is None:
        raise ValueError(
            "No value was found for the 'GT_CLOUD_API_KEY' environment variable. "
            "This environment variable is required when running in Gen AI Builder for authorization. "
            "You can generate a Gen AI Builder API Key by visiting https://cloud.griptape.ai/keys . "
            "Specify it as an environment variable when creating a Managed Structure in Gen AI Builder."
        )

validaterun_id(, structure_run_id)

Source Code in griptape/drivers/event_listener/griptape_cloud_event_listener_driver.py
@structure_run_id.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_run_id(self, _: Attribute, structure_run_id: Optional[str]) -> None:
    if structure_run_id is None:
        raise ValueError(
            "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID).",
        )

GriptapeCloudFileManagerDriver

Bases: BaseFileManagerDriver

Attributes

NameTypeDescription
bucket_idOptional[str]The ID of the Bucket to list, load, and save Assets in. If not provided, the driver will attempt to retrieve the ID from the environment variable GT_CLOUD_BUCKET_ID.
workdirstrThe working directory. List, load, and save operations will be performed relative to this directory.
base_urlstrThe base URL of the Gen AI Builder API. Defaults to the value of the environment variable GT_CLOUD_BASE_URL or https://cloud.griptape.ai.
api_keystrThe API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable GT_CLOUD_API_KEY.

Raises

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
@define
class GriptapeCloudFileManagerDriver(BaseFileManagerDriver):
    """GriptapeCloudFileManagerDriver can be used to list, load, and save files as Assets in Gen AI Builder Buckets.

    Attributes:
        bucket_id: The ID of the Bucket to list, load, and save Assets in. If not provided, the driver will attempt to
            retrieve the ID from the environment variable `GT_CLOUD_BUCKET_ID`.
        workdir: The working directory. List, load, and save operations will be performed relative to this directory.
        base_url: The base URL of the Gen AI Builder API. Defaults to the value of the environment variable
            `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`.
        api_key: The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will
            attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`.

    Raises:
        ValueError: If `api_key` is not provided, if `workdir` does not start with "/"", or invalid `bucket_id` and/or `bucket_name` value(s) are provided.
    """

    bucket_id: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_BUCKET_ID")), kw_only=True)
    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
    )
    api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
        init=False,
    )
    _workdir: str = field(default="/", kw_only=True, alias="workdir")

    @property
    def workdir(self) -> str:
        if self._workdir.startswith("/"):
            return self._workdir
        return f"/{self._workdir}"

    @workdir.setter
    def workdir(self, value: str) -> None:
        self._workdir = value

    @bucket_id.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_bucket_id(self, _: Attribute, value: Optional[str]) -> str:
        if value is None:
            raise ValueError(f"{self.__class__.__name__} requires an Bucket ID")
        return value

    def __attrs_post_init__(self) -> None:
        try:
            self._call_api(method="get", path=f"/buckets/{self.bucket_id}").json()
        except requests.exceptions.HTTPError as e:
            if e.response.status_code == 404:
                raise ValueError(f"No Bucket found with ID: {self.bucket_id}") from e
            raise ValueError(f"Unexpected error when retrieving Bucket with ID: {self.bucket_id}") from e

    def try_list_files(self, path: str, postfix: str = "") -> list[str]:
        full_key = self._to_full_key(path)

        data = {"prefix": full_key}
        if postfix:
            data["postfix"] = postfix
        list_assets_response = self._call_api(
            method="get", path=f"/buckets/{self.bucket_id}/assets", params=data, raise_for_status=False
        ).json()

        return [asset["name"] for asset in list_assets_response.get("assets", [])]

    def try_load_file(self, path: str) -> bytes:
        full_key = self._to_full_key(path)

        if self._is_a_directory(full_key):
            raise IsADirectoryError

        try:
            sas_url, headers = self._get_asset_url(full_key)
            response = requests.get(sas_url, headers=headers)
            response.raise_for_status()
            return response.content
        except requests.exceptions.HTTPError as e:
            if e.response.status_code == 404:
                raise FileNotFoundError from e
            raise e

    def try_save_file(self, path: str, value: bytes) -> str:
        full_key = self._to_full_key(path)

        if self._is_a_directory(full_key):
            raise IsADirectoryError

        self._call_api(
            method="put",
            path=f"/buckets/{self.bucket_id}/assets",
            json={"name": full_key},
            raise_for_status=True,
        )

        sas_url, headers = self._get_asset_url(full_key)
        response = requests.put(sas_url, data=value, headers=headers)
        response.raise_for_status()

        return f"buckets/{self.bucket_id}/assets/{full_key}"

    def _get_asset_url(self, full_key: str) -> tuple[str, dict]:
        url_response = self._call_api(
            method="post", path=f"/buckets/{self.bucket_id}/asset-urls/{full_key}", raise_for_status=True
        ).json()
        return url_response["url"], url_response.get("headers", {})

    def _get_url(self, path: str) -> str:
        path = path.lstrip("/")
        return griptape_cloud_url(self.base_url, f"api/{path}")

    def _call_api(
        self,
        method: str,
        path: str,
        json: Optional[dict] = None,
        params: Optional[dict] = None,
        *,
        raise_for_status: bool = True,
    ) -> requests.Response:
        res = requests.request(method, self._get_url(path), json=json, params=params, headers=self.headers)
        if raise_for_status:
            res.raise_for_status()
        return res

    def _is_a_directory(self, path: str) -> bool:
        return path == "" or path.endswith("/")

    def _to_full_key(self, path: str) -> str:
        path = path.lstrip("/")
        full_key = f"{self.workdir}/{path}"
        return full_key.lstrip("/")
  • _workdir = field(default='/', kw_only=True, alias='workdir') class-attribute instance-attribute

  • api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY'])) class-attribute instance-attribute

  • base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai'))) class-attribute instance-attribute

  • bucket_id = field(default=Factory(lambda: os.getenv('GT_CLOUD_BUCKET_ID')), kw_only=True) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), init=False) class-attribute instance-attribute

  • workdir property writable

attrs_post_init()

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def __attrs_post_init__(self) -> None:
    try:
        self._call_api(method="get", path=f"/buckets/{self.bucket_id}").json()
    except requests.exceptions.HTTPError as e:
        if e.response.status_code == 404:
            raise ValueError(f"No Bucket found with ID: {self.bucket_id}") from e
        raise ValueError(f"Unexpected error when retrieving Bucket with ID: {self.bucket_id}") from e

_call_api(method, path, json=None, params=None, *, raise_for_status=True)

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def _call_api(
    self,
    method: str,
    path: str,
    json: Optional[dict] = None,
    params: Optional[dict] = None,
    *,
    raise_for_status: bool = True,
) -> requests.Response:
    res = requests.request(method, self._get_url(path), json=json, params=params, headers=self.headers)
    if raise_for_status:
        res.raise_for_status()
    return res

_get_asset_url(full_key)

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def _get_asset_url(self, full_key: str) -> tuple[str, dict]:
    url_response = self._call_api(
        method="post", path=f"/buckets/{self.bucket_id}/asset-urls/{full_key}", raise_for_status=True
    ).json()
    return url_response["url"], url_response.get("headers", {})

_get_url(path)

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def _get_url(self, path: str) -> str:
    path = path.lstrip("/")
    return griptape_cloud_url(self.base_url, f"api/{path}")

_is_a_directory(path)

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def _is_a_directory(self, path: str) -> bool:
    return path == "" or path.endswith("/")

_to_full_key(path)

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def _to_full_key(self, path: str) -> str:
    path = path.lstrip("/")
    full_key = f"{self.workdir}/{path}"
    return full_key.lstrip("/")

try_list_files(path, postfix='')

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def try_list_files(self, path: str, postfix: str = "") -> list[str]:
    full_key = self._to_full_key(path)

    data = {"prefix": full_key}
    if postfix:
        data["postfix"] = postfix
    list_assets_response = self._call_api(
        method="get", path=f"/buckets/{self.bucket_id}/assets", params=data, raise_for_status=False
    ).json()

    return [asset["name"] for asset in list_assets_response.get("assets", [])]

try_load_file(path)

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def try_load_file(self, path: str) -> bytes:
    full_key = self._to_full_key(path)

    if self._is_a_directory(full_key):
        raise IsADirectoryError

    try:
        sas_url, headers = self._get_asset_url(full_key)
        response = requests.get(sas_url, headers=headers)
        response.raise_for_status()
        return response.content
    except requests.exceptions.HTTPError as e:
        if e.response.status_code == 404:
            raise FileNotFoundError from e
        raise e

try_save_file(path, value)

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
def try_save_file(self, path: str, value: bytes) -> str:
    full_key = self._to_full_key(path)

    if self._is_a_directory(full_key):
        raise IsADirectoryError

    self._call_api(
        method="put",
        path=f"/buckets/{self.bucket_id}/assets",
        json={"name": full_key},
        raise_for_status=True,
    )

    sas_url, headers = self._get_asset_url(full_key)
    response = requests.put(sas_url, data=value, headers=headers)
    response.raise_for_status()

    return f"buckets/{self.bucket_id}/assets/{full_key}"

validatebucket_id(, value)

Source Code in griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
@bucket_id.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_bucket_id(self, _: Attribute, value: Optional[str]) -> str:
    if value is None:
        raise ValueError(f"{self.__class__.__name__} requires an Bucket ID")
    return value

GriptapeCloudImageGenerationDriver

Bases: BaseImageGenerationDriver

Source Code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
@define
class GriptapeCloudImageGenerationDriver(BaseImageGenerationDriver):
    model: Optional[str] = field(default=None, kw_only=True)
    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
    )
    api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
    )
    style: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    quality: Literal["standard", "hd"] = field(default="standard", kw_only=True, metadata={"serializable": True})
    image_size: Literal["1024x1024", "1024x1792", "1792x1024"] = field(
        default="1024x1024", kw_only=True, metadata={"serializable": True}
    )

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        url = griptape_cloud_url(self.base_url, "api/images/generations")

        response = requests.post(
            url,
            headers=self.headers,
            json={
                "prompts": prompts,
                "driver_configuration": {
                    "model": self.model,
                    "image_size": self.image_size,
                    "quality": self.quality,
                    "style": self.style,
                },
            },
        )
        response.raise_for_status()
        response = response.json()

        return ImageArtifact.from_dict(response["artifact"])

    def try_image_variation(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support image variation")

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting")

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")
  • api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY'])) class-attribute instance-attribute

  • base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai'))) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

  • image_size = field(default='1024x1024', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(default=None, kw_only=True) class-attribute instance-attribute

  • quality = field(default='standard', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • style = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting")

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

try_image_variation(prompts, image, negative_prompts=None)

Source Code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_image_variation(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support image variation")

try_text_to_image(prompts, negative_prompts=None)

Source Code in griptape/drivers/image_generation/griptape_cloud_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    url = griptape_cloud_url(self.base_url, "api/images/generations")

    response = requests.post(
        url,
        headers=self.headers,
        json={
            "prompts": prompts,
            "driver_configuration": {
                "model": self.model,
                "image_size": self.image_size,
                "quality": self.quality,
                "style": self.style,
            },
        },
    )
    response.raise_for_status()
    response = response.json()

    return ImageArtifact.from_dict(response["artifact"])

GriptapeCloudObservabilityDriver

Bases: OpenTelemetryObservabilityDriver

Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
@define
class GriptapeCloudObservabilityDriver(OpenTelemetryObservabilityDriver):
    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), kw_only=True
    )
    api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]), kw_only=True)
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
    )
    structure_run_id: Optional[str] = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_STRUCTURE_RUN_ID")), kw_only=True
    )
    span_processor: SpanProcessor = field(
        default=Factory(
            lambda self: import_optional_dependency("opentelemetry.sdk.trace.export").BatchSpanProcessor(
                GriptapeCloudObservabilityDriver.build_span_exporter(
                    base_url=self.base_url,
                    api_key=self.api_key,
                    headers=self.headers,
                    structure_run_id=self.structure_run_id,
                )
            ),
            takes_self=True,
        ),
        kw_only=True,
    )

    @structure_run_id.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_run_id(self, _: Attribute, structure_run_id: Optional[str]) -> None:
        if structure_run_id is None:
            raise ValueError(
                "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID)."
            )

    @staticmethod
    def format_trace_id(trace_id: int) -> str:
        return str(UUID(int=trace_id))

    @staticmethod
    def format_span_id(span_id: int) -> str:
        return str(UUID(int=span_id))

    @staticmethod
    def build_span_exporter(base_url: str, api_key: str, headers: dict, structure_run_id: str) -> SpanExporter:
        @define
        class SpanExporter(import_optional_dependency("opentelemetry.sdk.trace.export").SpanExporter):
            base_url: str = field(kw_only=True)
            api_key: str = field(kw_only=True)
            headers: dict = field(kw_only=True)
            structure_run_id: str = field(kw_only=True)

            def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
                opentelemetry_util = import_optional_dependency("opentelemetry.sdk.util")
                opentelemetry_trace_export = import_optional_dependency("opentelemetry.sdk.trace.export")

                url = griptape_cloud_url(self.base_url, f"api/structure-runs/{self.structure_run_id}/spans")
                payload = [
                    {
                        "trace_id": GriptapeCloudObservabilityDriver.format_trace_id(span.context.trace_id),
                        "span_id": GriptapeCloudObservabilityDriver.format_span_id(span.context.span_id),
                        "parent_id": GriptapeCloudObservabilityDriver.format_span_id(span.parent.span_id)
                        if span.parent
                        else None,
                        "name": span.name,
                        "start_time": opentelemetry_util.ns_to_iso_str(span.start_time) if span.start_time else None,
                        "end_time": opentelemetry_util.ns_to_iso_str(span.end_time) if span.end_time else None,
                        "status": span.status.status_code.name,
                        "attributes": {**span.attributes} if span.attributes else {},
                        "events": [
                            {
                                "name": event.name,
                                "timestamp": opentelemetry_util.ns_to_iso_str(event.timestamp)
                                if event.timestamp
                                else None,
                                "attributes": {**event.attributes} if event.attributes else {},
                            }
                            for event in span.events
                        ],
                    }
                    for span in spans
                    if span.context is not None
                ]
                response = requests.post(url=url, json=payload, headers=self.headers)
                return (
                    opentelemetry_trace_export.SpanExportResult.SUCCESS
                    if response.status_code == 200
                    else opentelemetry_trace_export.SpanExportResult.FAILURE
                )

        return SpanExporter(
            base_url=base_url,
            api_key=api_key,
            headers=headers,
            structure_run_id=structure_run_id,
        )

    def get_span_id(self) -> Optional[str]:
        opentelemetry_trace = import_optional_dependency("opentelemetry.trace")
        span = opentelemetry_trace.get_current_span()
        if span is opentelemetry_trace.INVALID_SPAN:
            return None
        return GriptapeCloudObservabilityDriver.format_span_id(span.get_span_context().span_id)
  • api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY']), kw_only=True) class-attribute instance-attribute

  • base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')), kw_only=True) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

  • span_processor = field(default=Factory(lambda self: import_optional_dependency('opentelemetry.sdk.trace.export').BatchSpanProcessor(GriptapeCloudObservabilityDriver.build_span_exporter(base_url=self.base_url, api_key=self.api_key, headers=self.headers, structure_run_id=self.structure_run_id)), takes_self=True), kw_only=True) class-attribute instance-attribute

  • structure_run_id = field(default=Factory(lambda: os.getenv('GT_CLOUD_STRUCTURE_RUN_ID')), kw_only=True) class-attribute instance-attribute

build_span_exporter(base_url, api_key, headers, structure_run_id)staticmethod

Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
@staticmethod
def build_span_exporter(base_url: str, api_key: str, headers: dict, structure_run_id: str) -> SpanExporter:
    @define
    class SpanExporter(import_optional_dependency("opentelemetry.sdk.trace.export").SpanExporter):
        base_url: str = field(kw_only=True)
        api_key: str = field(kw_only=True)
        headers: dict = field(kw_only=True)
        structure_run_id: str = field(kw_only=True)

        def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
            opentelemetry_util = import_optional_dependency("opentelemetry.sdk.util")
            opentelemetry_trace_export = import_optional_dependency("opentelemetry.sdk.trace.export")

            url = griptape_cloud_url(self.base_url, f"api/structure-runs/{self.structure_run_id}/spans")
            payload = [
                {
                    "trace_id": GriptapeCloudObservabilityDriver.format_trace_id(span.context.trace_id),
                    "span_id": GriptapeCloudObservabilityDriver.format_span_id(span.context.span_id),
                    "parent_id": GriptapeCloudObservabilityDriver.format_span_id(span.parent.span_id)
                    if span.parent
                    else None,
                    "name": span.name,
                    "start_time": opentelemetry_util.ns_to_iso_str(span.start_time) if span.start_time else None,
                    "end_time": opentelemetry_util.ns_to_iso_str(span.end_time) if span.end_time else None,
                    "status": span.status.status_code.name,
                    "attributes": {**span.attributes} if span.attributes else {},
                    "events": [
                        {
                            "name": event.name,
                            "timestamp": opentelemetry_util.ns_to_iso_str(event.timestamp)
                            if event.timestamp
                            else None,
                            "attributes": {**event.attributes} if event.attributes else {},
                        }
                        for event in span.events
                    ],
                }
                for span in spans
                if span.context is not None
            ]
            response = requests.post(url=url, json=payload, headers=self.headers)
            return (
                opentelemetry_trace_export.SpanExportResult.SUCCESS
                if response.status_code == 200
                else opentelemetry_trace_export.SpanExportResult.FAILURE
            )

    return SpanExporter(
        base_url=base_url,
        api_key=api_key,
        headers=headers,
        structure_run_id=structure_run_id,
    )

format_span_id(span_id)staticmethod

Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
@staticmethod
def format_span_id(span_id: int) -> str:
    return str(UUID(int=span_id))

format_trace_id(trace_id)staticmethod

Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
@staticmethod
def format_trace_id(trace_id: int) -> str:
    return str(UUID(int=trace_id))

get_span_id()

Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
def get_span_id(self) -> Optional[str]:
    opentelemetry_trace = import_optional_dependency("opentelemetry.trace")
    span = opentelemetry_trace.get_current_span()
    if span is opentelemetry_trace.INVALID_SPAN:
        return None
    return GriptapeCloudObservabilityDriver.format_span_id(span.get_span_context().span_id)

validaterun_id(, structure_run_id)

Source Code in griptape/drivers/observability/griptape_cloud_observability_driver.py
@structure_run_id.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_run_id(self, _: Attribute, structure_run_id: Optional[str]) -> None:
    if structure_run_id is None:
        raise ValueError(
            "structure_run_id must be set either in the constructor or as an environment variable (GT_CLOUD_STRUCTURE_RUN_ID)."
        )

GriptapeCloudPromptDriver

Bases: BasePromptDriver

Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
@define
class GriptapeCloudPromptDriver(BasePromptDriver):
    model: Optional[str] = field(default=None, kw_only=True)
    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
    )
    api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True
    )
    tokenizer: BaseTokenizer = field(
        default=Factory(
            lambda self: SimpleTokenizer(
                characters_per_token=4,
                max_input_tokens=2000,
                max_output_tokens=self.max_tokens,
            ),
            takes_self=True,
        ),
        kw_only=True,
    )
    use_native_tools: bool = field(default=True, kw_only=True)
    structured_output_strategy: StructuredOutputStrategy = field(
        default="native", kw_only=True, metadata={"serializable": True}
    )

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        url = griptape_cloud_url(self.base_url, "api/chat/messages")

        params = self._base_params(prompt_stack)
        logger.debug(params)
        response = requests.post(url, headers=self.headers, json=params)
        response.raise_for_status()
        response_json = response.json()
        logger.debug(response_json)

        return Message.from_dict(response_json)

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        url = griptape_cloud_url(self.base_url, "api/chat/messages/stream")
        params = self._base_params(prompt_stack)
        logger.debug(params)
        with requests.post(url, headers=self.headers, json=params, stream=True) as response:
            response.raise_for_status()
            for line in response.iter_lines():
                if line:
                    decoded_line = line.decode("utf-8")
                    if decoded_line.startswith("data:"):
                        delta_message_payload = decoded_line.removeprefix("data:").strip()
                        logger.debug(delta_message_payload)
                        yield DeltaMessage.from_json(delta_message_payload)

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "messages": prompt_stack.to_dict()["messages"],
            "tools": self.__to_griptape_tools(prompt_stack.tools),
            **({"output_schema": prompt_stack.to_output_json_schema()} if prompt_stack.output_schema else {}),
            "driver_configuration": {
                **({"model": self.model} if self.model else {}),
                "max_tokens": self.max_tokens,
                "use_native_tools": self.use_native_tools,
                "temperature": self.temperature,
                "structured_output_strategy": self.structured_output_strategy,
                "extra_params": self.extra_params,
            },
        }

    def __to_griptape_tools(self, tools: list[BaseTool]) -> list[dict]:
        return [
            {
                "name": tool.name,
                "activities": [
                    {
                        "name": activity.__name__,
                        "description": tool.activity_description(activity),
                        "json_schema": tool.to_activity_json_schema(activity, "Schema"),
                    }
                    for activity in tool.activities()
                ],
            }
            for tool in tools
        ]
  • api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY'])) class-attribute instance-attribute

  • base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai'))) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

  • model = field(default=None, kw_only=True) class-attribute instance-attribute

  • structured_output_strategy = field(default='native', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: SimpleTokenizer(characters_per_token=4, max_input_tokens=2000, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

  • use_native_tools = field(default=True, kw_only=True) class-attribute instance-attribute

__to_griptape_tools(tools)

Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
def __to_griptape_tools(self, tools: list[BaseTool]) -> list[dict]:
    return [
        {
            "name": tool.name,
            "activities": [
                {
                    "name": activity.__name__,
                    "description": tool.activity_description(activity),
                    "json_schema": tool.to_activity_json_schema(activity, "Schema"),
                }
                for activity in tool.activities()
            ],
        }
        for tool in tools
    ]

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    return {
        "messages": prompt_stack.to_dict()["messages"],
        "tools": self.__to_griptape_tools(prompt_stack.tools),
        **({"output_schema": prompt_stack.to_output_json_schema()} if prompt_stack.output_schema else {}),
        "driver_configuration": {
            **({"model": self.model} if self.model else {}),
            "max_tokens": self.max_tokens,
            "use_native_tools": self.use_native_tools,
            "temperature": self.temperature,
            "structured_output_strategy": self.structured_output_strategy,
            "extra_params": self.extra_params,
        },
    }

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    url = griptape_cloud_url(self.base_url, "api/chat/messages")

    params = self._base_params(prompt_stack)
    logger.debug(params)
    response = requests.post(url, headers=self.headers, json=params)
    response.raise_for_status()
    response_json = response.json()
    logger.debug(response_json)

    return Message.from_dict(response_json)

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    url = griptape_cloud_url(self.base_url, "api/chat/messages/stream")
    params = self._base_params(prompt_stack)
    logger.debug(params)
    with requests.post(url, headers=self.headers, json=params, stream=True) as response:
        response.raise_for_status()
        for line in response.iter_lines():
            if line:
                decoded_line = line.decode("utf-8")
                if decoded_line.startswith("data:"):
                    delta_message_payload = decoded_line.removeprefix("data:").strip()
                    logger.debug(delta_message_payload)
                    yield DeltaMessage.from_json(delta_message_payload)

GriptapeCloudRulesetDriver

Bases: BaseRulesetDriver

Attributes

NameTypeDescription
ruleset_idOptional[str]The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to retrieve the ID from the environment variable GT_CLOUD_THREAD_ID. If that is not set, a new Thread will be created.
base_urlstrThe base URL of the Gen AI Builder API. Defaults to the value of the environment variable GT_CLOUD_BASE_URL or https://cloud.griptape.ai.
api_keyOptional[str]The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable GT_CLOUD_API_KEY.

Raises

Source Code in griptape/drivers/ruleset/griptape_cloud_ruleset_driver.py
@define(kw_only=True)
class GriptapeCloudRulesetDriver(BaseRulesetDriver):
    """A driver for storing conversation memory in the Gen AI Builder.

    Attributes:
        ruleset_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to
            retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. If that is not set, a new Thread will be
            created.
        base_url: The base URL of the Gen AI Builder API. Defaults to the value of the environment variable
            `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`.
        api_key: The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will
            attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`.

    Raises:
        ValueError: If `api_key` is not provided.
    """

    ruleset_id: Optional[str] = field(
        default=None,
        metadata={"serializable": True},
    )
    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
    )
    api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")))
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
        init=False,
    )

    @api_key.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
        if value is None:
            raise ValueError(f"{self.__class__.__name__} requires an API key")
        return value

    def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]:
        """Load the ruleset from Gen AI Builder, using the ruleset name as an alias if ruleset_id is not provided."""
        ruleset = None

        if self.ruleset_id is not None:
            res = self._call_api("get", f"/rulesets/{self.ruleset_id}", raise_for_status=False)
            if res.status_code == 200:
                ruleset = res.json()

        # use name as 'alias' to get ruleset
        if ruleset is None:
            res = self._call_api("get", f"/rulesets?alias={ruleset_name}").json()
            if res.get("rulesets"):
                ruleset = res["rulesets"][0]

        # no ruleset by name or ruleset_id
        if ruleset is None:
            if self.raise_not_found:
                raise ValueError(f"No ruleset found with alias: {ruleset_name} or ruleset_id: {self.ruleset_id}")
            return [], {}

        rules = self._call_api("get", f"/rules?ruleset_id={ruleset['ruleset_id']}").json().get("rules", [])

        for rule in rules:
            rule["metadata"] = dict_merge(rule.get("metadata", {}), {"griptape_cloud_rule_id": rule["rule_id"]})

        return [self._get_rule(rule["rule"], rule["metadata"]) for rule in rules], ruleset.get("metadata", {})

    def _get_url(self, path: str) -> str:
        path = path.lstrip("/")
        return griptape_cloud_url(self.base_url, f"api/{path}")

    def _call_api(self, method: str, path: str, *, raise_for_status: bool = True) -> requests.Response:
        res = requests.request(method, self._get_url(path), headers=self.headers)
        if raise_for_status:
            res.raise_for_status()
        return res
  • api_key = field(default=Factory(lambda: os.getenv('GT_CLOUD_API_KEY'))) class-attribute instance-attribute

  • base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai'))) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), init=False) class-attribute instance-attribute

  • ruleset_id = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

_call_api(method, path, *, raise_for_status=True)

Source Code in griptape/drivers/ruleset/griptape_cloud_ruleset_driver.py
def _call_api(self, method: str, path: str, *, raise_for_status: bool = True) -> requests.Response:
    res = requests.request(method, self._get_url(path), headers=self.headers)
    if raise_for_status:
        res.raise_for_status()
    return res

_get_url(path)

Source Code in griptape/drivers/ruleset/griptape_cloud_ruleset_driver.py
def _get_url(self, path: str) -> str:
    path = path.lstrip("/")
    return griptape_cloud_url(self.base_url, f"api/{path}")

load(ruleset_name)

Source Code in griptape/drivers/ruleset/griptape_cloud_ruleset_driver.py
def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]:
    """Load the ruleset from Gen AI Builder, using the ruleset name as an alias if ruleset_id is not provided."""
    ruleset = None

    if self.ruleset_id is not None:
        res = self._call_api("get", f"/rulesets/{self.ruleset_id}", raise_for_status=False)
        if res.status_code == 200:
            ruleset = res.json()

    # use name as 'alias' to get ruleset
    if ruleset is None:
        res = self._call_api("get", f"/rulesets?alias={ruleset_name}").json()
        if res.get("rulesets"):
            ruleset = res["rulesets"][0]

    # no ruleset by name or ruleset_id
    if ruleset is None:
        if self.raise_not_found:
            raise ValueError(f"No ruleset found with alias: {ruleset_name} or ruleset_id: {self.ruleset_id}")
        return [], {}

    rules = self._call_api("get", f"/rules?ruleset_id={ruleset['ruleset_id']}").json().get("rules", [])

    for rule in rules:
        rule["metadata"] = dict_merge(rule.get("metadata", {}), {"griptape_cloud_rule_id": rule["rule_id"]})

    return [self._get_rule(rule["rule"], rule["metadata"]) for rule in rules], ruleset.get("metadata", {})

validateapi_key(, value)

Source Code in griptape/drivers/ruleset/griptape_cloud_ruleset_driver.py
@api_key.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
    if value is None:
        raise ValueError(f"{self.__class__.__name__} requires an API key")
    return value

GriptapeCloudStructureRunDriver

Bases: BaseStructureRunDriver

Source Code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
@define
class GriptapeCloudStructureRunDriver(BaseStructureRunDriver):
    base_url: str = field(default="https://cloud.griptape.ai", kw_only=True)
    api_key: str = field(kw_only=True)
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
        kw_only=True,
    )
    structure_id: str = field(kw_only=True)
    structure_run_wait_time_interval: int = field(default=2, kw_only=True)
    structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True)
    async_run: bool = field(default=False, kw_only=True)

    def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
        structure_run_id = self._create_run(*args)

        if self.async_run:
            return InfoArtifact("Run started successfully")
        return self._get_run_result(structure_run_id)

    def _create_run(self, *args: BaseArtifact) -> str:
        url = griptape_cloud_url(self.base_url, f"api/structures/{self.structure_id}/runs")

        env_vars = [{"name": key, "value": value, "source": "manual"} for key, value in self.env.items()]

        response = requests.post(
            url,
            json={"args": [arg.value for arg in args], "env_vars": env_vars},
            headers=self.headers,
        )
        response.raise_for_status()
        response_json = response.json()

        return response_json["structure_run_id"]

    def _get_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact:
        events = self._get_run_events(structure_run_id)
        output = None

        for event in events:
            event_type = event["type"]
            event_payload = event.get("payload", {})
            if event["origin"] == "USER":
                try:
                    if "span_id" in event_payload:
                        span_id = event_payload.pop("span_id")
                        if "meta" in event_payload:
                            event_payload["meta"]["span_id"] = span_id
                        else:
                            event_payload["meta"] = {"span_id": span_id}
                    EventBus.publish_event(BaseEvent.from_dict(event_payload))
                except ValueError as e:
                    logger.warning("Failed to deserialize event: %s", e)
                if event["type"] == "FinishStructureRunEvent":
                    output = BaseArtifact.from_dict(event_payload["output_task_output"])
            elif event["origin"] == "SYSTEM":
                if event_type == "StructureRunError":
                    output = ErrorArtifact(event_payload["status_detail"]["error"])

        if output is None:
            raise ValueError("Output not found.")

        return output

    def _get_run_events(self, structure_run_id: str) -> Iterator[dict]:
        url = griptape_cloud_url(self.base_url, f"api/structure-runs/{structure_run_id}/events/stream")
        with requests.get(url, headers=self.headers, stream=True) as response:
            response.raise_for_status()
            for line in response.iter_lines():
                if line:
                    decoded_line = line.decode("utf-8")
                    if decoded_line.startswith("data:"):
                        yield json.loads(decoded_line.removeprefix("data:").strip())
  • api_key = field(kw_only=True) class-attribute instance-attribute

  • async_run = field(default=False, kw_only=True) class-attribute instance-attribute

  • base_url = field(default='https://cloud.griptape.ai', kw_only=True) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

  • structure_id = field(kw_only=True) class-attribute instance-attribute

  • structure_run_max_wait_time_attempts = field(default=20, kw_only=True) class-attribute instance-attribute

  • structure_run_wait_time_interval = field(default=2, kw_only=True) class-attribute instance-attribute

_create_run(*args)

Source Code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def _create_run(self, *args: BaseArtifact) -> str:
    url = griptape_cloud_url(self.base_url, f"api/structures/{self.structure_id}/runs")

    env_vars = [{"name": key, "value": value, "source": "manual"} for key, value in self.env.items()]

    response = requests.post(
        url,
        json={"args": [arg.value for arg in args], "env_vars": env_vars},
        headers=self.headers,
    )
    response.raise_for_status()
    response_json = response.json()

    return response_json["structure_run_id"]

_get_run_events(structure_run_id)

Source Code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def _get_run_events(self, structure_run_id: str) -> Iterator[dict]:
    url = griptape_cloud_url(self.base_url, f"api/structure-runs/{structure_run_id}/events/stream")
    with requests.get(url, headers=self.headers, stream=True) as response:
        response.raise_for_status()
        for line in response.iter_lines():
            if line:
                decoded_line = line.decode("utf-8")
                if decoded_line.startswith("data:"):
                    yield json.loads(decoded_line.removeprefix("data:").strip())

_get_run_result(structure_run_id)

Source Code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def _get_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact:
    events = self._get_run_events(structure_run_id)
    output = None

    for event in events:
        event_type = event["type"]
        event_payload = event.get("payload", {})
        if event["origin"] == "USER":
            try:
                if "span_id" in event_payload:
                    span_id = event_payload.pop("span_id")
                    if "meta" in event_payload:
                        event_payload["meta"]["span_id"] = span_id
                    else:
                        event_payload["meta"] = {"span_id": span_id}
                EventBus.publish_event(BaseEvent.from_dict(event_payload))
            except ValueError as e:
                logger.warning("Failed to deserialize event: %s", e)
            if event["type"] == "FinishStructureRunEvent":
                output = BaseArtifact.from_dict(event_payload["output_task_output"])
        elif event["origin"] == "SYSTEM":
            if event_type == "StructureRunError":
                output = ErrorArtifact(event_payload["status_detail"]["error"])

    if output is None:
        raise ValueError("Output not found.")

    return output

try_run(*args)

Source Code in griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py
def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
    structure_run_id = self._create_run(*args)

    if self.async_run:
        return InfoArtifact("Run started successfully")
    return self._get_run_result(structure_run_id)

GriptapeCloudVectorStoreDriver

Bases: BaseVectorStoreDriver

Attributes

NameTypeDescription
api_keystrAPI Key for Gen AI Builder.
knowledge_base_idstrKnowledge Base ID for Gen AI Builder.
base_urlstrBase URL for Gen AI Builder.
headersdictHeaders for Gen AI Builder.
Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
@define
class GriptapeCloudVectorStoreDriver(BaseVectorStoreDriver):
    """A vector store driver for Gen AI Builder Knowledge Bases.

    Attributes:
        api_key: API Key for Gen AI Builder.
        knowledge_base_id: Knowledge Base ID for Gen AI Builder.
        base_url: Base URL for Gen AI Builder.
        headers: Headers for Gen AI Builder.
    """

    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
    )
    api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
    knowledge_base_id: str = field(kw_only=True, metadata={"serializable": True})
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
        kw_only=True,
    )
    embedding_driver: BaseEmbeddingDriver = field(
        default=Factory(lambda: DummyEmbeddingDriver()),
        metadata={"serializable": True},
        kw_only=True,
        init=False,
    )

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        raise NotImplementedError(f"{self.__class__.__name__} does not support vector upsert.")

    def upsert_text_artifact(
        self,
        artifact: TextArtifact,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        vector_id: Optional[str] = None,
        **kwargs,
    ) -> str:
        raise NotImplementedError(f"{self.__class__.__name__} does not support text artifact upsert.")

    def upsert_text(
        self,
        string: str,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.")

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
        raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")

    def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.")

    def query(
        self,
        query: str | TextArtifact | ImageArtifact,
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: Optional[bool] = None,
        distance_metric: Optional[str] = None,
        # GriptapeCloudVectorStoreDriver-specific params:
        filter: Optional[dict] = None,  # noqa: A002
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Performs a query on the Knowledge Base.

        Performs a query on the Knowledge Base and returns Artifacts with close vector proximity to the query, optionally filtering to only those that match the provided filter(s).
        """
        if isinstance(query, ImageArtifact):
            raise ValueError(f"{self.__class__.__name__} does not support querying with Image Artifacts.")
        url = griptape_cloud_url(self.base_url, f"api/knowledge-bases/{self.knowledge_base_id}/query")

        query_args = {
            "count": count,
            "distance_metric": distance_metric,
            "filter": filter,
            "include_vectors": include_vectors,
        }
        query_args = {k: v for k, v in query_args.items() if v is not None}

        request: dict[str, Any] = {
            "query": str(query),
            "query_args": query_args,
        }

        response = requests.post(url, json=request, headers=self.headers).json()
        entries = response.get("entries", [])
        return [BaseVectorStoreDriver.Entry.from_dict(entry) for entry in entries]

    def delete_vector(self, vector_id: str) -> NoReturn:
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
  • api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY'])) class-attribute instance-attribute

  • base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai'))) class-attribute instance-attribute

  • embedding_driver = field(default=Factory(lambda: DummyEmbeddingDriver()), metadata={'serializable': True}, kw_only=True, init=False) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True) class-attribute instance-attribute

  • knowledge_base_id = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

delete_vector(vector_id)

Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn:
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

load_artifacts(*, namespace=None)

Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.")

load_entries(*, namespace=None)

Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")

load_entry(vector_id, *, namespace=None)

Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
    raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")

query(query, *, count=None, namespace=None, include_vectors=None, distance_metric=None, filter=None, **kwargs)

Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def query(
    self,
    query: str | TextArtifact | ImageArtifact,
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: Optional[bool] = None,
    distance_metric: Optional[str] = None,
    # GriptapeCloudVectorStoreDriver-specific params:
    filter: Optional[dict] = None,  # noqa: A002
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    """Performs a query on the Knowledge Base.

    Performs a query on the Knowledge Base and returns Artifacts with close vector proximity to the query, optionally filtering to only those that match the provided filter(s).
    """
    if isinstance(query, ImageArtifact):
        raise ValueError(f"{self.__class__.__name__} does not support querying with Image Artifacts.")
    url = griptape_cloud_url(self.base_url, f"api/knowledge-bases/{self.knowledge_base_id}/query")

    query_args = {
        "count": count,
        "distance_metric": distance_metric,
        "filter": filter,
        "include_vectors": include_vectors,
    }
    query_args = {k: v for k, v in query_args.items() if v is not None}

    request: dict[str, Any] = {
        "query": str(query),
        "query_args": query_args,
    }

    response = requests.post(url, json=request, headers=self.headers).json()
    entries = response.get("entries", [])
    return [BaseVectorStoreDriver.Entry.from_dict(entry) for entry in entries]

upsert_text(string, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def upsert_text(
    self,
    string: str,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.")

upsert_text_artifact(artifact, namespace=None, meta=None, vector_id=None, **kwargs)

Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def upsert_text_artifact(
    self,
    artifact: TextArtifact,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    vector_id: Optional[str] = None,
    **kwargs,
) -> str:
    raise NotImplementedError(f"{self.__class__.__name__} does not support text artifact upsert.")

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/griptape_cloud_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    raise NotImplementedError(f"{self.__class__.__name__} does not support vector upsert.")

GrokPromptDriver

Bases: OpenAiChatPromptDriver

Source Code in griptape/drivers/prompt/grok_prompt_driver.py
@define
class GrokPromptDriver(OpenAiChatPromptDriver):
    base_url: str = field(default="https://api.x.ai/v1", kw_only=True, metadata={"serializable": True})
    tokenizer: GrokTokenizer = field(
        default=Factory(
            lambda self: GrokTokenizer(base_url=self.base_url, api_key=self.api_key, model=self.model), takes_self=True
        ),
        kw_only=True,
        metadata={"serializable": True},
    )
  • base_url = field(default='https://api.x.ai/v1', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: GrokTokenizer(base_url=self.base_url, api_key=self.api_key, model=self.model), takes_self=True), kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

HuggingFaceHubEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes

NameTypeDescription
api_tokenstrHugging Face Hub API token.
modelstrHugging Face Hub model name.
clientInferenceClientCustom InferenceApi.
Source Code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
@define
class HuggingFaceHubEmbeddingDriver(BaseEmbeddingDriver):
    """Hugging Face Hub Embedding Driver.

    Attributes:
        api_token: Hugging Face Hub API token.
        model: Hugging Face Hub model name.
        client: Custom `InferenceApi`.
    """

    api_token: str = field(kw_only=True, metadata={"serializable": True})
    _client: Optional[InferenceClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> InferenceClient:
        return import_optional_dependency("huggingface_hub").InferenceClient(
            model=self.model,
            token=self.api_token,
        )

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        response = self.client.feature_extraction(chunk)

        return [float(val) for val in response.flatten().tolist()]
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_token = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
@lazy_property()
def client(self) -> InferenceClient:
    return import_optional_dependency("huggingface_hub").InferenceClient(
        model=self.model,
        token=self.api_token,
    )

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/huggingface_hub_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    response = self.client.feature_extraction(chunk)

    return [float(val) for val in response.flatten().tolist()]

HuggingFaceHubPromptDriver

Bases: BasePromptDriver

Attributes

NameTypeDescription
api_tokenstrHugging Face Hub API token.
use_gpustrUse GPU during model run.
modelstrHugging Face Hub model name.
clientInferenceClientCustom InferenceApi.
tokenizerHuggingFaceTokenizerCustom HuggingFaceTokenizer.
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@define
class HuggingFaceHubPromptDriver(BasePromptDriver):
    """Hugging Face Hub Prompt Driver.

    Attributes:
        api_token: Hugging Face Hub API token.
        use_gpu: Use GPU during model run.
        model: Hugging Face Hub model name.
        client: Custom `InferenceApi`.
        tokenizer: Custom `HuggingFaceTokenizer`.
    """

    api_token: str = field(kw_only=True, metadata={"serializable": True})
    max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
    model: str = field(kw_only=True, metadata={"serializable": True})
    structured_output_strategy: StructuredOutputStrategy = field(
        default="native", kw_only=True, metadata={"serializable": True}
    )
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens),
            takes_self=True,
        ),
        kw_only=True,
    )
    _client: Optional[InferenceClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> InferenceClient:
        return import_optional_dependency("huggingface_hub").InferenceClient(
            model=self.model,
            token=self.api_token,
        )

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value == "tool":
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        prompt = self.prompt_stack_to_string(prompt_stack)
        full_params = self._base_params(prompt_stack)
        logger.debug(
            {
                "prompt": prompt,
                **full_params,
            }
        )

        response = self.client.text_generation(
            prompt,
            **full_params,
        )
        logger.debug(response)
        input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
        output_tokens = len(self.tokenizer.tokenizer.encode(response))  # pyright: ignore[reportArgumentType]

        return Message(
            content=response,
            role=Message.ASSISTANT_ROLE,
            usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        prompt = self.prompt_stack_to_string(prompt_stack)
        full_params = {**self._base_params(prompt_stack), "stream": True}
        logger.debug(
            {
                "prompt": prompt,
                **full_params,
            }
        )

        response = self.client.text_generation(prompt, **full_params)

        input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))

        full_text = ""
        for token in response:
            logger.debug(token)
            full_text += token
            yield DeltaMessage(content=TextDeltaMessageContent(token, index=0))

        output_tokens = len(self.tokenizer.tokenizer.encode(full_text))  # pyright: ignore[reportArgumentType]
        yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens))

    def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
        return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))  # pyright: ignore[reportArgumentType]

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = {
            "return_full_text": False,
            "max_new_tokens": self.max_tokens,
            **self.extra_params,
        }

        if prompt_stack.output_schema and self.structured_output_strategy == "native":
            # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding
            output_schema = prompt_stack.to_output_json_schema()
            # Grammar does not support $schema and $id
            del output_schema["$schema"]
            del output_schema["$id"]
            params["grammar"] = {"type": "json", "value": output_schema}

        return params

    def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
        messages = []
        for message in prompt_stack.messages:
            if len(message.content) == 1:
                messages.append({"role": message.role, "content": message.to_text()})
            else:
                raise ValueError("Invalid input content length.")

        return messages

    def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
        messages = self._prompt_stack_to_messages(prompt_stack)
        tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

        if isinstance(tokens, list):
            return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
        raise ValueError("Invalid output type.")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_token = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • max_tokens = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • structured_output_strategy = field(default='native', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

__prompt_stack_to_tokens(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
    messages = self._prompt_stack_to_messages(prompt_stack)
    tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

    if isinstance(tokens, list):
        return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
    raise ValueError("Invalid output type.")

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    params = {
        "return_full_text": False,
        "max_new_tokens": self.max_tokens,
        **self.extra_params,
    }

    if prompt_stack.output_schema and self.structured_output_strategy == "native":
        # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding
        output_schema = prompt_stack.to_output_json_schema()
        # Grammar does not support $schema and $id
        del output_schema["$schema"]
        del output_schema["$id"]
        params["grammar"] = {"type": "json", "value": output_schema}

    return params

_prompt_stack_to_messages(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
    messages = []
    for message in prompt_stack.messages:
        if len(message.content) == 1:
            messages.append({"role": message.role, "content": message.to_text()})
        else:
            raise ValueError("Invalid input content length.")

    return messages

client()

Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@lazy_property()
def client(self) -> InferenceClient:
    return import_optional_dependency("huggingface_hub").InferenceClient(
        model=self.model,
        token=self.api_token,
    )

prompt_stack_to_string(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
    return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))  # pyright: ignore[reportArgumentType]

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    prompt = self.prompt_stack_to_string(prompt_stack)
    full_params = self._base_params(prompt_stack)
    logger.debug(
        {
            "prompt": prompt,
            **full_params,
        }
    )

    response = self.client.text_generation(
        prompt,
        **full_params,
    )
    logger.debug(response)
    input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
    output_tokens = len(self.tokenizer.tokenizer.encode(response))  # pyright: ignore[reportArgumentType]

    return Message(
        content=response,
        role=Message.ASSISTANT_ROLE,
        usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
    )

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    prompt = self.prompt_stack_to_string(prompt_stack)
    full_params = {**self._base_params(prompt_stack), "stream": True}
    logger.debug(
        {
            "prompt": prompt,
            **full_params,
        }
    )

    response = self.client.text_generation(prompt, **full_params)

    input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))

    full_text = ""
    for token in response:
        logger.debug(token)
        full_text += token
        yield DeltaMessage(content=TextDeltaMessageContent(token, index=0))

    output_tokens = len(self.tokenizer.tokenizer.encode(full_text))  # pyright: ignore[reportArgumentType]
    yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens))

validatestructured_output_strategy(, value)

Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value == "tool":
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value

HuggingFacePipelineImageGenerationDriver

Bases: BaseImageGenerationDriver , ABC

Attributes

NameTypeDescription
pipeline_driverBaseDiffusionImageGenerationPipelineDriverA pipeline image generation model driver typed for the specific pipeline required by the model.
deviceOptional[str]The hardware device used for inference. For example, "cpu", "cuda", or "mps".
output_formatstrThe format the generated image is returned in. Defaults to "png".
Source Code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
@define
class HuggingFacePipelineImageGenerationDriver(BaseImageGenerationDriver, ABC):
    """Image generation driver for models hosted by Hugging Face's Diffusion Pipeline.

    For more information, see the HuggingFace documentation for Diffusers:
            https://huggingface.co/docs/diffusers/en/index

    Attributes:
        pipeline_driver: A pipeline image generation model driver typed for the specific pipeline required by the model.
        device: The hardware device used for inference. For example, "cpu", "cuda", or "mps".
        output_format: The format the generated image is returned in. Defaults to "png".
    """

    pipeline_driver: BaseDiffusionImageGenerationPipelineDriver = field(kw_only=True, metadata={"serializable": True})
    device: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    output_format: str = field(default="png", kw_only=True, metadata={"serializable": True})

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

        prompt = ", ".join(prompts)
        output_image = pipeline(
            prompt, **self.pipeline_driver.make_additional_params(negative_prompts, self.device)
        ).images[0]

        buffer = io.BytesIO()
        output_image.save(buffer, format=self.output_format.upper())

        return ImageArtifact(
            value=buffer.getvalue(),
            format=self.output_format.lower(),
            height=output_image.height,
            width=output_image.width,
            meta={"prompt": prompt},
        )

    def try_image_variation(
        self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
    ) -> ImageArtifact:
        pil_image = import_optional_dependency("PIL.Image")

        pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

        prompt = ", ".join(prompts)
        input_image = pil_image.open(io.BytesIO(image.value))
        # The size of the input image drives the size of the output image.
        # Resize the input image to the configured dimensions.
        output_width, output_height = self.pipeline_driver.output_image_dimensions
        if input_image.height != output_height or input_image.width != output_width:
            input_image = input_image.resize((output_width, output_height))

        output_image = pipeline(
            prompt,
            **self.pipeline_driver.make_image_param(input_image),
            **self.pipeline_driver.make_additional_params(negative_prompts, self.device),
        ).images[0]

        buffer = io.BytesIO()
        output_image.save(buffer, format=self.output_format.upper())

        return ImageArtifact(
            value=buffer.getvalue(),
            format=self.output_format.lower(),
            height=output_image.height,
            width=output_image.width,
            meta={"prompt": prompt},
        )

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError("Inpainting is not supported by this driver.")

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError("Outpainting is not supported by this driver.")
  • device = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • output_format = field(default='png', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • pipeline_driver = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError("Inpainting is not supported by this driver.")

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError("Outpainting is not supported by this driver.")

try_image_variation(prompts, image, negative_prompts=None)

Source Code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_image_variation(
    self, prompts: list[str], image: ImageArtifact, negative_prompts: Optional[list[str]] = None
) -> ImageArtifact:
    pil_image = import_optional_dependency("PIL.Image")

    pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

    prompt = ", ".join(prompts)
    input_image = pil_image.open(io.BytesIO(image.value))
    # The size of the input image drives the size of the output image.
    # Resize the input image to the configured dimensions.
    output_width, output_height = self.pipeline_driver.output_image_dimensions
    if input_image.height != output_height or input_image.width != output_width:
        input_image = input_image.resize((output_width, output_height))

    output_image = pipeline(
        prompt,
        **self.pipeline_driver.make_image_param(input_image),
        **self.pipeline_driver.make_additional_params(negative_prompts, self.device),
    ).images[0]

    buffer = io.BytesIO()
    output_image.save(buffer, format=self.output_format.upper())

    return ImageArtifact(
        value=buffer.getvalue(),
        format=self.output_format.lower(),
        height=output_image.height,
        width=output_image.width,
        meta={"prompt": prompt},
    )

try_text_to_image(prompts, negative_prompts=None)

Source Code in griptape/drivers/image_generation/huggingface_pipeline_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    pipeline = self.pipeline_driver.prepare_pipeline(self.model, self.device)

    prompt = ", ".join(prompts)
    output_image = pipeline(
        prompt, **self.pipeline_driver.make_additional_params(negative_prompts, self.device)
    ).images[0]

    buffer = io.BytesIO()
    output_image.save(buffer, format=self.output_format.upper())

    return ImageArtifact(
        value=buffer.getvalue(),
        format=self.output_format.lower(),
        height=output_image.height,
        width=output_image.width,
        meta={"prompt": prompt},
    )

HuggingFacePipelinePromptDriver

Bases: BasePromptDriver

Attributes

NameTypeDescription
modelstrHugging Face Hub model name.
Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@define
class HuggingFacePipelinePromptDriver(BasePromptDriver):
    """Hugging Face Pipeline Prompt Driver.

    Attributes:
        model: Hugging Face Hub model name.
    """

    max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True})
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: HuggingFaceTokenizer = field(
        default=Factory(
            lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens),
            takes_self=True,
        ),
        kw_only=True,
    )
    structured_output_strategy: StructuredOutputStrategy = field(
        default="rule", kw_only=True, metadata={"serializable": True}
    )
    _pipeline: Optional[TextGenerationPipeline] = field(
        default=None, kw_only=True, alias="pipeline", metadata={"serializable": False}
    )

    @structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
        if value in ("native", "tool"):
            raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

        return value

    @lazy_property()
    def pipeline(self) -> TextGenerationPipeline:
        return import_optional_dependency("transformers").pipeline(
            task="text-generation",
            model=self.model,
            max_new_tokens=self.max_tokens,
            tokenizer=self.tokenizer.tokenizer,
        )

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        messages = self._prompt_stack_to_messages(prompt_stack)
        full_params = self._base_params(prompt_stack)
        logger.debug(
            (
                messages,
                full_params,
            )
        )

        result = self.pipeline(messages, **full_params)
        logger.debug(result)

        if isinstance(result, list):
            if len(result) == 1:
                generated_text = result[0]["generated_text"][-1]["content"]

                input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
                output_tokens = len(self.tokenizer.tokenizer.encode(generated_text))  # pyright: ignore[reportArgumentType]

                return Message(
                    content=[TextMessageContent(TextArtifact(generated_text))],
                    role=Message.ASSISTANT_ROLE,
                    usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
                )
            raise Exception("completion with more than one choice is not supported yet")
        raise Exception("invalid output format")

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        raise NotImplementedError("streaming is not supported")

    def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
        return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))  # pyright: ignore[reportArgumentType]

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        return {
            "max_new_tokens": self.max_tokens,
            "temperature": self.temperature,
            "do_sample": True,
            **self.extra_params,
        }

    def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
        messages = []

        for message in prompt_stack.messages:
            messages.append({"role": message.role, "content": message.to_text()})

        return messages

    def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
        messages = self._prompt_stack_to_messages(prompt_stack)
        tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

        if isinstance(tokens, list):
            return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
        raise ValueError("Invalid output type.")
  • _pipeline = field(default=None, kw_only=True, alias='pipeline', metadata={'serializable': False}) class-attribute instance-attribute

  • max_tokens = field(default=250, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • structured_output_strategy = field(default='rule', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

__prompt_stack_to_tokens(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]:
    messages = self._prompt_stack_to_messages(prompt_stack)
    tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True)

    if isinstance(tokens, list):
        return tokens  # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int].
    raise ValueError("Invalid output type.")

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    return {
        "max_new_tokens": self.max_tokens,
        "temperature": self.temperature,
        "do_sample": True,
        **self.extra_params,
    }

_prompt_stack_to_messages(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
    messages = []

    for message in prompt_stack.messages:
        messages.append({"role": message.role, "content": message.to_text()})

    return messages

pipeline()

Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@lazy_property()
def pipeline(self) -> TextGenerationPipeline:
    return import_optional_dependency("transformers").pipeline(
        task="text-generation",
        model=self.model,
        max_new_tokens=self.max_tokens,
        tokenizer=self.tokenizer.tokenizer,
    )

prompt_stack_to_string(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str:
    return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack))  # pyright: ignore[reportArgumentType]

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    messages = self._prompt_stack_to_messages(prompt_stack)
    full_params = self._base_params(prompt_stack)
    logger.debug(
        (
            messages,
            full_params,
        )
    )

    result = self.pipeline(messages, **full_params)
    logger.debug(result)

    if isinstance(result, list):
        if len(result) == 1:
            generated_text = result[0]["generated_text"][-1]["content"]

            input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack))
            output_tokens = len(self.tokenizer.tokenizer.encode(generated_text))  # pyright: ignore[reportArgumentType]

            return Message(
                content=[TextMessageContent(TextArtifact(generated_text))],
                role=Message.ASSISTANT_ROLE,
                usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens),
            )
        raise Exception("completion with more than one choice is not supported yet")
    raise Exception("invalid output format")

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    raise NotImplementedError("streaming is not supported")

validatestructured_output_strategy(, value)

Source Code in griptape/drivers/prompt/huggingface_pipeline_prompt_driver.py
@structured_output_strategy.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_structured_output_strategy(self, _: Attribute, value: str) -> str:
    if value in ("native", "tool"):
        raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.")

    return value

LeonardoImageGenerationDriver

Bases: BaseImageGenerationDriver

Attributes

NameTypeDescription
modelstrThe ID of the model to use when generating images.
api_keystrThe API key to use when making requests to the Leonardo API.
requests_sessionSessionThe requests session to use when making requests to the Leonardo API.
api_basestrThe base URL of the Leonardo API.
max_attemptsintThe maximum number of times to poll the Leonardo API for a completed image.
image_widthintThe width of the generated image in the range [32, 1024] and divisible by 8.
image_heightintThe height of the generated image in the range [32, 1024] and divisible by 8.
stepsOptional[int]Optionally specify the number of inference steps to run for each image generation request, [30, 60].
seedOptional[int]Optionally provide a consistent seed to generation requests, increasing consistency in output.
init_strengthOptional[float]Optionally specify the strength of the initial image, [0.0, 1.0].
Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
@define
class LeonardoImageGenerationDriver(BaseImageGenerationDriver):
    """Driver for the Leonardo image generation API.

    Details on Leonardo image generation parameters can be found here:
    https://docs.leonardo.ai/reference/creategeneration

    Attributes:
        model: The ID of the model to use when generating images.
        api_key: The API key to use when making requests to the Leonardo API.
        requests_session: The requests session to use when making requests to the Leonardo API.
        api_base: The base URL of the Leonardo API.
        max_attempts: The maximum number of times to poll the Leonardo API for a completed image.
        image_width: The width of the generated image in the range [32, 1024] and divisible by 8.
        image_height: The height of the generated image in the range [32, 1024] and divisible by 8.
        steps: Optionally specify the number of inference steps to run for each image generation request, [30, 60].
        seed: Optionally provide a consistent seed to generation requests, increasing consistency in output.
        init_strength: Optionally specify the strength of the initial image, [0.0, 1.0].
    """

    api_key: str = field(kw_only=True, metadata={"serializable": True})
    requests_session: requests.Session = field(default=Factory(lambda: requests.Session()), kw_only=True)
    api_base: str = "https://cloud.leonardo.ai/api/rest/v1"
    max_attempts: int = field(default=10, kw_only=True, metadata={"serializable": True})
    image_width: int = field(default=512, kw_only=True, metadata={"serializable": True})
    image_height: int = field(default=512, kw_only=True, metadata={"serializable": True})
    steps: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    init_strength: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
    control_net: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    control_net_type: Optional[Literal["POSE", "CANNY", "DEPTH"]] = field(
        default=None,
        kw_only=True,
        metadata={"serializable": True},
    )

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        if negative_prompts is None:
            negative_prompts = []

        generation_id = self._create_generation(prompts=prompts, negative_prompts=negative_prompts)
        image_url = self._get_image_url(generation_id=generation_id)
        image_data = self._download_image(url=image_url)

        return ImageArtifact(
            value=image_data,
            format="png",
            width=self.image_width,
            height=self.image_height,
            meta={
                "model": self.model,
                "prompt": ", ".join(prompts),
            },
        )

    def try_image_variation(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        if negative_prompts is None:
            negative_prompts = []

        init_image_id = self._upload_init_image(image)
        generation_id = self._create_generation(
            prompts=prompts,
            negative_prompts=negative_prompts,
            init_image_id=init_image_id,
        )
        image_url = self._get_image_url(generation_id=generation_id)
        image_data = self._download_image(url=image_url)

        return ImageArtifact(
            value=image_data,
            format="png",
            width=self.image_width,
            height=self.image_height,
            meta={
                "model": self.model,
                "prompt": ", ".join(prompts),
            },
        )

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting")

    def _upload_init_image(self, image: ImageArtifact) -> str:
        request = {"extension": image.mime_type.split("/")[1]}

        prep_response = self._make_api_request("/init-image", request=request)
        if prep_response is None or prep_response["uploadInitImage"] is None:
            raise Exception(f"failed to prepare init image: {prep_response}")

        fields = json.loads(prep_response["uploadInitImage"]["fields"])
        pre_signed_url = prep_response["uploadInitImage"]["url"]
        init_image_id = prep_response["uploadInitImage"]["id"]

        files = {"file": image.value}
        upload_response = requests.post(pre_signed_url, data=fields, files=files)
        if not upload_response.ok:
            raise Exception(f"failed to upload init image: {upload_response.text}")

        return init_image_id

    def _create_generation(
        self,
        prompts: list[str],
        negative_prompts: list[str],
        init_image_id: Optional[str] = None,
    ) -> str:
        prompt = ", ".join(prompts)
        negative_prompt = ", ".join(negative_prompts)
        request = {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "width": self.image_width,
            "height": self.image_height,
            "num_images": 1,
            "modelId": self.model,
        }

        if init_image_id is not None:
            request["init_image_id"] = init_image_id

        if self.init_strength is not None:
            request["init_strength"] = self.init_strength

        if self.steps:
            request["num_inference_steps"] = self.steps

        if self.seed is not None:
            request["seed"] = self.seed

        if self.control_net:
            request["controlNet"] = self.control_net
            request["controlNetType"] = self.control_net_type

        response = self._make_api_request("/generations", request=request)
        if response is None or response["sdGenerationJob"] is None:
            raise Exception(f"failed to create generation: {response}")

        return response["sdGenerationJob"]["generationId"]

    def _make_api_request(self, endpoint: str, request: dict, method: str = "POST") -> dict:
        url = f"{self.api_base}{endpoint}"
        headers = {"Authorization": f"Bearer {self.api_key}"}

        response = self.requests_session.request(url=url, method=method, json=request, headers=headers)
        if not response.ok:
            raise Exception(f"failed to make API request: {response.text}")

        return response.json()

    def _get_image_url(self, generation_id: str) -> str:
        for attempt in range(self.max_attempts):
            response = self.requests_session.get(
                url=f"{self.api_base}/generations/{generation_id}",
                headers={"Authorization": f"Bearer {self.api_key}"},
            ).json()

            if response["generations_by_pk"]["status"] == "PENDING":
                time.sleep(attempt + 1)
                continue

            return response["generations_by_pk"]["generated_images"][0]["url"]
        raise Exception("image generation failed to complete")

    def _download_image(self, url: str) -> bytes:
        response = self.requests_session.get(url=url, headers={"Authorization": f"Bearer {self.api_key}"})

        return response.content
  • api_base = 'https://cloud.leonardo.ai/api/rest/v1' class-attribute instance-attribute

  • api_key = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • control_net = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • control_net_type = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • image_height = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • image_width = field(default=512, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • init_strength = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • max_attempts = field(default=10, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • requests_session = field(default=Factory(lambda: requests.Session()), kw_only=True) class-attribute instance-attribute

  • seed = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • steps = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

_create_generation(prompts, negative_prompts, init_image_id=None)

Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def _create_generation(
    self,
    prompts: list[str],
    negative_prompts: list[str],
    init_image_id: Optional[str] = None,
) -> str:
    prompt = ", ".join(prompts)
    negative_prompt = ", ".join(negative_prompts)
    request = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "width": self.image_width,
        "height": self.image_height,
        "num_images": 1,
        "modelId": self.model,
    }

    if init_image_id is not None:
        request["init_image_id"] = init_image_id

    if self.init_strength is not None:
        request["init_strength"] = self.init_strength

    if self.steps:
        request["num_inference_steps"] = self.steps

    if self.seed is not None:
        request["seed"] = self.seed

    if self.control_net:
        request["controlNet"] = self.control_net
        request["controlNetType"] = self.control_net_type

    response = self._make_api_request("/generations", request=request)
    if response is None or response["sdGenerationJob"] is None:
        raise Exception(f"failed to create generation: {response}")

    return response["sdGenerationJob"]["generationId"]

_download_image(url)

Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def _download_image(self, url: str) -> bytes:
    response = self.requests_session.get(url=url, headers={"Authorization": f"Bearer {self.api_key}"})

    return response.content

_get_image_url(generation_id)

Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def _get_image_url(self, generation_id: str) -> str:
    for attempt in range(self.max_attempts):
        response = self.requests_session.get(
            url=f"{self.api_base}/generations/{generation_id}",
            headers={"Authorization": f"Bearer {self.api_key}"},
        ).json()

        if response["generations_by_pk"]["status"] == "PENDING":
            time.sleep(attempt + 1)
            continue

        return response["generations_by_pk"]["generated_images"][0]["url"]
    raise Exception("image generation failed to complete")

_make_api_request(endpoint, request, method='POST')

Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def _make_api_request(self, endpoint: str, request: dict, method: str = "POST") -> dict:
    url = f"{self.api_base}{endpoint}"
    headers = {"Authorization": f"Bearer {self.api_key}"}

    response = self.requests_session.request(url=url, method=method, json=request, headers=headers)
    if not response.ok:
        raise Exception(f"failed to make API request: {response.text}")

    return response.json()

_upload_init_image(image)

Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def _upload_init_image(self, image: ImageArtifact) -> str:
    request = {"extension": image.mime_type.split("/")[1]}

    prep_response = self._make_api_request("/init-image", request=request)
    if prep_response is None or prep_response["uploadInitImage"] is None:
        raise Exception(f"failed to prepare init image: {prep_response}")

    fields = json.loads(prep_response["uploadInitImage"]["fields"])
    pre_signed_url = prep_response["uploadInitImage"]["url"]
    init_image_id = prep_response["uploadInitImage"]["id"]

    files = {"file": image.value}
    upload_response = requests.post(pre_signed_url, data=fields, files=files)
    if not upload_response.ok:
        raise Exception(f"failed to upload init image: {upload_response.text}")

    return init_image_id

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support inpainting")

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

try_image_variation(prompts, image, negative_prompts=None)

Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_image_variation(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    if negative_prompts is None:
        negative_prompts = []

    init_image_id = self._upload_init_image(image)
    generation_id = self._create_generation(
        prompts=prompts,
        negative_prompts=negative_prompts,
        init_image_id=init_image_id,
    )
    image_url = self._get_image_url(generation_id=generation_id)
    image_data = self._download_image(url=image_url)

    return ImageArtifact(
        value=image_data,
        format="png",
        width=self.image_width,
        height=self.image_height,
        meta={
            "model": self.model,
            "prompt": ", ".join(prompts),
        },
    )

try_text_to_image(prompts, negative_prompts=None)

Source Code in griptape/drivers/image_generation/leonardo_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    if negative_prompts is None:
        negative_prompts = []

    generation_id = self._create_generation(prompts=prompts, negative_prompts=negative_prompts)
    image_url = self._get_image_url(generation_id=generation_id)
    image_data = self._download_image(url=image_url)

    return ImageArtifact(
        value=image_data,
        format="png",
        width=self.image_width,
        height=self.image_height,
        meta={
            "model": self.model,
            "prompt": ", ".join(prompts),
        },
    )

LocalConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Source Code in griptape/drivers/memory/conversation/local_conversation_memory_driver.py
@define(kw_only=True)
class LocalConversationMemoryDriver(BaseConversationMemoryDriver):
    persist_file: Optional[str] = field(default=None, metadata={"serializable": True})

    def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
        if self.persist_file is not None:
            Path(self.persist_file).write_text(json.dumps(self._to_params_dict(runs, metadata)))

    def load(self) -> tuple[list[Run], dict[str, Any]]:
        if (
            self.persist_file is not None
            and os.path.exists(self.persist_file)
            and (loaded_str := Path(self.persist_file).read_text()) is not None
        ):
            try:
                return self._from_params_dict(json.loads(loaded_str))
            except Exception as e:
                raise ValueError(f"Unable to load data from {self.persist_file}") from e

        return [], {}
  • persist_file = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

load()

Source Code in griptape/drivers/memory/conversation/local_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]:
    if (
        self.persist_file is not None
        and os.path.exists(self.persist_file)
        and (loaded_str := Path(self.persist_file).read_text()) is not None
    ):
        try:
            return self._from_params_dict(json.loads(loaded_str))
        except Exception as e:
            raise ValueError(f"Unable to load data from {self.persist_file}") from e

    return [], {}

store(runs, metadata)

Source Code in griptape/drivers/memory/conversation/local_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
    if self.persist_file is not None:
        Path(self.persist_file).write_text(json.dumps(self._to_params_dict(runs, metadata)))

LocalFileManagerDriver

Bases: BaseFileManagerDriver

Attributes

NameTypeDescription
workdirstrThe working directory as an absolute path. List, load, and save operations will be performed relative to this directory. Defaults to the current working directory. Setting this to None will disable the working directory and all paths will be treated as absolute paths.
Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
@define
class LocalFileManagerDriver(BaseFileManagerDriver):
    """LocalFileManagerDriver can be used to list, load, and save files on the local file system.

    Attributes:
        workdir: The working directory as an absolute path. List, load, and save operations will be performed relative to this directory.
                 Defaults to the current working directory.
                 Setting this to None will disable the working directory and all paths will be treated as absolute paths.
    """

    _workdir: str = field(default=Factory(lambda: os.getcwd()), kw_only=True, alias="workdir")

    @property
    def workdir(self) -> str:
        if os.path.isabs(self._workdir):
            return self._workdir
        return os.path.join(os.getcwd(), self._workdir)

    @workdir.setter
    def workdir(self, value: str) -> None:
        self._workdir = value

    def try_list_files(self, path: str) -> list[str]:
        full_path = self._full_path(path)
        return os.listdir(full_path)

    def try_load_file(self, path: str) -> bytes:
        full_path = self._full_path(path)
        if self._is_dir(full_path):
            raise IsADirectoryError
        return Path(full_path).read_bytes()

    def try_save_file(self, path: str, value: bytes) -> str:
        full_path = self._full_path(path)
        if self._is_dir(full_path):
            raise IsADirectoryError
        os.makedirs(os.path.dirname(full_path), exist_ok=True)
        Path(full_path).write_bytes(value)
        return full_path

    def _full_path(self, path: str) -> str:
        full_path = path if os.path.isabs(path) else os.path.join(self.workdir, path.lstrip("/"))
        # Need to keep the trailing slash if it was there,
        # because it means the path is a directory.
        ended_with_sep = path.endswith("/")
        full_path = os.path.normpath(full_path)
        if ended_with_sep:
            full_path = full_path.rstrip("/") + "/"
        return full_path

    def _is_dir(self, full_path: str) -> bool:
        return full_path.endswith("/") or Path(full_path).is_dir()
  • _workdir = field(default=Factory(lambda: os.getcwd()), kw_only=True, alias='workdir') class-attribute instance-attribute

  • workdir property writable

_full_path(path)

Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
def _full_path(self, path: str) -> str:
    full_path = path if os.path.isabs(path) else os.path.join(self.workdir, path.lstrip("/"))
    # Need to keep the trailing slash if it was there,
    # because it means the path is a directory.
    ended_with_sep = path.endswith("/")
    full_path = os.path.normpath(full_path)
    if ended_with_sep:
        full_path = full_path.rstrip("/") + "/"
    return full_path

_is_dir(full_path)

Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
def _is_dir(self, full_path: str) -> bool:
    return full_path.endswith("/") or Path(full_path).is_dir()

try_list_files(path)

Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
def try_list_files(self, path: str) -> list[str]:
    full_path = self._full_path(path)
    return os.listdir(full_path)

try_load_file(path)

Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
def try_load_file(self, path: str) -> bytes:
    full_path = self._full_path(path)
    if self._is_dir(full_path):
        raise IsADirectoryError
    return Path(full_path).read_bytes()

try_save_file(path, value)

Source Code in griptape/drivers/file_manager/local_file_manager_driver.py
def try_save_file(self, path: str, value: bytes) -> str:
    full_path = self._full_path(path)
    if self._is_dir(full_path):
        raise IsADirectoryError
    os.makedirs(os.path.dirname(full_path), exist_ok=True)
    Path(full_path).write_bytes(value)
    return full_path

LocalRerankDriver

Bases: BaseRerankDriver , FuturesExecutorMixin

Source Code in griptape/drivers/rerank/local_rerank_driver.py
@define(kw_only=True)
class LocalRerankDriver(BaseRerankDriver, FuturesExecutorMixin):
    calculate_relatedness: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)))
    embedding_driver: BaseEmbeddingDriver = field(
        kw_only=True, default=Factory(lambda: Defaults.drivers_config.embedding_driver), metadata={"serializable": True}
    )

    def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
        query_embedding = self.embedding_driver.embed(query, vector_operation="query")

        with self.create_futures_executor() as futures_executor:
            artifact_embeddings = execute_futures_list(
                [
                    futures_executor.submit(
                        with_contextvars(self.embedding_driver.embed_text_artifact), a, vector_operation="upsert"
                    )
                    for a in artifacts
                ],
            )

        artifacts_and_relatednesses = [
            (artifact, self.calculate_relatedness(query_embedding, artifact_embedding))
            for artifact, artifact_embedding in zip(artifacts, artifact_embeddings)
        ]

        artifacts_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)

        return [artifact for artifact, _ in artifacts_and_relatednesses]
  • calculate_relatedness = field(default=lambda x, y: dot(x, y) / norm(x) * norm(y)) class-attribute instance-attribute

  • embedding_driver = field(kw_only=True, default=Factory(lambda: Defaults.drivers_config.embedding_driver), metadata={'serializable': True}) class-attribute instance-attribute

run(query, artifacts)

Source Code in griptape/drivers/rerank/local_rerank_driver.py
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]:
    query_embedding = self.embedding_driver.embed(query, vector_operation="query")

    with self.create_futures_executor() as futures_executor:
        artifact_embeddings = execute_futures_list(
            [
                futures_executor.submit(
                    with_contextvars(self.embedding_driver.embed_text_artifact), a, vector_operation="upsert"
                )
                for a in artifacts
            ],
        )

    artifacts_and_relatednesses = [
        (artifact, self.calculate_relatedness(query_embedding, artifact_embedding))
        for artifact, artifact_embedding in zip(artifacts, artifact_embeddings)
    ]

    artifacts_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)

    return [artifact for artifact, _ in artifacts_and_relatednesses]

LocalRulesetDriver

Bases: BaseRulesetDriver

Source Code in griptape/drivers/ruleset/local_ruleset_driver.py
@define(kw_only=True)
class LocalRulesetDriver(BaseRulesetDriver):
    persist_dir: Optional[str] = field(default=None, metadata={"serializable": True})

    def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]:
        if self.persist_dir is None:
            return [], {}

        file_name = os.path.join(self.persist_dir, ruleset_name)

        if (
            file_name is not None
            and os.path.exists(file_name)
            and (loaded_str := Path(file_name).read_text()) is not None
        ):
            try:
                return self._from_ruleset_dict(json.loads(loaded_str))
            except Exception as e:
                raise ValueError(f"Unable to load data from {file_name}") from e

        if self.raise_not_found:
            raise ValueError(f"Ruleset not found with name {file_name}")
        return [], {}
  • persist_dir = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

load(ruleset_name)

Source Code in griptape/drivers/ruleset/local_ruleset_driver.py
def load(self, ruleset_name: str) -> tuple[list[BaseRule], dict[str, Any]]:
    if self.persist_dir is None:
        return [], {}

    file_name = os.path.join(self.persist_dir, ruleset_name)

    if (
        file_name is not None
        and os.path.exists(file_name)
        and (loaded_str := Path(file_name).read_text()) is not None
    ):
        try:
            return self._from_ruleset_dict(json.loads(loaded_str))
        except Exception as e:
            raise ValueError(f"Unable to load data from {file_name}") from e

    if self.raise_not_found:
        raise ValueError(f"Ruleset not found with name {file_name}")
    return [], {}

LocalStructureRunDriver

Bases: BaseStructureRunDriver

Source Code in griptape/drivers/structure_run/local_structure_run_driver.py
@define
class LocalStructureRunDriver(BaseStructureRunDriver):
    create_structure: Callable[[], Structure] = field(kw_only=True)

    def try_run(self, *args: BaseArtifact) -> BaseArtifact:
        old_env = os.environ.copy()
        try:
            os.environ.update(self.env)
            structure = self.create_structure().run(*[arg.value for arg in args])
        finally:
            os.environ.clear()
            os.environ.update(old_env)

        return structure.output
  • create_structure = field(kw_only=True) class-attribute instance-attribute

try_run(*args)

Source Code in griptape/drivers/structure_run/local_structure_run_driver.py
def try_run(self, *args: BaseArtifact) -> BaseArtifact:
    old_env = os.environ.copy()
    try:
        os.environ.update(self.env)
        structure = self.create_structure().run(*[arg.value for arg in args])
    finally:
        os.environ.clear()
        os.environ.update(old_env)

    return structure.output

LocalVectorStoreDriver

Bases: BaseVectorStoreDriver

Source Code in griptape/drivers/vector/local_vector_store_driver.py
@define(kw_only=True)
class LocalVectorStoreDriver(BaseVectorStoreDriver):
    entries: dict[str, BaseVectorStoreDriver.Entry] = field(factory=dict)
    persist_file: Optional[str] = field(default=None)
    calculate_relatedness: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y)))
    thread_lock: threading.Lock = field(default=Factory(lambda: threading.Lock()))

    def __attrs_post_init__(self) -> None:
        if self.persist_file is not None:
            directory = os.path.dirname(self.persist_file)

            if directory and not os.path.exists(directory):
                os.makedirs(directory)

            if not os.path.isfile(self.persist_file):
                with open(self.persist_file, "w") as file:
                    self.__save_entries_to_file(file)

            with open(self.persist_file, "r+") as file:
                if os.path.getsize(self.persist_file) > 0:
                    self.entries = self.load_entries_from_file(file)
                else:
                    self.__save_entries_to_file(file)

    def load_entries_from_file(self, json_file: TextIO) -> dict[str, BaseVectorStoreDriver.Entry]:
        with self.thread_lock:
            data = json.load(json_file)

            return {k: BaseVectorStoreDriver.Entry.from_dict(v) for k, v in data.items()}

    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        vector_id = vector_id or utils.str_to_hash(str(vector))

        with self.thread_lock:
            self.entries[self.__namespaced_vector_id(vector_id, namespace=namespace)] = self.Entry(
                id=vector_id,
                vector=vector,
                meta=meta,
                namespace=namespace,
            )

        if self.persist_file is not None:
            # TODO: optimize later since it reserializes all entries from memory and stores them in the JSON file
            #  every time a new vector is inserted
            with open(self.persist_file, "w") as file:
                self.__save_entries_to_file(file)

        return vector_id

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        return self.entries.get(self.__namespaced_vector_id(vector_id, namespace=namespace), None)

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        if namespace:
            entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")}
        else:
            entries = self.entries

        entries_and_relatednesses = [
            (entry, self.calculate_relatedness(vector, entry.vector)) for entry in list(entries.values())
        ]

        entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)

        result = [
            BaseVectorStoreDriver.Entry(
                id=er[0].id, vector=er[0].vector, score=er[1], meta=er[0].meta, namespace=er[0].namespace
            )
            for er in entries_and_relatednesses
        ][:count]

        if include_vectors:
            return result
        return [
            BaseVectorStoreDriver.Entry(id=r.id, vector=[], score=r.score, meta=r.meta, namespace=r.namespace)
            for r in result
        ]

    def delete_vector(self, vector_id: str) -> NoReturn:
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

    def __save_entries_to_file(self, json_file: TextIO) -> None:
        with self.thread_lock:
            serialized_data = {k: v.to_dict() for k, v in self.entries.items()}

            json.dump(serialized_data, json_file)

    def __namespaced_vector_id(self, vector_id: str, *, namespace: Optional[str]) -> str:
        return vector_id if namespace is None else f"{namespace}-{vector_id}"
  • calculate_relatedness = field(default=lambda x, y: dot(x, y) / norm(x) * norm(y)) class-attribute instance-attribute

  • entries = field(factory=dict) class-attribute instance-attribute

  • persist_file = field(default=None) class-attribute instance-attribute

  • thread_lock = field(default=Factory(lambda: threading.Lock())) class-attribute instance-attribute

attrs_post_init()

Source Code in griptape/drivers/vector/local_vector_store_driver.py
def __attrs_post_init__(self) -> None:
    if self.persist_file is not None:
        directory = os.path.dirname(self.persist_file)

        if directory and not os.path.exists(directory):
            os.makedirs(directory)

        if not os.path.isfile(self.persist_file):
            with open(self.persist_file, "w") as file:
                self.__save_entries_to_file(file)

        with open(self.persist_file, "r+") as file:
            if os.path.getsize(self.persist_file) > 0:
                self.entries = self.load_entries_from_file(file)
            else:
                self.__save_entries_to_file(file)

__namespaced_vector_id(vector_id, *, namespace)

Source Code in griptape/drivers/vector/local_vector_store_driver.py
def __namespaced_vector_id(self, vector_id: str, *, namespace: Optional[str]) -> str:
    return vector_id if namespace is None else f"{namespace}-{vector_id}"

__save_entries_to_file(json_file)

Source Code in griptape/drivers/vector/local_vector_store_driver.py
def __save_entries_to_file(self, json_file: TextIO) -> None:
    with self.thread_lock:
        serialized_data = {k: v.to_dict() for k, v in self.entries.items()}

        json.dump(serialized_data, json_file)

delete_vector(vector_id)

Source Code in griptape/drivers/vector/local_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn:
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

load_entries(*, namespace=None)

Source Code in griptape/drivers/vector/local_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    return [entry for key, entry in self.entries.items() if namespace is None or entry.namespace == namespace]

load_entries_from_file(json_file)

Source Code in griptape/drivers/vector/local_vector_store_driver.py
def load_entries_from_file(self, json_file: TextIO) -> dict[str, BaseVectorStoreDriver.Entry]:
    with self.thread_lock:
        data = json.load(json_file)

        return {k: BaseVectorStoreDriver.Entry.from_dict(v) for k, v in data.items()}

load_entry(vector_id, *, namespace=None)

Source Code in griptape/drivers/vector/local_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    return self.entries.get(self.__namespaced_vector_id(vector_id, namespace=namespace), None)

query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)

Source Code in griptape/drivers/vector/local_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    if namespace:
        entries = {k: v for (k, v) in self.entries.items() if k.startswith(f"{namespace}-")}
    else:
        entries = self.entries

    entries_and_relatednesses = [
        (entry, self.calculate_relatedness(vector, entry.vector)) for entry in list(entries.values())
    ]

    entries_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True)

    result = [
        BaseVectorStoreDriver.Entry(
            id=er[0].id, vector=er[0].vector, score=er[1], meta=er[0].meta, namespace=er[0].namespace
        )
        for er in entries_and_relatednesses
    ][:count]

    if include_vectors:
        return result
    return [
        BaseVectorStoreDriver.Entry(id=r.id, vector=[], score=r.score, meta=r.meta, namespace=r.namespace)
        for r in result
    ]

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/local_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    vector_id = vector_id or utils.str_to_hash(str(vector))

    with self.thread_lock:
        self.entries[self.__namespaced_vector_id(vector_id, namespace=namespace)] = self.Entry(
            id=vector_id,
            vector=vector,
            meta=meta,
            namespace=namespace,
        )

    if self.persist_file is not None:
        # TODO: optimize later since it reserializes all entries from memory and stores them in the JSON file
        #  every time a new vector is inserted
        with open(self.persist_file, "w") as file:
            self.__save_entries_to_file(file)

    return vector_id

MarkdownifyWebScraperDriver

Bases: BaseWebScraperDriver

Attributes

NameTypeDescription
include_linksboolIf True, the driver will include link urls in the markdown output.
exclude_tagslist[str]Optionally provide custom tags to exclude from the scraped content.
exclude_classeslist[str]Optionally provide custom classes to exclude from the scraped content.
exclude_idslist[str]Optionally provide custom ids to exclude from the scraped content.
timeoutOptional[int]Optionally provide a timeout in milliseconds for the page to continue loading after the browser has emitted the "load" event.
Source Code in griptape/drivers/web_scraper/markdownify_web_scraper_driver.py
@define
class MarkdownifyWebScraperDriver(BaseWebScraperDriver):
    """Driver to scrape a webpage and return the content in markdown format.

    As a prerequisite to using MarkdownifyWebScraperDriver, you need to install the browsers used by
    playwright. You can do this by running: `poetry run playwright install`.
    For more details about playwright, see https://playwright.dev/python/docs/library.

    Attributes:
        include_links: If `True`, the driver will include link urls in the markdown output.
        exclude_tags: Optionally provide custom tags to exclude from the scraped content.
        exclude_classes: Optionally provide custom classes to exclude from the scraped content.
        exclude_ids: Optionally provide custom ids to exclude from the scraped content.
        timeout: Optionally provide a timeout in milliseconds for the page to continue loading after
            the browser has emitted the "load" event.
    """

    DEFAULT_EXCLUDE_TAGS = ["script", "style", "head", "audio", "img", "picture", "source", "video"]

    include_links: bool = field(default=True, kw_only=True)
    exclude_tags: list[str] = field(
        default=Factory(lambda self: self.DEFAULT_EXCLUDE_TAGS, takes_self=True),
        kw_only=True,
    )
    exclude_classes: list[str] = field(default=Factory(list), kw_only=True)
    exclude_ids: list[str] = field(default=Factory(list), kw_only=True)
    timeout: Optional[int] = field(default=None, kw_only=True)

    def fetch_url(self, url: str) -> str:
        sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright

        with sync_playwright() as p, p.chromium.launch(headless=True) as browser:
            page = browser.new_page()

            def skip_loading_images(route: Any) -> Any:
                if route.request.resource_type == "image":
                    return route.abort()
                route.continue_()
                return None

            page.route("**/*", skip_loading_images)

            page.goto(url)

            # Some websites require a delay before the content is fully loaded
            # even after the browser has emitted "load" event.
            if self.timeout:
                page.wait_for_timeout(self.timeout)

            content = page.content()

            if not content:
                raise Exception("can't access URL")

            return content

    def extract_page(self, page: str) -> TextArtifact:
        bs4 = import_optional_dependency("bs4")
        markdownify = import_optional_dependency("markdownify")
        include_links = self.include_links

        # Custom MarkdownConverter to optionally linked urls. If include_links is False only
        # the text of the link is returned.
        class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter):
            def convert_a(self, el: Any, text: str, parent_tags: Any) -> str:
                if include_links:
                    return super().convert_a(el, text, parent_tags)
                return text

        soup = bs4.BeautifulSoup(page, "html.parser")

        # Remove unwanted elements
        exclude_selector = ",".join(
            self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids],
        )
        if exclude_selector:
            for s in soup.select(exclude_selector):
                s.extract()

        text = OptionalLinksMarkdownConverter().convert_soup(soup)

        # Remove leading and trailing whitespace from the entire text
        text = text.strip()

        # Remove trailing whitespace from each line
        text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE)

        # Indent using 2 spaces instead of tabs
        text = re.sub(r"(\n?\s*?)\t", r"\1  ", text)

        # Remove triple+ newlines (keep double newlines for paragraphs)
        text = re.sub(r"\n\n+", "\n\n", text)

        return TextArtifact(text)
  • DEFAULT_EXCLUDE_TAGS = ['script', 'style', 'head', 'audio', 'img', 'picture', 'source', 'video'] class-attribute instance-attribute

  • exclude_classes = field(default=Factory(list), kw_only=True) class-attribute instance-attribute

  • exclude_ids = field(default=Factory(list), kw_only=True) class-attribute instance-attribute

  • exclude_tags = field(default=Factory(lambda self: self.DEFAULT_EXCLUDE_TAGS, takes_self=True), kw_only=True) class-attribute instance-attribute

  • include_links = field(default=True, kw_only=True) class-attribute instance-attribute

  • timeout = field(default=None, kw_only=True) class-attribute instance-attribute

extract_page(page)

Source Code in griptape/drivers/web_scraper/markdownify_web_scraper_driver.py
def extract_page(self, page: str) -> TextArtifact:
    bs4 = import_optional_dependency("bs4")
    markdownify = import_optional_dependency("markdownify")
    include_links = self.include_links

    # Custom MarkdownConverter to optionally linked urls. If include_links is False only
    # the text of the link is returned.
    class OptionalLinksMarkdownConverter(markdownify.MarkdownConverter):
        def convert_a(self, el: Any, text: str, parent_tags: Any) -> str:
            if include_links:
                return super().convert_a(el, text, parent_tags)
            return text

    soup = bs4.BeautifulSoup(page, "html.parser")

    # Remove unwanted elements
    exclude_selector = ",".join(
        self.exclude_tags + [f".{c}" for c in self.exclude_classes] + [f"#{i}" for i in self.exclude_ids],
    )
    if exclude_selector:
        for s in soup.select(exclude_selector):
            s.extract()

    text = OptionalLinksMarkdownConverter().convert_soup(soup)

    # Remove leading and trailing whitespace from the entire text
    text = text.strip()

    # Remove trailing whitespace from each line
    text = re.sub(r"[ \t]+$", "", text, flags=re.MULTILINE)

    # Indent using 2 spaces instead of tabs
    text = re.sub(r"(\n?\s*?)\t", r"\1  ", text)

    # Remove triple+ newlines (keep double newlines for paragraphs)
    text = re.sub(r"\n\n+", "\n\n", text)

    return TextArtifact(text)

fetch_url(url)

Source Code in griptape/drivers/web_scraper/markdownify_web_scraper_driver.py
def fetch_url(self, url: str) -> str:
    sync_playwright = import_optional_dependency("playwright.sync_api").sync_playwright

    with sync_playwright() as p, p.chromium.launch(headless=True) as browser:
        page = browser.new_page()

        def skip_loading_images(route: Any) -> Any:
            if route.request.resource_type == "image":
                return route.abort()
            route.continue_()
            return None

        page.route("**/*", skip_loading_images)

        page.goto(url)

        # Some websites require a delay before the content is fully loaded
        # even after the browser has emitted "load" event.
        if self.timeout:
            page.wait_for_timeout(self.timeout)

        content = page.content()

        if not content:
            raise Exception("can't access URL")

        return content

MarqoVectorStoreDriver

Bases: BaseVectorStoreDriver

Attributes

NameTypeDescription
api_keystrThe API key for the Marqo API.
urlstrThe URL to the Marqo API.
clientClientAn optional Marqo client. Defaults to a new client with the given URL and API key.
indexstrThe name of the index to use.
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
@define
class MarqoVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for Marqo.

    Attributes:
        api_key: The API key for the Marqo API.
        url: The URL to the Marqo API.
        client: An optional Marqo client. Defaults to a new client with the given URL and API key.
        index: The name of the index to use.
    """

    api_key: str = field(kw_only=True, metadata={"serializable": True})
    url: str = field(kw_only=True, metadata={"serializable": True})
    index: str = field(kw_only=True, metadata={"serializable": True})
    _client: Optional[marqo.Client] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> marqo.Client:
        return import_optional_dependency("marqo").Client(self.url, api_key=self.api_key)

    def upsert(
        self,
        value: str | TextArtifact | ImageArtifact,
        *,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        vector_id: Optional[str] = None,
        **kwargs: Any,
    ) -> str:
        """Upsert a text document into the Marqo index.

        Args:
            value: The value to be indexed.
            namespace: An optional namespace for the document.
            meta: An optional dictionary of metadata for the document.
            vector_id: The ID for the vector. If None, Marqo will generate an ID.
            kwargs: Additional keyword arguments to pass to the Marqo client.

        Returns:
            str: The ID of the document that was added.
        """
        if isinstance(value, TextArtifact):
            artifact_json = value.to_json()
            vector_id = utils.str_to_hash(value.value) if vector_id is None else vector_id

            doc = {
                "_id": vector_id,
                "Description": value.value,
                "artifact": str(artifact_json),
            }
        elif isinstance(value, ImageArtifact):
            raise NotImplementedError("`MarqoVectorStoreDriver` does not upserting Image Artifacts.")
        else:
            doc = {"_id": vector_id, "Description": value}

        # Non-tensor fields
        if meta:
            doc["meta"] = str(meta)
        if namespace:
            doc["namespace"] = namespace

        response = self.client.index(self.index).add_documents([doc], tensor_fields=["Description"])
        if isinstance(response, dict) and "items" in response and response["items"]:
            return response["items"][0]["_id"]
        raise ValueError(f"Failed to upsert text: {response}")

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Load a document entry from the Marqo index.

        Args:
            vector_id: The ID of the vector to load.
            namespace: The namespace of the vector to load.

        Returns:
            The loaded Entry if found, otherwise None.
        """
        result = self.client.index(self.index).get_document(document_id=vector_id, expose_facets=True)

        if result and "_tensor_facets" in result and len(result["_tensor_facets"]) > 0:
            return BaseVectorStoreDriver.Entry(
                id=result["_id"],
                meta={k: v for k, v in result.items() if k != "_id"},
                vector=result["_tensor_facets"][0]["_embedding"],
            )
        return None

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Load all document entries from the Marqo index.

        Args:
            namespace: The namespace to filter entries by.

        Returns:
            The list of loaded Entries.
        """
        filter_string = f"namespace:{namespace}" if namespace else None

        if filter_string is not None:
            results = self.client.index(self.index).search("", limit=10000, filter_string=filter_string)
        else:
            results = self.client.index(self.index).search("", limit=10000)

        # get all _id's from search results
        ids = [r["_id"] for r in results["hits"]]

        # get documents corresponding to the ids
        documents = self.client.index(self.index).get_documents(document_ids=ids, expose_facets=True)

        # for each document, if it's found, create an Entry object
        entries = []
        for doc in documents["results"]:
            if doc["_found"]:
                entries.append(
                    BaseVectorStoreDriver.Entry(
                        id=doc["_id"],
                        vector=doc["_tensor_facets"][0]["_embedding"],
                        meta={k: v for k, v in doc.items() if k not in ["_id", "_tensor_facets", "_found"]},
                        namespace=doc.get("namespace"),
                    ),
                )

        return entries

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        include_metadata: bool = True,
        **kwargs: Any,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Query the Marqo index for documents.

        Args:
            vector: The vector to query by.
            count: The maximum number of results to return.
            namespace: The namespace to filter results by.
            include_vectors: Whether to include vector data in the results.
            include_metadata: Whether to include metadata in the results.
            kwargs: Additional keyword arguments to pass to the Marqo client.

        Returns:
            The list of query results.
        """
        params = {
            "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
            "attributes_to_retrieve": None if include_metadata else ["_id"],
            "filter_string": f"namespace:{namespace}" if namespace else None,
        } | kwargs

        results = self.client.index(self.index).search(**params, context={"tensor": [vector], "weight": 1})

        return self.__process_results(results, include_vectors=include_vectors)

    def query(
        self,
        query: str | TextArtifact | ImageArtifact,
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        include_metadata: bool = True,
        **kwargs: Any,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Query the Marqo index for documents.

        Args:
            query: The query string.
            count: The maximum number of results to return.
            namespace: The namespace to filter results by.
            include_vectors: Whether to include vector data in the results.
            include_metadata: Whether to include metadata in the results.
            kwargs: Additional keyword arguments to pass to the Marqo client.

        Returns:
            The list of query results.
        """
        params = {
            "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
            "attributes_to_retrieve": None if include_metadata else ["_id"],
            "filter_string": f"namespace:{namespace}" if namespace else None,
        } | kwargs

        results = self.client.index(self.index).search(str(query), **params)
        return self.__process_results(results, include_vectors=include_vectors)

    def delete_index(self, name: str) -> dict[str, Any]:
        """Delete an index in the Marqo client.

        Args:
            name: The name of the index to delete.
        """
        return self.client.delete_index(name)

    def get_indexes(self) -> list[str]:
        """Get a list of all indexes in the Marqo client.

        Returns:
            The list of all indexes.
        """
        return [index["index"] for index in self.client.get_indexes()["results"]]

    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs: Any,
    ) -> str:
        """Upsert a vector into the Marqo index.

        Args:
            vector: The vector to be indexed.
            vector_id: The ID for the vector. If None, Marqo will generate an ID.
            namespace: An optional namespace for the vector.
            meta: An optional dictionary of metadata for the vector.
            kwargs: Additional keyword arguments to pass to the Marqo client.

        Raises:
            Exception: This function is not yet implemented.

        Returns:
            The ID of the vector that was added.
        """
        raise NotImplementedError(f"{self.__class__.__name__} does not support upserting a vector.")

    def delete_vector(self, vector_id: str) -> NoReturn:
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

    def __process_results(self, results: dict, *, include_vectors: bool) -> list[BaseVectorStoreDriver.Entry]:
        if include_vectors:
            results["hits"] = [
                {**r, **self.client.index(self.index).get_document(r["_id"], expose_facets=True)}
                for r in results["hits"]
            ]

        return [
            BaseVectorStoreDriver.Entry(
                id=r["_id"],
                vector=r["_tensor_facets"][0]["_embedding"] if include_vectors else [],
                score=r["_score"],
                meta={k: v for k, v in r.items() if k not in ["_score", "_tensor_facets"]},
            )
            for r in results["hits"]
        ]
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • index = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • url = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__process_results(results, *, include_vectors)

Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def __process_results(self, results: dict, *, include_vectors: bool) -> list[BaseVectorStoreDriver.Entry]:
    if include_vectors:
        results["hits"] = [
            {**r, **self.client.index(self.index).get_document(r["_id"], expose_facets=True)}
            for r in results["hits"]
        ]

    return [
        BaseVectorStoreDriver.Entry(
            id=r["_id"],
            vector=r["_tensor_facets"][0]["_embedding"] if include_vectors else [],
            score=r["_score"],
            meta={k: v for k, v in r.items() if k not in ["_score", "_tensor_facets"]},
        )
        for r in results["hits"]
    ]

client()

Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
@lazy_property()
def client(self) -> marqo.Client:
    return import_optional_dependency("marqo").Client(self.url, api_key=self.api_key)

delete_index(name)

Delete an index in the Marqo client.

Parameters

NameTypeDescriptionDefault
namestrThe name of the index to delete.
required
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def delete_index(self, name: str) -> dict[str, Any]:
    """Delete an index in the Marqo client.

    Args:
        name: The name of the index to delete.
    """
    return self.client.delete_index(name)

delete_vector(vector_id)

Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn:
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

get_indexes()

Get a list of all indexes in the Marqo client.

Returns

TypeDescription
list[str]The list of all indexes.
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def get_indexes(self) -> list[str]:
    """Get a list of all indexes in the Marqo client.

    Returns:
        The list of all indexes.
    """
    return [index["index"] for index in self.client.get_indexes()["results"]]

load_entries(*, namespace=None)

Load all document entries from the Marqo index.

Parameters

NameTypeDescriptionDefault
namespaceOptional[str]The namespace to filter entries by.
None

Returns

TypeDescription
list[Entry]The list of loaded Entries.
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Load all document entries from the Marqo index.

    Args:
        namespace: The namespace to filter entries by.

    Returns:
        The list of loaded Entries.
    """
    filter_string = f"namespace:{namespace}" if namespace else None

    if filter_string is not None:
        results = self.client.index(self.index).search("", limit=10000, filter_string=filter_string)
    else:
        results = self.client.index(self.index).search("", limit=10000)

    # get all _id's from search results
    ids = [r["_id"] for r in results["hits"]]

    # get documents corresponding to the ids
    documents = self.client.index(self.index).get_documents(document_ids=ids, expose_facets=True)

    # for each document, if it's found, create an Entry object
    entries = []
    for doc in documents["results"]:
        if doc["_found"]:
            entries.append(
                BaseVectorStoreDriver.Entry(
                    id=doc["_id"],
                    vector=doc["_tensor_facets"][0]["_embedding"],
                    meta={k: v for k, v in doc.items() if k not in ["_id", "_tensor_facets", "_found"]},
                    namespace=doc.get("namespace"),
                ),
            )

    return entries

load_entry(vector_id, *, namespace=None)

Load a document entry from the Marqo index.

Parameters

NameTypeDescriptionDefault
vector_idstrThe ID of the vector to load.
required
namespaceOptional[str]The namespace of the vector to load.
None

Returns

TypeDescription
Optional[Entry]The loaded Entry if found, otherwise None.
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Load a document entry from the Marqo index.

    Args:
        vector_id: The ID of the vector to load.
        namespace: The namespace of the vector to load.

    Returns:
        The loaded Entry if found, otherwise None.
    """
    result = self.client.index(self.index).get_document(document_id=vector_id, expose_facets=True)

    if result and "_tensor_facets" in result and len(result["_tensor_facets"]) > 0:
        return BaseVectorStoreDriver.Entry(
            id=result["_id"],
            meta={k: v for k, v in result.items() if k != "_id"},
            vector=result["_tensor_facets"][0]["_embedding"],
        )
    return None

query(query, *, count=None, namespace=None, include_vectors=False, include_metadata=True, **kwargs)

Query the Marqo index for documents.

Parameters

NameTypeDescriptionDefault
querystr | TextArtifact | ImageArtifactThe query string.
required
countOptional[int]The maximum number of results to return.
None
namespaceOptional[str]The namespace to filter results by.
None
include_vectorsboolWhether to include vector data in the results.
False
include_metadataboolWhether to include metadata in the results.
True
kwargsAnyAdditional keyword arguments to pass to the Marqo client.
{}

Returns

TypeDescription
list[Entry]The list of query results.
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def query(
    self,
    query: str | TextArtifact | ImageArtifact,
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    include_metadata: bool = True,
    **kwargs: Any,
) -> list[BaseVectorStoreDriver.Entry]:
    """Query the Marqo index for documents.

    Args:
        query: The query string.
        count: The maximum number of results to return.
        namespace: The namespace to filter results by.
        include_vectors: Whether to include vector data in the results.
        include_metadata: Whether to include metadata in the results.
        kwargs: Additional keyword arguments to pass to the Marqo client.

    Returns:
        The list of query results.
    """
    params = {
        "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
        "attributes_to_retrieve": None if include_metadata else ["_id"],
        "filter_string": f"namespace:{namespace}" if namespace else None,
    } | kwargs

    results = self.client.index(self.index).search(str(query), **params)
    return self.__process_results(results, include_vectors=include_vectors)

query_vector(vector, *, count=None, namespace=None, include_vectors=False, include_metadata=True, **kwargs)

Query the Marqo index for documents.

Parameters

NameTypeDescriptionDefault
vectorlist[float]The vector to query by.
required
countOptional[int]The maximum number of results to return.
None
namespaceOptional[str]The namespace to filter results by.
None
include_vectorsboolWhether to include vector data in the results.
False
include_metadataboolWhether to include metadata in the results.
True
kwargsAnyAdditional keyword arguments to pass to the Marqo client.
{}

Returns

TypeDescription
list[Entry]The list of query results.
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    include_metadata: bool = True,
    **kwargs: Any,
) -> list[BaseVectorStoreDriver.Entry]:
    """Query the Marqo index for documents.

    Args:
        vector: The vector to query by.
        count: The maximum number of results to return.
        namespace: The namespace to filter results by.
        include_vectors: Whether to include vector data in the results.
        include_metadata: Whether to include metadata in the results.
        kwargs: Additional keyword arguments to pass to the Marqo client.

    Returns:
        The list of query results.
    """
    params = {
        "limit": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
        "attributes_to_retrieve": None if include_metadata else ["_id"],
        "filter_string": f"namespace:{namespace}" if namespace else None,
    } | kwargs

    results = self.client.index(self.index).search(**params, context={"tensor": [vector], "weight": 1})

    return self.__process_results(results, include_vectors=include_vectors)

upsert(value, *, namespace=None, meta=None, vector_id=None, **kwargs)

Upsert a text document into the Marqo index.

Parameters

NameTypeDescriptionDefault
valuestr | TextArtifact | ImageArtifactThe value to be indexed.
required
namespaceOptional[str]An optional namespace for the document.
None
metaOptional[dict]An optional dictionary of metadata for the document.
None
vector_idOptional[str]The ID for the vector. If None, Marqo will generate an ID.
None
kwargsAnyAdditional keyword arguments to pass to the Marqo client.
{}

Returns

NameTypeDescription
strstrThe ID of the document that was added.
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def upsert(
    self,
    value: str | TextArtifact | ImageArtifact,
    *,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    vector_id: Optional[str] = None,
    **kwargs: Any,
) -> str:
    """Upsert a text document into the Marqo index.

    Args:
        value: The value to be indexed.
        namespace: An optional namespace for the document.
        meta: An optional dictionary of metadata for the document.
        vector_id: The ID for the vector. If None, Marqo will generate an ID.
        kwargs: Additional keyword arguments to pass to the Marqo client.

    Returns:
        str: The ID of the document that was added.
    """
    if isinstance(value, TextArtifact):
        artifact_json = value.to_json()
        vector_id = utils.str_to_hash(value.value) if vector_id is None else vector_id

        doc = {
            "_id": vector_id,
            "Description": value.value,
            "artifact": str(artifact_json),
        }
    elif isinstance(value, ImageArtifact):
        raise NotImplementedError("`MarqoVectorStoreDriver` does not upserting Image Artifacts.")
    else:
        doc = {"_id": vector_id, "Description": value}

    # Non-tensor fields
    if meta:
        doc["meta"] = str(meta)
    if namespace:
        doc["namespace"] = namespace

    response = self.client.index(self.index).add_documents([doc], tensor_fields=["Description"])
    if isinstance(response, dict) and "items" in response and response["items"]:
        return response["items"][0]["_id"]
    raise ValueError(f"Failed to upsert text: {response}")

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)

Upsert a vector into the Marqo index.

Parameters

NameTypeDescriptionDefault
vectorlist[float]The vector to be indexed.
required
vector_idOptional[str]The ID for the vector. If None, Marqo will generate an ID.
None
namespaceOptional[str]An optional namespace for the vector.
None
metaOptional[dict]An optional dictionary of metadata for the vector.
None
kwargsAnyAdditional keyword arguments to pass to the Marqo client.
{}

Raises

Returns

TypeDescription
strThe ID of the vector that was added.
Source Code in griptape/drivers/vector/marqo_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs: Any,
) -> str:
    """Upsert a vector into the Marqo index.

    Args:
        vector: The vector to be indexed.
        vector_id: The ID for the vector. If None, Marqo will generate an ID.
        namespace: An optional namespace for the vector.
        meta: An optional dictionary of metadata for the vector.
        kwargs: Additional keyword arguments to pass to the Marqo client.

    Raises:
        Exception: This function is not yet implemented.

    Returns:
        The ID of the vector that was added.
    """
    raise NotImplementedError(f"{self.__class__.__name__} does not support upserting a vector.")

MongoDbAtlasVectorStoreDriver

Bases: BaseVectorStoreDriver

Attributes

NameTypeDescription
connection_stringstrThe connection string for the MongoDb Atlas cluster.
database_namestrThe name of the database to use.
collection_namestrThe name of the collection to use.
index_namestrThe name of the index to use.
vector_pathstrThe path to the vector field in the collection.
clientMongoClientAn optional MongoDb client to use. Defaults to a new client using the connection string.
Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
@define
class MongoDbAtlasVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for MongoDb Atlas.

    Attributes:
        connection_string: The connection string for the MongoDb Atlas cluster.
        database_name: The name of the database to use.
        collection_name: The name of the collection to use.
        index_name: The name of the index to use.
        vector_path: The path to the vector field in the collection.
        client: An optional MongoDb client to use. Defaults to a new client using the connection string.
    """

    MAX_NUM_CANDIDATES = 10000

    connection_string: str = field(kw_only=True, metadata={"serializable": True})
    database_name: str = field(kw_only=True, metadata={"serializable": True})
    collection_name: str = field(kw_only=True, metadata={"serializable": True})
    index_name: str = field(kw_only=True, metadata={"serializable": True})
    vector_path: str = field(kw_only=True, metadata={"serializable": True})
    num_candidates_multiplier: int = field(
        default=10,
        kw_only=True,
        metadata={"serializable": True},
    )  # https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#fields
    _client: Optional[MongoClient] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> MongoClient:
        return import_optional_dependency("pymongo").MongoClient(self.connection_string)

    def get_collection(self) -> Collection:
        """Returns the MongoDB Collection instance for the specified database and collection name."""
        return self.client[self.database_name][self.collection_name]

    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in the collection.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        """
        collection = self.get_collection()

        if vector_id is None:
            result = collection.insert_one({self.vector_path: vector, "namespace": namespace, "meta": meta})
            vector_id = str(result.inserted_id)
        else:
            collection.replace_one(
                {"_id": vector_id},
                {self.vector_path: vector, "namespace": namespace, "meta": meta},
                upsert=True,
            )
        return vector_id

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Loads a document entry from the MongoDB collection based on the vector ID.

        Returns:
            The loaded Entry if found; otherwise, None is returned.
        """
        collection = self.get_collection()
        if namespace:
            doc = collection.find_one({"_id": vector_id, "namespace": namespace})
        else:
            doc = collection.find_one({"_id": vector_id})

        if doc is None:
            return doc
        return BaseVectorStoreDriver.Entry(
            id=str(doc["_id"]),
            vector=doc[self.vector_path],
            namespace=doc["namespace"],
            meta=doc["meta"],
        )

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Loads all document entries from the MongoDB collection.

        Entries can optionally be filtered by namespace.
        """
        collection = self.get_collection()
        cursor = collection.find() if namespace is None else collection.find({"namespace": namespace})

        return [
            BaseVectorStoreDriver.Entry(
                id=str(doc["_id"]),
                vector=doc[self.vector_path],
                namespace=doc["namespace"],
                meta=doc["meta"],
            )
            for doc in cursor
        ]

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        offset: Optional[int] = None,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Queries the MongoDB collection for documents that match the provided vector list.

        Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
        """
        collection = self.get_collection()

        count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
        offset = offset or 0

        pipeline = [
            {
                "$vectorSearch": {
                    "index": self.index_name,
                    "path": self.vector_path,
                    "queryVector": vector,
                    "numCandidates": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                    "limit": count,
                },
            },
            {
                "$project": {
                    "_id": 1,
                    self.vector_path: 1,
                    "namespace": 1,
                    "meta": 1,
                    "score": {"$meta": "vectorSearchScore"},
                },
            },
        ]

        if namespace:
            pipeline[0]["$vectorSearch"]["filter"] = {"namespace": namespace}

        return [
            BaseVectorStoreDriver.Entry(
                id=str(doc["_id"]),
                vector=doc[self.vector_path] if include_vectors else [],
                score=doc["score"],
                meta=doc["meta"],
                namespace=namespace,
            )
            for doc in collection.aggregate(pipeline)
        ]

    def delete_vector(self, vector_id: str) -> None:
        """Deletes the vector from the collection."""
        collection = self.get_collection()
        collection.delete_one({"_id": vector_id})
  • MAX_NUM_CANDIDATES = 10000 class-attribute instance-attribute

  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • collection_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • connection_string = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • database_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • index_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • num_candidates_multiplier = field(default=10, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • vector_path = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
@lazy_property()
def client(self) -> MongoClient:
    return import_optional_dependency("pymongo").MongoClient(self.connection_string)

delete_vector(vector_id)

Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def delete_vector(self, vector_id: str) -> None:
    """Deletes the vector from the collection."""
    collection = self.get_collection()
    collection.delete_one({"_id": vector_id})

get_collection()

Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def get_collection(self) -> Collection:
    """Returns the MongoDB Collection instance for the specified database and collection name."""
    return self.client[self.database_name][self.collection_name]

load_entries(*, namespace=None)

Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Loads all document entries from the MongoDB collection.

    Entries can optionally be filtered by namespace.
    """
    collection = self.get_collection()
    cursor = collection.find() if namespace is None else collection.find({"namespace": namespace})

    return [
        BaseVectorStoreDriver.Entry(
            id=str(doc["_id"]),
            vector=doc[self.vector_path],
            namespace=doc["namespace"],
            meta=doc["meta"],
        )
        for doc in cursor
    ]

load_entry(vector_id, *, namespace=None)

Loads a document entry from the MongoDB collection based on the vector ID.

Returns

TypeDescription
Optional[Entry]The loaded Entry if found; otherwise, None is returned.
Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Loads a document entry from the MongoDB collection based on the vector ID.

    Returns:
        The loaded Entry if found; otherwise, None is returned.
    """
    collection = self.get_collection()
    if namespace:
        doc = collection.find_one({"_id": vector_id, "namespace": namespace})
    else:
        doc = collection.find_one({"_id": vector_id})

    if doc is None:
        return doc
    return BaseVectorStoreDriver.Entry(
        id=str(doc["_id"]),
        vector=doc[self.vector_path],
        namespace=doc["namespace"],
        meta=doc["meta"],
    )

query_vector(vector, *, count=None, namespace=None, include_vectors=False, offset=None, **kwargs)

Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    offset: Optional[int] = None,
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    """Queries the MongoDB collection for documents that match the provided vector list.

    Results can be customized based on parameters like count, namespace, inclusion of vectors, offset, and index.
    """
    collection = self.get_collection()

    count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
    offset = offset or 0

    pipeline = [
        {
            "$vectorSearch": {
                "index": self.index_name,
                "path": self.vector_path,
                "queryVector": vector,
                "numCandidates": min(count * self.num_candidates_multiplier, self.MAX_NUM_CANDIDATES),
                "limit": count,
            },
        },
        {
            "$project": {
                "_id": 1,
                self.vector_path: 1,
                "namespace": 1,
                "meta": 1,
                "score": {"$meta": "vectorSearchScore"},
            },
        },
    ]

    if namespace:
        pipeline[0]["$vectorSearch"]["filter"] = {"namespace": namespace}

    return [
        BaseVectorStoreDriver.Entry(
            id=str(doc["_id"]),
            vector=doc[self.vector_path] if include_vectors else [],
            score=doc["score"],
            meta=doc["meta"],
            namespace=namespace,
        )
        for doc in collection.aggregate(pipeline)
    ]

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/mongodb_atlas_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in the collection.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    """
    collection = self.get_collection()

    if vector_id is None:
        result = collection.insert_one({self.vector_path: vector, "namespace": namespace, "meta": meta})
        vector_id = str(result.inserted_id)
    else:
        collection.replace_one(
            {"_id": vector_id},
            {self.vector_path: vector, "namespace": namespace, "meta": meta},
            upsert=True,
        )
    return vector_id

NoOpObservabilityDriver

Bases: BaseObservabilityDriver

Source Code in griptape/drivers/observability/no_op_observability_driver.py
@define
class NoOpObservabilityDriver(BaseObservabilityDriver):
    def observe(self, call: Observable.Call) -> Any:
        return call()

    def get_span_id(self) -> Optional[str]:
        return None

get_span_id()

Source Code in griptape/drivers/observability/no_op_observability_driver.py
def get_span_id(self) -> Optional[str]:
    return None

observe(call)

Source Code in griptape/drivers/observability/no_op_observability_driver.py
def observe(self, call: Observable.Call) -> Any:
    return call()

OllamaEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes

NameTypeDescription
modelstrOllama embedding model name.
hostOptional[str]Optional Ollama host.
clientClientOllama Client.
Source Code in griptape/drivers/embedding/ollama_embedding_driver.py
@define
class OllamaEmbeddingDriver(BaseEmbeddingDriver):
    """Ollama Embedding Driver.

    Attributes:
        model: Ollama embedding model name.
        host: Optional Ollama host.
        client: Ollama `Client`.
    """

    model: str = field(kw_only=True, metadata={"serializable": True})
    host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> Client:
        return import_optional_dependency("ollama").Client(host=self.host)

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • host = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/embedding/ollama_embedding_driver.py
@lazy_property()
def client(self) -> Client:
    return import_optional_dependency("ollama").Client(host=self.host)

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/ollama_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    return list(self.client.embeddings(model=self.model, prompt=chunk)["embedding"])

OllamaPromptDriver

Bases: BasePromptDriver

Attributes

NameTypeDescription
modelstrModel name.
Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
@define
class OllamaPromptDriver(BasePromptDriver):
    """Ollama Prompt Driver.

    Attributes:
        model: Model name.
    """

    model: str = field(kw_only=True, metadata={"serializable": True})
    host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    tokenizer: BaseTokenizer = field(
        default=Factory(
            lambda self: SimpleTokenizer(
                characters_per_token=4,
                max_input_tokens=2000,
                max_output_tokens=self.max_tokens,
            ),
            takes_self=True,
        ),
        kw_only=True,
    )
    options: dict = field(
        default=Factory(
            lambda self: {
                "temperature": self.temperature,
                "stop": self.tokenizer.stop_sequences,
                "num_predict": self.max_tokens,
            },
            takes_self=True,
        ),
        kw_only=True,
    )
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> Client:
        return import_optional_dependency("ollama").Client(host=self.host)

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        params = self._base_params(prompt_stack)
        logger.debug(params)
        response = self.client.chat(**params)
        logger.debug(response.model_dump())

        return Message(
            content=self.__to_prompt_stack_message_content(response),
            role=Message.ASSISTANT_ROLE,
        )

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        params = {**self._base_params(prompt_stack), "stream": True}
        logger.debug(params)
        stream: Iterator = self.client.chat(**params)

        tool_index = 0
        for chunk in stream:
            logger.debug(chunk)
            message_content = self.__to_prompt_stack_delta_message_content(chunk)
            # Ollama provides multiple Tool calls as separate chunks but with no index to differentiate them.
            # So we must keep track of the index ourselves.
            if isinstance(message_content, ActionCallDeltaMessageContent):
                message_content.index = tool_index
                tool_index += 1
            yield DeltaMessage(content=message_content)

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        messages = self._prompt_stack_to_messages(prompt_stack)

        params = {
            "messages": messages,
            "model": self.model,
            "options": self.options,
            **self.extra_params,
        }

        if prompt_stack.output_schema is not None and self.structured_output_strategy == "native":
            params["format"] = prompt_stack.to_output_json_schema()

        # Tool calling is only supported when not streaming
        if prompt_stack.tools and self.use_native_tools:
            params["tools"] = self.__to_ollama_tools(prompt_stack.tools)

        return params

    def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
        ollama_messages = []
        for message in prompt_stack.messages:
            action_result_contents = message.get_content_type(ActionResultMessageContent)

            # Function calls need to be handled separately from the rest of the message content
            if action_result_contents:
                ollama_messages.extend(
                    [
                        {
                            "role": self.__to_ollama_role(message, action_result_content),
                            "content": self.__to_ollama_message_content(action_result_content),
                        }
                        for action_result_content in action_result_contents
                    ],
                )

                text_contents = message.get_content_type(TextMessageContent)
                if text_contents:
                    ollama_messages.append({"role": self.__to_ollama_role(message), "content": message.to_text()})
            else:
                ollama_message: dict[str, Any] = {
                    "role": self.__to_ollama_role(message),
                    "content": message.to_text(),
                }

                action_call_contents = message.get_content_type(ActionCallMessageContent)
                if action_call_contents:
                    ollama_message["tool_calls"] = [
                        self.__to_ollama_message_content(action_call_content)
                        for action_call_content in action_call_contents
                    ]

                image_contents = message.get_content_type(ImageMessageContent)
                if image_contents:
                    ollama_message["images"] = [
                        self.__to_ollama_message_content(image_content) for image_content in image_contents
                    ]

                ollama_messages.append(ollama_message)

        return ollama_messages

    def __to_ollama_message_content(self, content: BaseMessageContent) -> str | dict:
        if isinstance(content, TextMessageContent):
            return content.artifact.to_text()
        if isinstance(content, ImageMessageContent):
            if isinstance(content.artifact, ImageArtifact):
                return content.artifact.base64
            # TODO: add support for image urls once resolved https://github.com/ollama/ollama/issues/4474
            raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}")
        if isinstance(content, ActionCallMessageContent):
            action = content.artifact.value

            return {
                "type": "function",
                "id": action.tag,
                "function": {"name": action.to_native_tool_name(), "arguments": action.input},
            }
        if isinstance(content, ActionResultMessageContent):
            return content.artifact.to_text()
        raise ValueError(f"Unsupported content type: {type(content)}")

    def __to_ollama_tools(self, tools: list[BaseTool]) -> list[dict]:
        ollama_tools = []

        for tool in tools:
            for activity in tool.activities():
                ollama_tool = {
                    "function": {
                        "name": tool.to_native_tool_name(activity),
                        "description": tool.activity_description(activity),
                    },
                    "type": "function",
                }

                activity_schema = tool.activity_schema(activity)
                if activity_schema is not None:
                    ollama_tool["function"]["parameters"] = tool.to_activity_json_schema(activity, "Parameters Schema")[
                        "properties"
                    ]["values"]

                ollama_tools.append(ollama_tool)
        return ollama_tools

    def __to_ollama_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str:
        if message.is_system():
            return "system"
        if message.is_assistant():
            return "assistant"
        if isinstance(message_content, ActionResultMessageContent):
            return "tool"
        return "user"

    def __to_prompt_stack_message_content(self, response: ChatResponse) -> list[BaseMessageContent]:
        content = []
        message = response["message"]

        if message.get("content"):
            content.append(TextMessageContent(TextArtifact(response["message"]["content"])))
        if "tool_calls" in message:
            content.extend(
                [
                    ActionCallMessageContent(
                        ActionArtifact(
                            ToolAction(
                                tag=tool_call["function"]["name"],
                                name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0],
                                path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1],
                                input=tool_call["function"]["arguments"],
                            ),
                        ),
                    )
                    for tool_call in message["tool_calls"]
                ],
            )

        return content

    def __to_prompt_stack_delta_message_content(self, content_delta: ChatResponse) -> BaseDeltaMessageContent:
        message = content_delta["message"]
        if message.get("content"):
            return TextDeltaMessageContent(message["content"])
        if "tool_calls" in message and len(message["tool_calls"]):
            tool_calls = message["tool_calls"]

            # Ollama doesn't _really_ support Tool streaming. They provide the full tool call at once.
            # Multiple, parallel, Tool calls are provided as multiple content deltas.
            # Tracking here: https://github.com/ollama/ollama/issues/7886
            tool_call = tool_calls[0]

            return ActionCallDeltaMessageContent(
                tag=tool_call["function"]["name"],
                name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0],
                path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1],
                partial_input=json.dumps(tool_call["function"]["arguments"]),
            )
        return TextDeltaMessageContent("")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • host = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • options = field(default=Factory(lambda self: {'temperature': self.temperature, 'stop': self.tokenizer.stop_sequences, 'num_predict': self.max_tokens}, takes_self=True), kw_only=True) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: SimpleTokenizer(characters_per_token=4, max_input_tokens=2000, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True) class-attribute instance-attribute

  • use_native_tools = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_ollama_message_content(content)

Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def __to_ollama_message_content(self, content: BaseMessageContent) -> str | dict:
    if isinstance(content, TextMessageContent):
        return content.artifact.to_text()
    if isinstance(content, ImageMessageContent):
        if isinstance(content.artifact, ImageArtifact):
            return content.artifact.base64
        # TODO: add support for image urls once resolved https://github.com/ollama/ollama/issues/4474
        raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}")
    if isinstance(content, ActionCallMessageContent):
        action = content.artifact.value

        return {
            "type": "function",
            "id": action.tag,
            "function": {"name": action.to_native_tool_name(), "arguments": action.input},
        }
    if isinstance(content, ActionResultMessageContent):
        return content.artifact.to_text()
    raise ValueError(f"Unsupported content type: {type(content)}")

__to_ollama_role(message, message_content=None)

Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def __to_ollama_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str:
    if message.is_system():
        return "system"
    if message.is_assistant():
        return "assistant"
    if isinstance(message_content, ActionResultMessageContent):
        return "tool"
    return "user"

__to_ollama_tools(tools)

Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def __to_ollama_tools(self, tools: list[BaseTool]) -> list[dict]:
    ollama_tools = []

    for tool in tools:
        for activity in tool.activities():
            ollama_tool = {
                "function": {
                    "name": tool.to_native_tool_name(activity),
                    "description": tool.activity_description(activity),
                },
                "type": "function",
            }

            activity_schema = tool.activity_schema(activity)
            if activity_schema is not None:
                ollama_tool["function"]["parameters"] = tool.to_activity_json_schema(activity, "Parameters Schema")[
                    "properties"
                ]["values"]

            ollama_tools.append(ollama_tool)
    return ollama_tools

__to_prompt_stack_delta_message_content(content_delta)

Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, content_delta: ChatResponse) -> BaseDeltaMessageContent:
    message = content_delta["message"]
    if message.get("content"):
        return TextDeltaMessageContent(message["content"])
    if "tool_calls" in message and len(message["tool_calls"]):
        tool_calls = message["tool_calls"]

        # Ollama doesn't _really_ support Tool streaming. They provide the full tool call at once.
        # Multiple, parallel, Tool calls are provided as multiple content deltas.
        # Tracking here: https://github.com/ollama/ollama/issues/7886
        tool_call = tool_calls[0]

        return ActionCallDeltaMessageContent(
            tag=tool_call["function"]["name"],
            name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0],
            path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1],
            partial_input=json.dumps(tool_call["function"]["arguments"]),
        )
    return TextDeltaMessageContent("")

__to_prompt_stack_message_content(response)

Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def __to_prompt_stack_message_content(self, response: ChatResponse) -> list[BaseMessageContent]:
    content = []
    message = response["message"]

    if message.get("content"):
        content.append(TextMessageContent(TextArtifact(response["message"]["content"])))
    if "tool_calls" in message:
        content.extend(
            [
                ActionCallMessageContent(
                    ActionArtifact(
                        ToolAction(
                            tag=tool_call["function"]["name"],
                            name=ToolAction.from_native_tool_name(tool_call["function"]["name"])[0],
                            path=ToolAction.from_native_tool_name(tool_call["function"]["name"])[1],
                            input=tool_call["function"]["arguments"],
                        ),
                    ),
                )
                for tool_call in message["tool_calls"]
            ],
        )

    return content

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    messages = self._prompt_stack_to_messages(prompt_stack)

    params = {
        "messages": messages,
        "model": self.model,
        "options": self.options,
        **self.extra_params,
    }

    if prompt_stack.output_schema is not None and self.structured_output_strategy == "native":
        params["format"] = prompt_stack.to_output_json_schema()

    # Tool calling is only supported when not streaming
    if prompt_stack.tools and self.use_native_tools:
        params["tools"] = self.__to_ollama_tools(prompt_stack.tools)

    return params

_prompt_stack_to_messages(prompt_stack)

Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]:
    ollama_messages = []
    for message in prompt_stack.messages:
        action_result_contents = message.get_content_type(ActionResultMessageContent)

        # Function calls need to be handled separately from the rest of the message content
        if action_result_contents:
            ollama_messages.extend(
                [
                    {
                        "role": self.__to_ollama_role(message, action_result_content),
                        "content": self.__to_ollama_message_content(action_result_content),
                    }
                    for action_result_content in action_result_contents
                ],
            )

            text_contents = message.get_content_type(TextMessageContent)
            if text_contents:
                ollama_messages.append({"role": self.__to_ollama_role(message), "content": message.to_text()})
        else:
            ollama_message: dict[str, Any] = {
                "role": self.__to_ollama_role(message),
                "content": message.to_text(),
            }

            action_call_contents = message.get_content_type(ActionCallMessageContent)
            if action_call_contents:
                ollama_message["tool_calls"] = [
                    self.__to_ollama_message_content(action_call_content)
                    for action_call_content in action_call_contents
                ]

            image_contents = message.get_content_type(ImageMessageContent)
            if image_contents:
                ollama_message["images"] = [
                    self.__to_ollama_message_content(image_content) for image_content in image_contents
                ]

            ollama_messages.append(ollama_message)

    return ollama_messages

client()

Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
@lazy_property()
def client(self) -> Client:
    return import_optional_dependency("ollama").Client(host=self.host)

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    params = self._base_params(prompt_stack)
    logger.debug(params)
    response = self.client.chat(**params)
    logger.debug(response.model_dump())

    return Message(
        content=self.__to_prompt_stack_message_content(response),
        role=Message.ASSISTANT_ROLE,
    )

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/ollama_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    params = {**self._base_params(prompt_stack), "stream": True}
    logger.debug(params)
    stream: Iterator = self.client.chat(**params)

    tool_index = 0
    for chunk in stream:
        logger.debug(chunk)
        message_content = self.__to_prompt_stack_delta_message_content(chunk)
        # Ollama provides multiple Tool calls as separate chunks but with no index to differentiate them.
        # So we must keep track of the index ourselves.
        if isinstance(message_content, ActionCallDeltaMessageContent):
            message_content.index = tool_index
            tool_index += 1
        yield DeltaMessage(content=message_content)

OpenAiAssistantDriver

Bases: BaseAssistantDriver

Source Code in griptape/drivers/assistant/openai_assistant_driver.py
@define
class OpenAiAssistantDriver(BaseAssistantDriver):
    class EventHandler(AssistantEventHandler):
        @override
        def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
            if delta.value is not None:
                EventBus.publish_event(TextChunkEvent(token=delta.value))

        @override
        def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
            if delta.type == "code_interpreter" and delta.code_interpreter is not None:
                if delta.code_interpreter.input:
                    EventBus.publish_event(TextChunkEvent(token=delta.code_interpreter.input))
                if delta.code_interpreter.outputs:
                    EventBus.publish_event(TextChunkEvent(token="\n\noutput >"))
                    for output in delta.code_interpreter.outputs:
                        if output.type == "logs" and output.logs:
                            EventBus.publish_event(TextChunkEvent(token=output.logs))

    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    thread_id: Optional[str] = field(default=None, kw_only=True)
    assistant_id: str = field(kw_only=True)
    event_handler: AssistantEventHandler = field(
        default=Factory(lambda: OpenAiAssistantDriver.EventHandler()), kw_only=True, metadata={"serializable": False}
    )
    auto_create_thread: bool = field(default=True, kw_only=True)

    _client: Optional[openai.OpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.OpenAI:
        return openai.OpenAI(
            base_url=self.base_url,
            api_key=self.api_key,
            organization=self.organization,
        )

    def try_run(self, *args: BaseArtifact) -> TextArtifact:
        if self.thread_id is None:
            if self.auto_create_thread:
                thread_id = self.client.beta.threads.create().id
                self.thread_id = thread_id
            else:
                raise ValueError("Thread ID is required but not provided and auto_create_thread is disabled.")
        else:
            thread_id = self.thread_id

        response = self._create_run(thread_id, *args)

        response.meta.update({"assistant_id": self.assistant_id, "thread_id": self.thread_id})

        return response

    def _create_run(self, thread_id: str, *args: BaseArtifact) -> TextArtifact:
        content = "\n".join(arg.value for arg in args)
        message_id = self.client.beta.threads.messages.create(thread_id=thread_id, role="user", content=content)
        with self.client.beta.threads.runs.stream(
            thread_id=thread_id,
            assistant_id=self.assistant_id,
            event_handler=self.event_handler,
        ) as stream:
            stream.until_done()
            last_messages = stream.get_final_messages()

            message_contents = []
            for message in last_messages:
                message_contents.append(
                    "".join(content.text.value for content in message.content if content.type == "TextContentBlock")
                )
            message_text = "\n".join(message_contents)

            response = TextArtifact(message_text)

            response.meta.update(
                {"assistant_id": self.assistant_id, "thread_id": self.thread_id, "message_id": message_id}
            )
            return response
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • assistant_id = field(kw_only=True) class-attribute instance-attribute

  • auto_create_thread = field(default=True, kw_only=True) class-attribute instance-attribute

  • base_url = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • event_handler = field(default=Factory(lambda: OpenAiAssistantDriver.EventHandler()), kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • organization = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • thread_id = field(default=None, kw_only=True) class-attribute instance-attribute

EventHandler

Bases:

AssistantEventHandler
Source Code in griptape/drivers/assistant/openai_assistant_driver.py
class EventHandler(AssistantEventHandler):
    @override
    def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None:
        if delta.value is not None:
            EventBus.publish_event(TextChunkEvent(token=delta.value))

    @override
    def on_tool_call_delta(self, delta: ToolCallDelta, snapshot: ToolCall) -> None:
        if delta.type == "code_interpreter" and delta.code_interpreter is not None:
            if delta.code_interpreter.input:
                EventBus.publish_event(TextChunkEvent(token=delta.code_interpreter.input))
            if delta.code_interpreter.outputs:
                EventBus.publish_event(TextChunkEvent(token="\n\noutput >"))
                for output in delta.code_interpreter.outputs:
                    if output.type == "logs" and output.logs:
                        EventBus.publish_event(TextChunkEvent(token=output.logs))

_create_run(thread_id, *args)

Source Code in griptape/drivers/assistant/openai_assistant_driver.py
def _create_run(self, thread_id: str, *args: BaseArtifact) -> TextArtifact:
    content = "\n".join(arg.value for arg in args)
    message_id = self.client.beta.threads.messages.create(thread_id=thread_id, role="user", content=content)
    with self.client.beta.threads.runs.stream(
        thread_id=thread_id,
        assistant_id=self.assistant_id,
        event_handler=self.event_handler,
    ) as stream:
        stream.until_done()
        last_messages = stream.get_final_messages()

        message_contents = []
        for message in last_messages:
            message_contents.append(
                "".join(content.text.value for content in message.content if content.type == "TextContentBlock")
            )
        message_text = "\n".join(message_contents)

        response = TextArtifact(message_text)

        response.meta.update(
            {"assistant_id": self.assistant_id, "thread_id": self.thread_id, "message_id": message_id}
        )
        return response

client()

Source Code in griptape/drivers/assistant/openai_assistant_driver.py
@lazy_property()
def client(self) -> openai.OpenAI:
    return openai.OpenAI(
        base_url=self.base_url,
        api_key=self.api_key,
        organization=self.organization,
    )

try_run(*args)

Source Code in griptape/drivers/assistant/openai_assistant_driver.py
def try_run(self, *args: BaseArtifact) -> TextArtifact:
    if self.thread_id is None:
        if self.auto_create_thread:
            thread_id = self.client.beta.threads.create().id
            self.thread_id = thread_id
        else:
            raise ValueError("Thread ID is required but not provided and auto_create_thread is disabled.")
    else:
        thread_id = self.thread_id

    response = self._create_run(thread_id, *args)

    response.meta.update({"assistant_id": self.assistant_id, "thread_id": self.thread_id})

    return response

OpenAiAudioTranscriptionDriver

Bases: BaseAudioTranscriptionDriver

Source Code in griptape/drivers/audio_transcription/openai_audio_transcription_driver.py
@define
class OpenAiAudioTranscriptionDriver(BaseAudioTranscriptionDriver):
    api_type: Optional[str] = field(default=openai.api_type, kw_only=True)
    api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
    _client: Optional[openai.OpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.OpenAI:
        return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

    def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
        additional_params = {}

        if prompts is not None:
            additional_params["prompt"] = ", ".join(prompts)

        transcription = self.client.audio.transcriptions.create(
            # Even though we're not actually providing a file to the client, the API still requires that we send a file
            # name. We set the file name to use the same format as the audio file so that the API can reject
            # it if the format is unsupported.
            model=self.model,
            file=(f"a.{audio.format}", io.BytesIO(audio.value)),
            response_format="json",
            **additional_params,
        )

        return TextArtifact(value=transcription.text)
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • api_type = field(default=openai.api_type, kw_only=True) class-attribute instance-attribute

  • api_version = field(default=openai.api_version, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • base_url = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • organization = field(default=openai.organization, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/audio_transcription/openai_audio_transcription_driver.py
@lazy_property()
def client(self) -> openai.OpenAI:
    return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

try_run(audio, prompts=None)

Source Code in griptape/drivers/audio_transcription/openai_audio_transcription_driver.py
def try_run(self, audio: AudioArtifact, prompts: Optional[list[str]] = None) -> TextArtifact:
    additional_params = {}

    if prompts is not None:
        additional_params["prompt"] = ", ".join(prompts)

    transcription = self.client.audio.transcriptions.create(
        # Even though we're not actually providing a file to the client, the API still requires that we send a file
        # name. We set the file name to use the same format as the audio file so that the API can reject
        # it if the format is unsupported.
        model=self.model,
        file=(f"a.{audio.format}", io.BytesIO(audio.value)),
        response_format="json",
        **additional_params,
    )

    return TextArtifact(value=transcription.text)

OpenAiChatPromptDriver

Bases: BasePromptDriver

Attributes

NameTypeDescription
base_urlOptional[str]An optional OpenAi API URL.
api_keyOptional[str]An optional OpenAi API key. If not provided, the OPENAI_API_KEY environment variable will be used.
organizationOptional[str]An optional OpenAI organization. If not provided, the OPENAI_ORG_ID environment variable will be used.
clientOpenAIAn openai.OpenAI client.
modelstrAn OpenAI model name.
tokenizerBaseTokenizerAn OpenAiTokenizer.
userstrA user id. Can be used to track requests by user.
response_formatOptional[dict]An optional OpenAi Chat Completion response format. Currently only supports json_object which will enable OpenAi's JSON mode.
seedOptional[int]An optional OpenAi Chat Completion seed.
ignored_exception_typestuple[type[Exception], ...]An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
parallel_tool_callsboolA flag to enable parallel tool calls. Defaults to True.
Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
@define
class OpenAiChatPromptDriver(BasePromptDriver):
    """OpenAI Chat Prompt Driver.

    Attributes:
        base_url: An optional OpenAi API URL.
        api_key: An optional OpenAi API key. If not provided, the `OPENAI_API_KEY` environment variable will be used.
        organization: An optional OpenAI organization. If not provided, the `OPENAI_ORG_ID` environment variable will be used.
        client: An `openai.OpenAI` client.
        model: An OpenAI model name.
        tokenizer: An `OpenAiTokenizer`.
        user: A user id. Can be used to track requests by user.
        response_format: An optional OpenAi Chat Completion response format. Currently only supports `json_object` which will enable OpenAi's JSON mode.
        seed: An optional OpenAi Chat Completion seed.
        ignored_exception_types: An optional tuple of exception types to ignore. Defaults to OpenAI's known exception types.
        parallel_tool_calls: A flag to enable parallel tool calls. Defaults to `True`.
    """

    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: BaseTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    user: str = field(default="", kw_only=True, metadata={"serializable": True})
    response_format: Optional[dict] = field(
        default=None,
        kw_only=True,
        metadata={"serializable": True},
    )
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": False})
    reasoning_effort: Literal["low", "medium", "high"] = field(
        default="medium", kw_only=True, metadata={"serializable": True}
    )
    use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    structured_output_strategy: StructuredOutputStrategy = field(
        default="native", kw_only=True, metadata={"serializable": True}
    )
    parallel_tool_calls: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    ignored_exception_types: tuple[type[Exception], ...] = field(
        default=Factory(
            lambda: (
                openai.BadRequestError,
                openai.AuthenticationError,
                openai.PermissionDeniedError,
                openai.NotFoundError,
                openai.ConflictError,
                openai.UnprocessableEntityError,
            ),
        ),
        kw_only=True,
    )
    modalities: list[str] = field(factory=list, kw_only=True, metadata={"serializable": True})
    audio: dict = field(
        default=Factory(lambda: {"voice": "alloy", "format": "pcm16"}), kw_only=True, metadata={"serializable": True}
    )
    _client: Optional[openai.OpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.OpenAI:
        return openai.OpenAI(
            base_url=self.base_url,
            api_key=self.api_key,
            organization=self.organization,
        )

    @property
    def is_reasoning_model(self) -> bool:
        return self.model.startswith("o")

    @observable
    def try_run(self, prompt_stack: PromptStack) -> Message:
        params = self._base_params(prompt_stack)
        logger.debug(params)
        result = self.client.chat.completions.create(**params)

        logger.debug(result.model_dump())
        return self._to_message(result)

    @observable
    def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
        params = self._base_params(prompt_stack)
        logger.debug({"stream": True, **params})
        result = self.client.chat.completions.create(**params, stream=True)

        return self._to_delta_message_stream(result)

    def _to_message(self, result: ChatCompletion) -> Message:
        if len(result.choices) == 1:
            choice_message = result.choices[0].message

            message = Message(
                content=self.__to_prompt_stack_message_content(choice_message),
                role=Message.ASSISTANT_ROLE,
            )
            if result.usage is not None:
                message.usage = Message.Usage(
                    input_tokens=result.usage.prompt_tokens,
                    output_tokens=result.usage.completion_tokens,
                )

            return message
        raise Exception("Completion with more than one choice is not supported yet.")

    def _to_delta_message_stream(self, result: Stream[ChatCompletionChunk]) -> Iterator[DeltaMessage]:
        for message in result:
            if message.usage is not None:
                yield DeltaMessage(
                    usage=DeltaMessage.Usage(
                        input_tokens=message.usage.prompt_tokens,
                        output_tokens=message.usage.completion_tokens,
                    ),
                )
            if message.choices:
                choice = message.choices[0]
                delta = choice.delta

                content = self.__to_prompt_stack_delta_message_content(delta)

                if content is not None:
                    yield DeltaMessage(content=content)

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = {
            "model": self.model,
            **({"user": self.user} if self.user else {}),
            **({"seed": self.seed} if self.seed is not None else {}),
            **({"modalities": self.modalities} if self.modalities and not self.is_reasoning_model else {}),
            **(
                {"reasoning_effort": self.reasoning_effort}
                if self.is_reasoning_model and self.model != "o1-mini"
                else {}
            ),
            **({"temperature": self.temperature} if not self.is_reasoning_model else {}),
            **({"audio": self.audio} if "audio" in self.modalities else {}),
            **(
                {"stop": self.tokenizer.stop_sequences}
                if not self.is_reasoning_model and self.tokenizer.stop_sequences
                else {}
            ),
            **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
            **({"stream_options": {"include_usage": True}} if self.stream else {}),
            **self.extra_params,
        }

        if prompt_stack.tools and self.use_native_tools:
            params["tool_choice"] = self.tool_choice
            params["parallel_tool_calls"] = self.parallel_tool_calls

        if prompt_stack.output_schema is not None:
            if self.structured_output_strategy == "native":
                params["response_format"] = {
                    "type": "json_schema",
                    "json_schema": {
                        "name": "Output",
                        "schema": prompt_stack.to_output_json_schema(),
                        "strict": True,
                    },
                }
            elif self.structured_output_strategy == "tool" and self.use_native_tools:
                params["tool_choice"] = "required"

        if self.response_format is not None:
            if self.response_format == {"type": "json_object"}:
                params["response_format"] = self.response_format
                # JSON mode still requires a system message instructing the LLM to output JSON.
                prompt_stack.add_system_message("Provide your response as a valid JSON object.")
            else:
                params["response_format"] = self.response_format

        if prompt_stack.tools and self.use_native_tools:
            params["tools"] = self.__to_openai_tools(prompt_stack.tools)

        messages = self.__to_openai_messages(prompt_stack.messages)

        params["messages"] = messages

        return params

    def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
        openai_messages = []

        for message in messages:
            # If the message only contains textual content we can send it as a single content.
            if message.has_all_content_type(TextMessageContent):
                openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
            # Action results must be sent as separate messages.
            elif action_result_contents := message.get_content_type(ActionResultMessageContent):
                openai_messages.extend(
                    {
                        "role": self.__to_openai_role(message, action_result_content),
                        "content": self.__to_openai_message_content(action_result_content),
                        "tool_call_id": action_result_content.action.tag,
                    }
                    for action_result_content in action_result_contents
                )

                if message.has_any_content_type(TextMessageContent):
                    openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
            else:
                openai_message = {
                    "role": self.__to_openai_role(message),
                    "content": [],
                }

                for content in message.content:
                    if isinstance(content, ActionCallMessageContent):
                        if "tool_calls" not in openai_message:
                            openai_message["tool_calls"] = []
                        openai_message["tool_calls"].append(self.__to_openai_message_content(content))
                    elif (
                        isinstance(content, AudioMessageContent)
                        and message.is_assistant()
                        and time.time() < content.artifact.meta["expires_at"]
                    ):
                        # For assistant audio messages, we reference the audio id instead of sending audio message content.
                        openai_message["audio"] = {
                            "id": content.artifact.meta["audio_id"],
                        }
                    else:
                        openai_message["content"].append(self.__to_openai_message_content(content))

                # Some OpenAi-compatible services don't accept an empty array for content
                if not openai_message["content"]:
                    del openai_message["content"]

                openai_messages.append(openai_message)

        return openai_messages

    def __to_openai_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str:
        if message.is_system():
            if self.is_reasoning_model:
                return "developer"
            return "system"
        if message.is_assistant():
            return "assistant"
        if isinstance(message_content, ActionResultMessageContent):
            return "tool"
        return "user"

    def __to_openai_tools(self, tools: list[BaseTool]) -> list[dict]:
        return [
            {
                "function": {
                    "name": tool.to_native_tool_name(activity),
                    "description": tool.activity_description(activity),
                    "parameters": tool.to_activity_json_schema(activity, "Parameters Schema"),
                },
                "type": "function",
            }
            for tool in tools
            for activity in tool.activities()
        ]

    def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict:
        if isinstance(content, TextMessageContent):
            return {"type": "text", "text": content.artifact.to_text()}
        if isinstance(content, ImageMessageContent):
            if isinstance(content.artifact, ImageArtifact):
                return {
                    "type": "image_url",
                    "image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"},
                }
            if isinstance(content.artifact, ImageUrlArtifact):
                return {
                    "type": "image_url",
                    "image_url": {"url": content.artifact.value},
                }
            raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}")
        if isinstance(content, AudioMessageContent):
            artifact = content.artifact
            metadata = artifact.meta

            # If there's an expiration date, we can assume it's an assistant message.
            if "expires_at" in metadata:
                # If it's expired, we send the transcript instead.
                if time.time() >= metadata["expires_at"]:
                    return {
                        "type": "text",
                        "text": artifact.meta.get("transcript"),
                    }
                # This should never occur, since a non-expired audio content
                # should have already been referenced by the audio id.
                raise ValueError("Assistant audio messages should be sent as audio ids.")
            # If there's no expiration date, we can assume it's a user message where we send the audio every time.
            return {
                "type": "input_audio",
                "input_audio": {
                    "data": base64.b64encode(artifact.value).decode("utf-8"),
                    "format": artifact.format,
                },
            }
        if isinstance(content, ActionCallMessageContent):
            action = content.artifact.value

            return {
                "type": "function",
                "id": action.tag,
                "function": {"name": action.to_native_tool_name(), "arguments": json.dumps(action.input)},
            }
        if isinstance(content, ActionResultMessageContent):
            return content.artifact.to_text()
        return {"type": "text", "text": content.artifact.to_text()}

    def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) -> list[BaseMessageContent]:
        content = []

        if response.content is not None:
            content.append(TextMessageContent(TextArtifact(response.content)))
        if hasattr(response, "audio") and response.audio is not None:
            content.append(
                AudioMessageContent(
                    AudioArtifact(
                        value=base64.b64decode(response.audio.data),
                        format="wav",
                        meta={
                            "audio_id": response.audio.id,
                            "transcript": response.audio.transcript,
                            "expires_at": response.audio.expires_at,
                        },
                    )
                )
            )
        if response.tool_calls is not None:
            content.extend(
                [
                    ActionCallMessageContent(
                        ActionArtifact(
                            ToolAction(
                                tag=tool_call.id,
                                name=ToolAction.from_native_tool_name(tool_call.function.name)[0],
                                path=ToolAction.from_native_tool_name(tool_call.function.name)[1],
                                input=json.loads(tool_call.function.arguments),
                            ),
                        ),
                    )
                    for tool_call in response.tool_calls
                ],
            )

        return content

    def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> Optional[BaseDeltaMessageContent]:
        if content_delta.content is not None:
            return TextDeltaMessageContent(content_delta.content)
        if content_delta.tool_calls is not None:
            tool_calls = content_delta.tool_calls

            if len(tool_calls) == 1:
                tool_call = tool_calls[0]
                index = tool_call.index

                if tool_call.function is not None:
                    function_name = tool_call.function.name
                    return ActionCallDeltaMessageContent(
                        index=index,
                        tag=tool_call.id,
                        name=ToolAction.from_native_tool_name(function_name)[0] if function_name else None,
                        path=ToolAction.from_native_tool_name(function_name)[1] if function_name else None,
                        partial_input=tool_call.function.arguments,
                    )
                raise ValueError(f"Unsupported tool call delta: {tool_call}")
            raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}")
        # OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr.
        if hasattr(content_delta, "audio") and getattr(content_delta, "audio") is not None:
            audio_chunk: dict = getattr(content_delta, "audio")
            return AudioDeltaMessageContent(
                id=audio_chunk.get("id"),
                data=audio_chunk.get("data"),
                expires_at=audio_chunk.get("expires_at"),
                transcript=audio_chunk.get("transcript"),
            )
        return None
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • audio = field(default=Factory(lambda: {'voice': 'alloy', 'format': 'pcm16'}), kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • base_url = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • ignored_exception_types = field(default=Factory(lambda: (openai.BadRequestError, openai.AuthenticationError, openai.PermissionDeniedError, openai.NotFoundError, openai.ConflictError, openai.UnprocessableEntityError)), kw_only=True) class-attribute instance-attribute

  • is_reasoning_model property

  • modalities = field(factory=list, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • organization = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • parallel_tool_calls = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • reasoning_effort = field(default='medium', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • response_format = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • seed = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • structured_output_strategy = field(default='native', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

  • tool_choice = field(default='auto', kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • use_native_tools = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • user = field(default='', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

__to_openai_message_content(content)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict:
    if isinstance(content, TextMessageContent):
        return {"type": "text", "text": content.artifact.to_text()}
    if isinstance(content, ImageMessageContent):
        if isinstance(content.artifact, ImageArtifact):
            return {
                "type": "image_url",
                "image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"},
            }
        if isinstance(content.artifact, ImageUrlArtifact):
            return {
                "type": "image_url",
                "image_url": {"url": content.artifact.value},
            }
        raise ValueError(f"Unsupported image artifact type: {type(content.artifact)}")
    if isinstance(content, AudioMessageContent):
        artifact = content.artifact
        metadata = artifact.meta

        # If there's an expiration date, we can assume it's an assistant message.
        if "expires_at" in metadata:
            # If it's expired, we send the transcript instead.
            if time.time() >= metadata["expires_at"]:
                return {
                    "type": "text",
                    "text": artifact.meta.get("transcript"),
                }
            # This should never occur, since a non-expired audio content
            # should have already been referenced by the audio id.
            raise ValueError("Assistant audio messages should be sent as audio ids.")
        # If there's no expiration date, we can assume it's a user message where we send the audio every time.
        return {
            "type": "input_audio",
            "input_audio": {
                "data": base64.b64encode(artifact.value).decode("utf-8"),
                "format": artifact.format,
            },
        }
    if isinstance(content, ActionCallMessageContent):
        action = content.artifact.value

        return {
            "type": "function",
            "id": action.tag,
            "function": {"name": action.to_native_tool_name(), "arguments": json.dumps(action.input)},
        }
    if isinstance(content, ActionResultMessageContent):
        return content.artifact.to_text()
    return {"type": "text", "text": content.artifact.to_text()}

__to_openai_messages(messages)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_openai_messages(self, messages: list[Message]) -> list[dict]:
    openai_messages = []

    for message in messages:
        # If the message only contains textual content we can send it as a single content.
        if message.has_all_content_type(TextMessageContent):
            openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
        # Action results must be sent as separate messages.
        elif action_result_contents := message.get_content_type(ActionResultMessageContent):
            openai_messages.extend(
                {
                    "role": self.__to_openai_role(message, action_result_content),
                    "content": self.__to_openai_message_content(action_result_content),
                    "tool_call_id": action_result_content.action.tag,
                }
                for action_result_content in action_result_contents
            )

            if message.has_any_content_type(TextMessageContent):
                openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
        else:
            openai_message = {
                "role": self.__to_openai_role(message),
                "content": [],
            }

            for content in message.content:
                if isinstance(content, ActionCallMessageContent):
                    if "tool_calls" not in openai_message:
                        openai_message["tool_calls"] = []
                    openai_message["tool_calls"].append(self.__to_openai_message_content(content))
                elif (
                    isinstance(content, AudioMessageContent)
                    and message.is_assistant()
                    and time.time() < content.artifact.meta["expires_at"]
                ):
                    # For assistant audio messages, we reference the audio id instead of sending audio message content.
                    openai_message["audio"] = {
                        "id": content.artifact.meta["audio_id"],
                    }
                else:
                    openai_message["content"].append(self.__to_openai_message_content(content))

            # Some OpenAi-compatible services don't accept an empty array for content
            if not openai_message["content"]:
                del openai_message["content"]

            openai_messages.append(openai_message)

    return openai_messages

__to_openai_role(message, message_content=None)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_openai_role(self, message: Message, message_content: Optional[BaseMessageContent] = None) -> str:
    if message.is_system():
        if self.is_reasoning_model:
            return "developer"
        return "system"
    if message.is_assistant():
        return "assistant"
    if isinstance(message_content, ActionResultMessageContent):
        return "tool"
    return "user"

__to_openai_tools(tools)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_openai_tools(self, tools: list[BaseTool]) -> list[dict]:
    return [
        {
            "function": {
                "name": tool.to_native_tool_name(activity),
                "description": tool.activity_description(activity),
                "parameters": tool.to_activity_json_schema(activity, "Parameters Schema"),
            },
            "type": "function",
        }
        for tool in tools
        for activity in tool.activities()
    ]

__to_prompt_stack_delta_message_content(content_delta)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> Optional[BaseDeltaMessageContent]:
    if content_delta.content is not None:
        return TextDeltaMessageContent(content_delta.content)
    if content_delta.tool_calls is not None:
        tool_calls = content_delta.tool_calls

        if len(tool_calls) == 1:
            tool_call = tool_calls[0]
            index = tool_call.index

            if tool_call.function is not None:
                function_name = tool_call.function.name
                return ActionCallDeltaMessageContent(
                    index=index,
                    tag=tool_call.id,
                    name=ToolAction.from_native_tool_name(function_name)[0] if function_name else None,
                    path=ToolAction.from_native_tool_name(function_name)[1] if function_name else None,
                    partial_input=tool_call.function.arguments,
                )
            raise ValueError(f"Unsupported tool call delta: {tool_call}")
        raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}")
    # OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr.
    if hasattr(content_delta, "audio") and getattr(content_delta, "audio") is not None:
        audio_chunk: dict = getattr(content_delta, "audio")
        return AudioDeltaMessageContent(
            id=audio_chunk.get("id"),
            data=audio_chunk.get("data"),
            expires_at=audio_chunk.get("expires_at"),
            transcript=audio_chunk.get("transcript"),
        )
    return None

__to_prompt_stack_message_content(response)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) -> list[BaseMessageContent]:
    content = []

    if response.content is not None:
        content.append(TextMessageContent(TextArtifact(response.content)))
    if hasattr(response, "audio") and response.audio is not None:
        content.append(
            AudioMessageContent(
                AudioArtifact(
                    value=base64.b64decode(response.audio.data),
                    format="wav",
                    meta={
                        "audio_id": response.audio.id,
                        "transcript": response.audio.transcript,
                        "expires_at": response.audio.expires_at,
                    },
                )
            )
        )
    if response.tool_calls is not None:
        content.extend(
            [
                ActionCallMessageContent(
                    ActionArtifact(
                        ToolAction(
                            tag=tool_call.id,
                            name=ToolAction.from_native_tool_name(tool_call.function.name)[0],
                            path=ToolAction.from_native_tool_name(tool_call.function.name)[1],
                            input=json.loads(tool_call.function.arguments),
                        ),
                    ),
                )
                for tool_call in response.tool_calls
            ],
        )

    return content

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    params = {
        "model": self.model,
        **({"user": self.user} if self.user else {}),
        **({"seed": self.seed} if self.seed is not None else {}),
        **({"modalities": self.modalities} if self.modalities and not self.is_reasoning_model else {}),
        **(
            {"reasoning_effort": self.reasoning_effort}
            if self.is_reasoning_model and self.model != "o1-mini"
            else {}
        ),
        **({"temperature": self.temperature} if not self.is_reasoning_model else {}),
        **({"audio": self.audio} if "audio" in self.modalities else {}),
        **(
            {"stop": self.tokenizer.stop_sequences}
            if not self.is_reasoning_model and self.tokenizer.stop_sequences
            else {}
        ),
        **({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
        **({"stream_options": {"include_usage": True}} if self.stream else {}),
        **self.extra_params,
    }

    if prompt_stack.tools and self.use_native_tools:
        params["tool_choice"] = self.tool_choice
        params["parallel_tool_calls"] = self.parallel_tool_calls

    if prompt_stack.output_schema is not None:
        if self.structured_output_strategy == "native":
            params["response_format"] = {
                "type": "json_schema",
                "json_schema": {
                    "name": "Output",
                    "schema": prompt_stack.to_output_json_schema(),
                    "strict": True,
                },
            }
        elif self.structured_output_strategy == "tool" and self.use_native_tools:
            params["tool_choice"] = "required"

    if self.response_format is not None:
        if self.response_format == {"type": "json_object"}:
            params["response_format"] = self.response_format
            # JSON mode still requires a system message instructing the LLM to output JSON.
            prompt_stack.add_system_message("Provide your response as a valid JSON object.")
        else:
            params["response_format"] = self.response_format

    if prompt_stack.tools and self.use_native_tools:
        params["tools"] = self.__to_openai_tools(prompt_stack.tools)

    messages = self.__to_openai_messages(prompt_stack.messages)

    params["messages"] = messages

    return params

_to_delta_message_stream(result)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def _to_delta_message_stream(self, result: Stream[ChatCompletionChunk]) -> Iterator[DeltaMessage]:
    for message in result:
        if message.usage is not None:
            yield DeltaMessage(
                usage=DeltaMessage.Usage(
                    input_tokens=message.usage.prompt_tokens,
                    output_tokens=message.usage.completion_tokens,
                ),
            )
        if message.choices:
            choice = message.choices[0]
            delta = choice.delta

            content = self.__to_prompt_stack_delta_message_content(delta)

            if content is not None:
                yield DeltaMessage(content=content)

_to_message(result)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
def _to_message(self, result: ChatCompletion) -> Message:
    if len(result.choices) == 1:
        choice_message = result.choices[0].message

        message = Message(
            content=self.__to_prompt_stack_message_content(choice_message),
            role=Message.ASSISTANT_ROLE,
        )
        if result.usage is not None:
            message.usage = Message.Usage(
                input_tokens=result.usage.prompt_tokens,
                output_tokens=result.usage.completion_tokens,
            )

        return message
    raise Exception("Completion with more than one choice is not supported yet.")

client()

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
@lazy_property()
def client(self) -> openai.OpenAI:
    return openai.OpenAI(
        base_url=self.base_url,
        api_key=self.api_key,
        organization=self.organization,
    )

try_run(prompt_stack)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
    params = self._base_params(prompt_stack)
    logger.debug(params)
    result = self.client.chat.completions.create(**params)

    logger.debug(result.model_dump())
    return self._to_message(result)

try_stream(prompt_stack)

Source Code in griptape/drivers/prompt/openai_chat_prompt_driver.py
@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
    params = self._base_params(prompt_stack)
    logger.debug({"stream": True, **params})
    result = self.client.chat.completions.create(**params, stream=True)

    return self._to_delta_message_stream(result)

OpenAiEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes

NameTypeDescription
modelstrOpenAI embedding model name. Defaults to text-embedding-3-small.
base_urlOptional[str]API URL. Defaults to OpenAI's v1 API URL.
api_keyOptional[str]API key to pass directly. Defaults to OPENAI_API_KEY environment variable.
organizationOptional[str]OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
tokenizerOpenAiTokenizerOptionally provide custom OpenAiTokenizer.
clientOpenAIOptionally provide custom openai.OpenAI client.
azure_deploymentOpenAIAn Azure OpenAi deployment id.
azure_endpointOpenAIAn Azure OpenAi endpoint.
azure_ad_tokenOpenAIAn optional Azure Active Directory token.
azure_ad_token_providerOpenAIAn optional Azure Active Directory token provider.
api_versionOpenAIAn Azure OpenAi API version.
Source Code in griptape/drivers/embedding/openai_embedding_driver.py
@define
class OpenAiEmbeddingDriver(BaseEmbeddingDriver):
    """OpenAI Embedding Driver.

    Attributes:
        model: OpenAI embedding model name. Defaults to `text-embedding-3-small`.
        base_url: API URL. Defaults to OpenAI's v1 API URL.
        api_key: API key to pass directly. Defaults to `OPENAI_API_KEY` environment variable.
        organization: OpenAI organization. Defaults to 'OPENAI_ORGANIZATION' environment variable.
        tokenizer: Optionally provide custom `OpenAiTokenizer`.
        client: Optionally provide custom `openai.OpenAI` client.
        azure_deployment: An Azure OpenAi deployment id.
        azure_endpoint: An Azure OpenAi endpoint.
        azure_ad_token: An optional Azure Active Directory token.
        azure_ad_token_provider: An optional Azure Active Directory token provider.
        api_version: An Azure OpenAi API version.
    """

    DEFAULT_MODEL = "text-embedding-3-small"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    tokenizer: OpenAiTokenizer = field(
        default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True),
        kw_only=True,
    )
    _client: Optional[openai.OpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.OpenAI:
        return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        # Address a performance issue in older ada models
        # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
        if self.model.endswith("001"):
            chunk = chunk.replace("\n", " ")
        return self.client.embeddings.create(**self._params(chunk)).data[0].embedding

    def _params(self, chunk: str) -> dict:
        return {"input": chunk, "model": self.model}
  • DEFAULT_MODEL = 'text-embedding-3-small' class-attribute instance-attribute

  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • base_url = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • organization = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: OpenAiTokenizer(model=self.model), takes_self=True), kw_only=True) class-attribute instance-attribute

_params(chunk)

Source Code in griptape/drivers/embedding/openai_embedding_driver.py
def _params(self, chunk: str) -> dict:
    return {"input": chunk, "model": self.model}

client()

Source Code in griptape/drivers/embedding/openai_embedding_driver.py
@lazy_property()
def client(self) -> openai.OpenAI:
    return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/openai_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    # Address a performance issue in older ada models
    # https://github.com/openai/openai-python/issues/418#issuecomment-1525939500
    if self.model.endswith("001"):
        chunk = chunk.replace("\n", " ")
    return self.client.embeddings.create(**self._params(chunk)).data[0].embedding

OpenAiImageGenerationDriver

Bases: BaseImageGenerationDriver

Attributes

NameTypeDescription
modelstrOpenAI model, for example 'dall-e-2' or 'dall-e-3'.
api_typeOptional[str]OpenAI API type, for example 'open_ai' or 'azure'.
api_versionOptional[str]API version.
base_urlOptional[str]API URL.
api_keyOptional[str]OpenAI API key.
organizationOptional[str]OpenAI organization ID.
styleOptional[Literal['vivid', 'natural']]Optional and only supported for dall-e-3, can be either 'vivid' or 'natural'.
qualityOptional[Literal['standard', 'hd', 'low', 'medium', 'high', 'auto']]Optional and only supported for dall-e-3. Accepts 'standard', 'hd'.
image_sizeOptional[Literal['256x256', '512x512', '1024x1024', '1024x1792', '1792x1024']]Size of the generated image. Must be one of the following, depending on the requested model: dall-e-2: [256x256, 512x512, 1024x1024] dall-e-3: [1024x1024, 1024x1792, 1792x1024] gpt-image-1: [1024x1024, 1536x1024, 1024x1536, auto]
response_formatLiteral['b64_json']The response format. Currently only supports 'b64_json' which will return a base64 encoded image in a JSON object.
backgroundOptional[Literal['transparent', 'opaque', 'auto']]Optional and only supported for gpt-image-1. Can be either 'transparent', 'opaque', or 'auto'.
moderationOptional[Literal['low', 'auto']]Optional and only supported for gpt-image-1. Can be either 'low' or 'auto'.
output_compressionOptional[int]Optional and only supported for gpt-image-1. Can be an integer between 0 and 100.
output_formatOptional[Literal['png', 'jpeg']]Optional and only supported for gpt-image-1. Can be either 'png' or 'jpeg'.
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
@define
class OpenAiImageGenerationDriver(BaseImageGenerationDriver):
    """Driver for the OpenAI image generation API.

    Attributes:
        model: OpenAI model, for example 'dall-e-2' or 'dall-e-3'.
        api_type: OpenAI API type, for example 'open_ai' or 'azure'.
        api_version: API version.
        base_url: API URL.
        api_key: OpenAI API key.
        organization: OpenAI organization ID.
        style: Optional and only supported for dall-e-3, can be either 'vivid' or 'natural'.
        quality: Optional and only supported for dall-e-3. Accepts 'standard', 'hd'.
        image_size: Size of the generated image. Must be one of the following, depending on the requested model:
            dall-e-2: [256x256, 512x512, 1024x1024]
            dall-e-3: [1024x1024, 1024x1792, 1792x1024]
            gpt-image-1: [1024x1024, 1536x1024, 1024x1536, auto]
        response_format: The response format. Currently only supports 'b64_json' which will return
            a base64 encoded image in a JSON object.
        background: Optional and only supported for gpt-image-1. Can be either 'transparent', 'opaque', or 'auto'.
        moderation: Optional and only supported for gpt-image-1. Can be either 'low' or 'auto'.
        output_compression: Optional and only supported for gpt-image-1. Can be an integer between 0 and 100.
        output_format: Optional and only supported for gpt-image-1. Can be either 'png' or 'jpeg'.
    """

    api_type: Optional[str] = field(default=openai.api_type, kw_only=True)
    api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
    style: Optional[Literal["vivid", "natural"]] = field(
        default=None, kw_only=True, metadata={"serializable": True, "model_allowlist": ["dall-e-3"]}
    )
    quality: Optional[Literal["standard", "hd", "low", "medium", "high", "auto"]] = field(
        default=None,
        kw_only=True,
        metadata={"serializable": True},
    )
    image_size: Optional[Literal["256x256", "512x512", "1024x1024", "1024x1792", "1792x1024"]] = field(
        default=None,
        kw_only=True,
        metadata={"serializable": True},
    )
    response_format: Literal["b64_json"] = field(
        default="b64_json",
        kw_only=True,
        metadata={"serializable": True, "model_denylist": ["gpt-image-1"]},
    )
    background: Optional[Literal["transparent", "opaque", "auto"]] = field(
        default=None,
        kw_only=True,
        metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]},
    )
    moderation: Optional[Literal["low", "auto"]] = field(
        default=None,
        kw_only=True,
        metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]},
    )
    output_compression: Optional[int] = field(
        default=None,
        kw_only=True,
        metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]},
    )
    output_format: Optional[Literal["png", "jpeg"]] = field(
        default=None,
        kw_only=True,
        metadata={"serializable": True, "model_allowlist": ["gpt-image-1"]},
    )
    _client: Optional[openai.OpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @image_size.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_image_size(self, attribute: str, value: str | None) -> None:
        """Validates the image size based on the model.

        Must be one of `1024x1024`, `1536x1024` (landscape), `1024x1536` (portrait), or `auto` (default value) for
        `gpt-image-1`, one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`, and
        one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3`.

        """
        if value is None:
            return

        if self.model.startswith("gpt-image"):
            allowed_sizes = ("1024x1024", "1536x1024", "1024x1536", "auto")
        elif self.model == "dall-e-2":
            allowed_sizes = ("256x256", "512x512", "1024x1024")
        elif self.model == "dall-e-3":
            allowed_sizes = ("1024x1024", "1792x1024", "1024x1792")
        else:
            raise NotImplementedError(f"Image size validation not implemented for model {self.model}")

        if value is not None and value not in allowed_sizes:
            raise ValueError(f"Image size, {value}, must be one of the following: {allowed_sizes}")

    @lazy_property()
    def client(self) -> openai.OpenAI:
        return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

    def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
        prompt = ", ".join(prompts)

        response = self.client.images.generate(
            model=self.model,
            prompt=prompt,
            n=1,
            **self._build_model_params(
                {
                    "size": "image_size",
                    "quality": "quality",
                    "style": "style",
                    "response_format": "response_format",
                    "background": "background",
                    "moderation": "moderation",
                    "output_compression": "output_compression",
                    "output_format": "output_format",
                }
            ),
        )

        return self._parse_image_response(response, prompt)

    def try_image_variation(
        self,
        prompts: list[str],
        image: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        """Creates a variation of an image.

        Only supported by for dall-e-2. Requires image size to be one of the following:
            [256x256, 512x512, 1024x1024]
        """
        if self.model != "dall-e-2":
            raise NotImplementedError("Image variation only supports dall-e-2")
        response = self.client.images.create_variation(
            image=image.value,
            n=1,
            response_format=self.response_format,
            size=self.image_size,  # pyright: ignore[reportArgumentType]
        )

        return self._parse_image_response(response, "")

    def try_image_inpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        prompt = ", ".join(prompts)
        response = self.client.images.edit(
            prompt=prompt,
            image=image.value,
            mask=mask.value,
            **self._build_model_params(
                {
                    "size": "image_size",
                    "response_format": "response_format",
                }
            ),
        )

        return self._parse_image_response(response, prompt)

    def try_image_outpainting(
        self,
        prompts: list[str],
        image: ImageArtifact,
        mask: ImageArtifact,
        negative_prompts: Optional[list[str]] = None,
    ) -> ImageArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

    def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageArtifact:
        from griptape.loaders.image_loader import ImageLoader

        if response.data is None or response.data[0] is None or response.data[0].b64_json is None:
            raise Exception("Failed to generate image")

        image_data = base64.b64decode(response.data[0].b64_json)

        image_artifact = ImageLoader().parse(image_data)

        image_artifact.meta["prompt"] = prompt
        image_artifact.meta["model"] = self.model

        return image_artifact

    def _build_model_params(self, values: dict) -> dict:
        """Builds parameters while considering field metadata and None values.

        Args:
            values: A dictionary mapping parameter names to field names.

        Field will be added to the params dictionary if all conditions are met:
            - The field value is not None
            - The model_allowlist is None or the model is in the allowlist
            - The model_denylist is None or the model is not in the denylist
        """
        params = {}

        fields = fields_dict(self.__class__)
        for param_name, field_name in values.items():
            metadata = fields[field_name].metadata
            model_allowlist = metadata.get("model_allowlist")
            model_denylist = metadata.get("model_denylist")

            field_value = getattr(self, field_name, None)

            allowlist_condition = model_allowlist is None or self.model in model_allowlist
            denylist_condition = model_denylist is None or self.model not in model_denylist

            if field_value is not None and allowlist_condition and denylist_condition:
                params[param_name] = field_value
        return params
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • api_type = field(default=openai.api_type, kw_only=True) class-attribute instance-attribute

  • api_version = field(default=openai.api_version, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • background = field(default=None, kw_only=True, metadata={'serializable': True, 'model_allowlist': ['gpt-image-1']}) class-attribute instance-attribute

  • base_url = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • image_size = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • moderation = field(default=None, kw_only=True, metadata={'serializable': True, 'model_allowlist': ['gpt-image-1']}) class-attribute instance-attribute

  • organization = field(default=openai.organization, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • output_compression = field(default=None, kw_only=True, metadata={'serializable': True, 'model_allowlist': ['gpt-image-1']}) class-attribute instance-attribute

  • output_format = field(default=None, kw_only=True, metadata={'serializable': True, 'model_allowlist': ['gpt-image-1']}) class-attribute instance-attribute

  • quality = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • response_format = field(default='b64_json', kw_only=True, metadata={'serializable': True, 'model_denylist': ['gpt-image-1']}) class-attribute instance-attribute

  • style = field(default=None, kw_only=True, metadata={'serializable': True, 'model_allowlist': ['dall-e-3']}) class-attribute instance-attribute

_build_model_params(values)

Builds parameters while considering field metadata and None values.

Parameters

NameTypeDescriptionDefault
valuesdictA dictionary mapping parameter names to field names.
required
Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def _build_model_params(self, values: dict) -> dict:
    """Builds parameters while considering field metadata and None values.

    Args:
        values: A dictionary mapping parameter names to field names.

    Field will be added to the params dictionary if all conditions are met:
        - The field value is not None
        - The model_allowlist is None or the model is in the allowlist
        - The model_denylist is None or the model is not in the denylist
    """
    params = {}

    fields = fields_dict(self.__class__)
    for param_name, field_name in values.items():
        metadata = fields[field_name].metadata
        model_allowlist = metadata.get("model_allowlist")
        model_denylist = metadata.get("model_denylist")

        field_value = getattr(self, field_name, None)

        allowlist_condition = model_allowlist is None or self.model in model_allowlist
        denylist_condition = model_denylist is None or self.model not in model_denylist

        if field_value is not None and allowlist_condition and denylist_condition:
            params[param_name] = field_value
    return params

_parse_image_response(response, prompt)

Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def _parse_image_response(self, response: ImagesResponse, prompt: str) -> ImageArtifact:
    from griptape.loaders.image_loader import ImageLoader

    if response.data is None or response.data[0] is None or response.data[0].b64_json is None:
        raise Exception("Failed to generate image")

    image_data = base64.b64decode(response.data[0].b64_json)

    image_artifact = ImageLoader().parse(image_data)

    image_artifact.meta["prompt"] = prompt
    image_artifact.meta["model"] = self.model

    return image_artifact

client()

Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
@lazy_property()
def client(self) -> openai.OpenAI:
    return openai.OpenAI(api_key=self.api_key, base_url=self.base_url, organization=self.organization)

try_image_inpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_image_inpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    prompt = ", ".join(prompts)
    response = self.client.images.edit(
        prompt=prompt,
        image=image.value,
        mask=mask.value,
        **self._build_model_params(
            {
                "size": "image_size",
                "response_format": "response_format",
            }
        ),
    )

    return self._parse_image_response(response, prompt)

try_image_outpainting(prompts, image, mask, negative_prompts=None)

Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_image_outpainting(
    self,
    prompts: list[str],
    image: ImageArtifact,
    mask: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support outpainting")

try_image_variation(prompts, image, negative_prompts=None)

Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_image_variation(
    self,
    prompts: list[str],
    image: ImageArtifact,
    negative_prompts: Optional[list[str]] = None,
) -> ImageArtifact:
    """Creates a variation of an image.

    Only supported by for dall-e-2. Requires image size to be one of the following:
        [256x256, 512x512, 1024x1024]
    """
    if self.model != "dall-e-2":
        raise NotImplementedError("Image variation only supports dall-e-2")
    response = self.client.images.create_variation(
        image=image.value,
        n=1,
        response_format=self.response_format,
        size=self.image_size,  # pyright: ignore[reportArgumentType]
    )

    return self._parse_image_response(response, "")

try_text_to_image(prompts, negative_prompts=None)

Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
    prompt = ", ".join(prompts)

    response = self.client.images.generate(
        model=self.model,
        prompt=prompt,
        n=1,
        **self._build_model_params(
            {
                "size": "image_size",
                "quality": "quality",
                "style": "style",
                "response_format": "response_format",
                "background": "background",
                "moderation": "moderation",
                "output_compression": "output_compression",
                "output_format": "output_format",
            }
        ),
    )

    return self._parse_image_response(response, prompt)

validate_image_size(attribute, value)

Source Code in griptape/drivers/image_generation/openai_image_generation_driver.py
@image_size.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_image_size(self, attribute: str, value: str | None) -> None:
    """Validates the image size based on the model.

    Must be one of `1024x1024`, `1536x1024` (landscape), `1024x1536` (portrait), or `auto` (default value) for
    `gpt-image-1`, one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`, and
    one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3`.

    """
    if value is None:
        return

    if self.model.startswith("gpt-image"):
        allowed_sizes = ("1024x1024", "1536x1024", "1024x1536", "auto")
    elif self.model == "dall-e-2":
        allowed_sizes = ("256x256", "512x512", "1024x1024")
    elif self.model == "dall-e-3":
        allowed_sizes = ("1024x1024", "1792x1024", "1024x1792")
    else:
        raise NotImplementedError(f"Image size validation not implemented for model {self.model}")

    if value is not None and value not in allowed_sizes:
        raise ValueError(f"Image size, {value}, must be one of the following: {allowed_sizes}")

OpenAiTextToSpeechDriver

Bases: BaseTextToSpeechDriver

Source Code in griptape/drivers/text_to_speech/openai_text_to_speech_driver.py
@define
class OpenAiTextToSpeechDriver(BaseTextToSpeechDriver):
    model: str = field(default="tts-1", kw_only=True, metadata={"serializable": True})
    voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = field(
        default="alloy",
        kw_only=True,
        metadata={"serializable": True},
    )
    format: Literal["mp3", "opus", "aac", "flac"] = field(default="mp3", kw_only=True, metadata={"serializable": True})
    api_type: Optional[str] = field(default=openai.api_type, kw_only=True)
    api_version: Optional[str] = field(default=openai.api_version, kw_only=True, metadata={"serializable": True})
    base_url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True)
    organization: Optional[str] = field(default=openai.organization, kw_only=True, metadata={"serializable": True})
    _client: Optional[openai.OpenAI] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> openai.OpenAI:
        return openai.OpenAI(
            api_key=self.api_key,
            base_url=self.base_url,
            organization=self.organization,
        )

    def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
        response = self.client.audio.speech.create(
            input=". ".join(prompts),
            voice=self.voice,
            model=self.model,
            response_format=self.format,
        )

        return AudioArtifact(value=response.content, format=self.format)
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True) class-attribute instance-attribute

  • api_type = field(default=openai.api_type, kw_only=True) class-attribute instance-attribute

  • api_version = field(default=openai.api_version, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • base_url = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • format = field(default='mp3', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(default='tts-1', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • organization = field(default=openai.organization, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • voice = field(default='alloy', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/text_to_speech/openai_text_to_speech_driver.py
@lazy_property()
def client(self) -> openai.OpenAI:
    return openai.OpenAI(
        api_key=self.api_key,
        base_url=self.base_url,
        organization=self.organization,
    )

try_text_to_audio(prompts)

Source Code in griptape/drivers/text_to_speech/openai_text_to_speech_driver.py
def try_text_to_audio(self, prompts: list[str]) -> AudioArtifact:
    response = self.client.audio.speech.create(
        input=". ".join(prompts),
        voice=self.voice,
        model=self.model,
        response_format=self.format,
    )

    return AudioArtifact(value=response.content, format=self.format)

OpenSearchVectorStoreDriver

Bases: BaseVectorStoreDriver

Attributes

NameTypeDescription
hoststrThe host of the OpenSearch cluster.
portintThe port of the OpenSearch cluster.
http_authstr | tuple[str, Optional[str]]The HTTP authentication credentials to use.
use_sslboolWhether to use SSL.
verify_certsboolWhether to verify SSL certificates.
index_namestrThe name of the index to use.
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
@define
class OpenSearchVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for OpenSearch.

    Attributes:
        host: The host of the OpenSearch cluster.
        port: The port of the OpenSearch cluster.
        http_auth: The HTTP authentication credentials to use.
        use_ssl: Whether to use SSL.
        verify_certs: Whether to verify SSL certificates.
        index_name: The name of the index to use.
    """

    host: str = field(kw_only=True, metadata={"serializable": True})
    port: int = field(default=443, kw_only=True, metadata={"serializable": True})
    http_auth: str | tuple[str, Optional[str]] = field(default=None, kw_only=True, metadata={"serializable": True})
    use_ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    verify_certs: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    index_name: str = field(kw_only=True, metadata={"serializable": True})
    _client: Optional[OpenSearch] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> OpenSearch:
        opensearchpy = import_optional_dependency("opensearchpy")

        return opensearchpy.OpenSearch(
            hosts=[{"host": self.host, "port": self.port}],
            http_auth=self.http_auth,
            use_ssl=self.use_ssl,
            verify_certs=self.verify_certs,
            connection_class=opensearchpy.RequestsHttpConnection,
        )

    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in OpenSearch.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        Metadata associated with the vector can also be provided.
        """
        vector_id = vector_id or utils.str_to_hash(str(vector))
        doc = {"vector": vector, "namespace": namespace, "metadata": meta}
        doc.update(kwargs)
        response = self.client.index(index=self.index_name, id=vector_id, body=doc)

        return response["_id"]

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Retrieves a specific vector entry from OpenSearch based on its identifier and optional namespace.

        Returns:
            If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.
        """
        try:
            query = {"bool": {"must": [{"term": {"_id": vector_id}}]}}

            if namespace:
                query["bool"]["must"].append({"term": {"namespace": namespace}})

            response = self.client.search(index=self.index_name, body={"query": query, "size": 1})

            if response["hits"]["total"]["value"] > 0:
                vector_data = response["hits"]["hits"][0]["_source"]
                return BaseVectorStoreDriver.Entry(
                    id=vector_id,
                    meta=vector_data.get("metadata"),
                    vector=vector_data.get("vector"),
                    namespace=vector_data.get("namespace"),
                )
            return None
        except Exception as e:
            logging.exception("Error while loading entry: %s", e)
            return None

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Retrieves all vector entries from OpenSearch that match the optional namespace.

        Returns:
            A list of BaseVectorStoreDriver.Entry objects.
        """
        query_body = {"size": 10000, "query": {"match_all": {}}}

        if namespace:
            query_body["query"] = {"match": {"namespace": namespace}}

        response = self.client.search(index=self.index_name, body=query_body)

        return [
            BaseVectorStoreDriver.Entry(
                id=hit["_id"],
                vector=hit["_source"].get("vector"),
                meta=hit["_source"].get("metadata"),
                namespace=hit["_source"].get("namespace"),
            )
            for hit in response["hits"]["hits"]
        ]

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        include_metadata: bool = True,
        field_name: str = "vector",
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided vector list.

        Results can be limited using the count parameter and optionally filtered by a namespace.

        Returns:
            A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
        """
        count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
        # Base k-NN query
        query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}}

        if namespace:
            query_body["query"] = {
                "bool": {
                    "must": [
                        {"match": {"namespace": namespace}},
                        {"knn": {field_name: {"vector": vector, "k": count}}},
                    ],
                },
            }

        response = self.client.search(index=self.index_name, body=query_body)

        return [
            BaseVectorStoreDriver.Entry(
                id=hit["_id"],
                namespace=hit["_source"].get("namespace") if namespace else None,
                score=hit["_score"],
                vector=hit["_source"].get("vector") if include_vectors else None,
                meta=hit["_source"].get("metadata") if include_metadata else None,
            )
            for hit in response["hits"]["hits"]
        ]

    def delete_vector(self, vector_id: str) -> NoReturn:
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • host = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • http_auth = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • index_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • port = field(default=443, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • use_ssl = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • verify_certs = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
@lazy_property()
def client(self) -> OpenSearch:
    opensearchpy = import_optional_dependency("opensearchpy")

    return opensearchpy.OpenSearch(
        hosts=[{"host": self.host, "port": self.port}],
        http_auth=self.http_auth,
        use_ssl=self.use_ssl,
        verify_certs=self.verify_certs,
        connection_class=opensearchpy.RequestsHttpConnection,
    )

delete_vector(vector_id)

Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn:
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

load_entries(*, namespace=None)

Retrieves all vector entries from OpenSearch that match the optional namespace.

Returns

TypeDescription
list[Entry]A list of BaseVectorStoreDriver.Entry objects.
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Retrieves all vector entries from OpenSearch that match the optional namespace.

    Returns:
        A list of BaseVectorStoreDriver.Entry objects.
    """
    query_body = {"size": 10000, "query": {"match_all": {}}}

    if namespace:
        query_body["query"] = {"match": {"namespace": namespace}}

    response = self.client.search(index=self.index_name, body=query_body)

    return [
        BaseVectorStoreDriver.Entry(
            id=hit["_id"],
            vector=hit["_source"].get("vector"),
            meta=hit["_source"].get("metadata"),
            namespace=hit["_source"].get("namespace"),
        )
        for hit in response["hits"]["hits"]
    ]

load_entry(vector_id, *, namespace=None)

Retrieves a specific vector entry from OpenSearch based on its identifier and optional namespace.

Returns

TypeDescription
Optional[Entry]If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Retrieves a specific vector entry from OpenSearch based on its identifier and optional namespace.

    Returns:
        If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.
    """
    try:
        query = {"bool": {"must": [{"term": {"_id": vector_id}}]}}

        if namespace:
            query["bool"]["must"].append({"term": {"namespace": namespace}})

        response = self.client.search(index=self.index_name, body={"query": query, "size": 1})

        if response["hits"]["total"]["value"] > 0:
            vector_data = response["hits"]["hits"][0]["_source"]
            return BaseVectorStoreDriver.Entry(
                id=vector_id,
                meta=vector_data.get("metadata"),
                vector=vector_data.get("vector"),
                namespace=vector_data.get("namespace"),
            )
        return None
    except Exception as e:
        logging.exception("Error while loading entry: %s", e)
        return None

query_vector(vector, *, count=None, namespace=None, include_vectors=False, include_metadata=True, field_name='vector', **kwargs)

Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided vector list.

Results can be limited using the count parameter and optionally filtered by a namespace.

Returns

TypeDescription
list[Entry]A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    include_metadata: bool = True,
    field_name: str = "vector",
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    """Performs a nearest neighbor search on OpenSearch to find vectors similar to the provided vector list.

    Results can be limited using the count parameter and optionally filtered by a namespace.

    Returns:
        A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
    """
    count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT
    # Base k-NN query
    query_body = {"size": count, "query": {"knn": {field_name: {"vector": vector, "k": count}}}}

    if namespace:
        query_body["query"] = {
            "bool": {
                "must": [
                    {"match": {"namespace": namespace}},
                    {"knn": {field_name: {"vector": vector, "k": count}}},
                ],
            },
        }

    response = self.client.search(index=self.index_name, body=query_body)

    return [
        BaseVectorStoreDriver.Entry(
            id=hit["_id"],
            namespace=hit["_source"].get("namespace") if namespace else None,
            score=hit["_score"],
            vector=hit["_source"].get("vector") if include_vectors else None,
            meta=hit["_source"].get("metadata") if include_metadata else None,
        )
        for hit in response["hits"]["hits"]
    ]

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/opensearch_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in OpenSearch.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    Metadata associated with the vector can also be provided.
    """
    vector_id = vector_id or utils.str_to_hash(str(vector))
    doc = {"vector": vector, "namespace": namespace, "metadata": meta}
    doc.update(kwargs)
    response = self.client.index(index=self.index_name, id=vector_id, body=doc)

    return response["_id"]

OpenTelemetryObservabilityDriver

Bases: BaseObservabilityDriver

Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
@define
class OpenTelemetryObservabilityDriver(BaseObservabilityDriver):
    service_name: str = field(default="griptape", kw_only=True)
    span_processor: SpanProcessor = field(kw_only=True)
    service_version: Optional[str] = field(default=None, kw_only=True)
    deployment_env: Optional[str] = field(default=None, kw_only=True)
    trace_provider: TracerProvider = field(
        default=Factory(
            lambda self: self._trace_provider_factory(),
            takes_self=True,
        ),
        kw_only=True,
    )
    _tracer: Tracer = field(init=False)
    _root_span_context_manager: Any = None

    def _trace_provider_factory(self) -> TracerProvider:
        opentelemetry_trace = import_optional_dependency("opentelemetry.sdk.trace")

        attributes = {"service.name": self.service_name}
        if self.service_version is not None:
            attributes["service.version"] = self.service_version
        if self.deployment_env is not None:
            attributes["deployment.environment"] = self.deployment_env
        return opentelemetry_trace.TracerProvider(
            resource=import_optional_dependency("opentelemetry.sdk.resources").Resource(attributes=attributes)
        )  # pyright: ignore[reportArgumentType]

    def __attrs_post_init__(self) -> None:
        opentelemetry_trace = import_optional_dependency("opentelemetry.trace")
        self.trace_provider.add_span_processor(self.span_processor)
        self._tracer = opentelemetry_trace.get_tracer(self.service_name, tracer_provider=self.trace_provider)

    def __enter__(self) -> None:
        opentelemetry_instrumentation_threading = import_optional_dependency("opentelemetry.instrumentation.threading")

        opentelemetry_instrumentation_threading.ThreadingInstrumentor().instrument()
        self._root_span_context_manager = self._tracer.start_as_current_span("main")  # pyright: ignore[reportCallIssue]
        self._root_span_context_manager.__enter__()

    def __exit__(
        self,
        exc_type: Optional[type[BaseException]],
        exc_value: Optional[BaseException],
        exc_traceback: Optional[TracebackType],
    ) -> bool:
        opentelemetry_trace = import_optional_dependency("opentelemetry.trace")
        opentelemetry_instrumentation_threading = import_optional_dependency("opentelemetry.instrumentation.threading")
        root_span = opentelemetry_trace.get_current_span()
        if exc_value:
            root_span = opentelemetry_trace.get_current_span()
            root_span.set_status(opentelemetry_trace.Status(opentelemetry_trace.StatusCode.ERROR))
            root_span.record_exception(exc_value)
        else:
            root_span.set_status(opentelemetry_trace.Status(opentelemetry_trace.StatusCode.OK))
        if self._root_span_context_manager:
            self._root_span_context_manager.__exit__(exc_type, exc_value, exc_traceback)
            self._root_span_context_manager = None
        self.trace_provider.force_flush()
        opentelemetry_instrumentation_threading.ThreadingInstrumentor().uninstrument()
        return False

    def observe(self, call: Observable.Call) -> Any:
        open_telemetry_trace = import_optional_dependency("opentelemetry.trace")
        func = call.func
        instance = call.instance
        tags = call.tags

        class_name = f"{instance.__class__.__name__}." if instance else ""
        span_name = f"{class_name}{func.__name__}()"
        with self._tracer.start_as_current_span(span_name) as span:  # pyright: ignore[reportCallIssue]
            if tags is not None:
                span.set_attribute("tags", tags)

            try:
                result = call()
                span.set_status(open_telemetry_trace.Status(open_telemetry_trace.StatusCode.OK))
                return result
            except Exception as e:
                span.set_status(open_telemetry_trace.Status(open_telemetry_trace.StatusCode.ERROR))
                span.record_exception(e)
                raise e

    def get_span_id(self) -> Optional[str]:
        opentelemetry_trace = import_optional_dependency("opentelemetry.trace")
        span = opentelemetry_trace.get_current_span()
        if span is opentelemetry_trace.INVALID_SPAN:
            return None
        return opentelemetry_trace.format_span_id(span.get_span_context().span_id)
  • _root_span_context_manager = None class-attribute instance-attribute

  • _tracer = field(init=False) class-attribute instance-attribute

  • deployment_env = field(default=None, kw_only=True) class-attribute instance-attribute

  • service_name = field(default='griptape', kw_only=True) class-attribute instance-attribute

  • service_version = field(default=None, kw_only=True) class-attribute instance-attribute

  • span_processor = field(kw_only=True) class-attribute instance-attribute

  • trace_provider = field(default=Factory(lambda self: self._trace_provider_factory(), takes_self=True), kw_only=True) class-attribute instance-attribute

attrs_post_init()

Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def __attrs_post_init__(self) -> None:
    opentelemetry_trace = import_optional_dependency("opentelemetry.trace")
    self.trace_provider.add_span_processor(self.span_processor)
    self._tracer = opentelemetry_trace.get_tracer(self.service_name, tracer_provider=self.trace_provider)

enter()

Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def __enter__(self) -> None:
    opentelemetry_instrumentation_threading = import_optional_dependency("opentelemetry.instrumentation.threading")

    opentelemetry_instrumentation_threading.ThreadingInstrumentor().instrument()
    self._root_span_context_manager = self._tracer.start_as_current_span("main")  # pyright: ignore[reportCallIssue]
    self._root_span_context_manager.__enter__()

exit(exc_type, exc_value, exc_traceback)

Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def __exit__(
    self,
    exc_type: Optional[type[BaseException]],
    exc_value: Optional[BaseException],
    exc_traceback: Optional[TracebackType],
) -> bool:
    opentelemetry_trace = import_optional_dependency("opentelemetry.trace")
    opentelemetry_instrumentation_threading = import_optional_dependency("opentelemetry.instrumentation.threading")
    root_span = opentelemetry_trace.get_current_span()
    if exc_value:
        root_span = opentelemetry_trace.get_current_span()
        root_span.set_status(opentelemetry_trace.Status(opentelemetry_trace.StatusCode.ERROR))
        root_span.record_exception(exc_value)
    else:
        root_span.set_status(opentelemetry_trace.Status(opentelemetry_trace.StatusCode.OK))
    if self._root_span_context_manager:
        self._root_span_context_manager.__exit__(exc_type, exc_value, exc_traceback)
        self._root_span_context_manager = None
    self.trace_provider.force_flush()
    opentelemetry_instrumentation_threading.ThreadingInstrumentor().uninstrument()
    return False

_trace_provider_factory()

Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def _trace_provider_factory(self) -> TracerProvider:
    opentelemetry_trace = import_optional_dependency("opentelemetry.sdk.trace")

    attributes = {"service.name": self.service_name}
    if self.service_version is not None:
        attributes["service.version"] = self.service_version
    if self.deployment_env is not None:
        attributes["deployment.environment"] = self.deployment_env
    return opentelemetry_trace.TracerProvider(
        resource=import_optional_dependency("opentelemetry.sdk.resources").Resource(attributes=attributes)
    )  # pyright: ignore[reportArgumentType]

get_span_id()

Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def get_span_id(self) -> Optional[str]:
    opentelemetry_trace = import_optional_dependency("opentelemetry.trace")
    span = opentelemetry_trace.get_current_span()
    if span is opentelemetry_trace.INVALID_SPAN:
        return None
    return opentelemetry_trace.format_span_id(span.get_span_context().span_id)

observe(call)

Source Code in griptape/drivers/observability/open_telemetry_observability_driver.py
def observe(self, call: Observable.Call) -> Any:
    open_telemetry_trace = import_optional_dependency("opentelemetry.trace")
    func = call.func
    instance = call.instance
    tags = call.tags

    class_name = f"{instance.__class__.__name__}." if instance else ""
    span_name = f"{class_name}{func.__name__}()"
    with self._tracer.start_as_current_span(span_name) as span:  # pyright: ignore[reportCallIssue]
        if tags is not None:
            span.set_attribute("tags", tags)

        try:
            result = call()
            span.set_status(open_telemetry_trace.Status(open_telemetry_trace.StatusCode.OK))
            return result
        except Exception as e:
            span.set_status(open_telemetry_trace.Status(open_telemetry_trace.StatusCode.ERROR))
            span.record_exception(e)
            raise e

PerplexityPromptDriver

Bases: OpenAiChatPromptDriver

Source Code in griptape/drivers/prompt/perplexity_prompt_driver.py
@define
class PerplexityPromptDriver(OpenAiChatPromptDriver):
    base_url: str = field(default="https://api.perplexity.ai", kw_only=True, metadata={"serializable": True})
    structured_output_strategy: str = field(default="native", kw_only=True, metadata={"serializable": True})

    @override
    def _to_message(self, result: ChatCompletion) -> Message:
        message = super()._to_message(result)

        message.content[0].artifact.meta["citations"] = getattr(result, "citations", [])

        return message

    def _base_params(self, prompt_stack: PromptStack) -> dict:
        params = super()._base_params(prompt_stack)

        if "stop" in params:
            del params["stop"]

        return params
  • base_url = field(default='https://api.perplexity.ai', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • structured_output_strategy = field(default='native', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

_base_params(prompt_stack)

Source Code in griptape/drivers/prompt/perplexity_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict:
    params = super()._base_params(prompt_stack)

    if "stop" in params:
        del params["stop"]

    return params

_to_message(result)

Source Code in griptape/drivers/prompt/perplexity_prompt_driver.py
@override
def _to_message(self, result: ChatCompletion) -> Message:
    message = super()._to_message(result)

    message.content[0].artifact.meta["citations"] = getattr(result, "citations", [])

    return message

PerplexityWebSearchDriver

Bases: BaseWebSearchDriver

Source Code in griptape/drivers/web_search/perplexity_web_search_driver.py
@define
class PerplexityWebSearchDriver(BaseWebSearchDriver):
    model: str = field(default="sonar-pro", kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(kw_only=True, default=None)
    _prompt_driver: Optional[PerplexityPromptDriver] = field(default=None, alias="prompt_driver")

    @lazy_property()
    def prompt_driver(self) -> PerplexityPromptDriver:
        if self.api_key is None:
            raise ValueError("api_key is required")
        return PerplexityPromptDriver(model=self.model, api_key=self.api_key)

    def search(self, query: str, **kwargs) -> ListArtifact:
        message = self.prompt_driver.run(PromptStack.from_artifact(TextArtifact(query)))

        return ListArtifact([message.to_artifact()])
  • _prompt_driver = field(default=None, alias='prompt_driver') class-attribute instance-attribute

  • api_key = field(kw_only=True, default=None) class-attribute instance-attribute

  • model = field(default='sonar-pro', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

prompt_driver()

Source Code in griptape/drivers/web_search/perplexity_web_search_driver.py
@lazy_property()
def prompt_driver(self) -> PerplexityPromptDriver:
    if self.api_key is None:
        raise ValueError("api_key is required")
    return PerplexityPromptDriver(model=self.model, api_key=self.api_key)

search(query, **kwargs)

Source Code in griptape/drivers/web_search/perplexity_web_search_driver.py
def search(self, query: str, **kwargs) -> ListArtifact:
    message = self.prompt_driver.run(PromptStack.from_artifact(TextArtifact(query)))

    return ListArtifact([message.to_artifact()])

PgAiKnowledgeBaseVectorStoreDriver

Bases: BaseVectorStoreDriver

Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
@define
class PgAiKnowledgeBaseVectorStoreDriver(BaseVectorStoreDriver):
    connection_string: str = field(kw_only=True, metadata={"serializable": True})
    knowledge_base_name: str = field(kw_only=True, metadata={"serializable": True})
    embedding_driver: BaseEmbeddingDriver = field(
        default=Factory(lambda: DummyEmbeddingDriver()),
        metadata={"serializable": True},
        kw_only=True,
        init=False,
    )
    _engine: sqlalchemy.Engine = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False})

    @lazy_property()
    def engine(self) -> sqlalchemy.Engine:
        return import_optional_dependency("sqlalchemy").create_engine(self.connection_string)

    def query(
        self,
        query: str | TextArtifact | ImageArtifact,
        *,
        count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        if isinstance(query, ImageArtifact):
            raise ValueError(f"{self.__class__.__name__} does not support querying with Image Artifacts.")

        sqlalchemy = import_optional_dependency("sqlalchemy")

        with sqlalchemy.orm.Session(self.engine) as session:
            rows = session.query(sqlalchemy.func.aidb.retrieve_text(self.knowledge_base_name, query, count)).all()

        entries = []
        for (row,) in rows:
            # Remove the first and last parentheses from the row and list by commas
            # Example: '(foo,bar)' -> ['foo', 'bar']
            row_list = "".join(row.replace("(", "", 1).rsplit(")", 1)).split(",")
            entries.append(
                BaseVectorStoreDriver.Entry(
                    id=row_list[0],
                    score=float(row_list[2]),
                    meta={"artifact": TextArtifact(row_list[1]).to_json()},
                )
            )

        return entries

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        raise NotImplementedError(f"{self.__class__.__name__} does not support vector upsert.")

    def upsert_text_artifact(
        self,
        artifact: TextArtifact,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        vector_id: Optional[str] = None,
        **kwargs,
    ) -> str:
        raise NotImplementedError(f"{self.__class__.__name__} does not support text artifact upsert.")

    def upsert_text(
        self,
        string: str,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.")

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
        raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")

    def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
        raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.")

    def delete_vector(self, vector_id: str) -> NoReturn:
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
  • _engine = field(default=None, kw_only=True, alias='engine', metadata={'serializable': False}) class-attribute instance-attribute

  • connection_string = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • embedding_driver = field(default=Factory(lambda: DummyEmbeddingDriver()), metadata={'serializable': True}, kw_only=True, init=False) class-attribute instance-attribute

  • knowledge_base_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

delete_vector(vector_id)

Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn:
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

engine()

Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
@lazy_property()
def engine(self) -> sqlalchemy.Engine:
    return import_optional_dependency("sqlalchemy").create_engine(self.connection_string)

load_artifacts(*, namespace=None)

Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def load_artifacts(self, *, namespace: Optional[str] = None) -> ListArtifact:
    raise NotImplementedError(f"{self.__class__.__name__} does not support Artifact loading.")

load_entries(*, namespace=None)

Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")

load_entry(vector_id, *, namespace=None)

Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
    raise NotImplementedError(f"{self.__class__.__name__} does not support entry loading.")

query(query, *, count=BaseVectorStoreDriver.DEFAULT_QUERY_COUNT, **kwargs)

Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def query(
    self,
    query: str | TextArtifact | ImageArtifact,
    *,
    count: Optional[int] = BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    if isinstance(query, ImageArtifact):
        raise ValueError(f"{self.__class__.__name__} does not support querying with Image Artifacts.")

    sqlalchemy = import_optional_dependency("sqlalchemy")

    with sqlalchemy.orm.Session(self.engine) as session:
        rows = session.query(sqlalchemy.func.aidb.retrieve_text(self.knowledge_base_name, query, count)).all()

    entries = []
    for (row,) in rows:
        # Remove the first and last parentheses from the row and list by commas
        # Example: '(foo,bar)' -> ['foo', 'bar']
        row_list = "".join(row.replace("(", "", 1).rsplit(")", 1)).split(",")
        entries.append(
            BaseVectorStoreDriver.Entry(
                id=row_list[0],
                score=float(row_list[2]),
                meta={"artifact": TextArtifact(row_list[1]).to_json()},
            )
        )

    return entries

upsert_text(string, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def upsert_text(
    self,
    string: str,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    raise NotImplementedError(f"{self.__class__.__name__} does not support text upsert.")

upsert_text_artifact(artifact, namespace=None, meta=None, vector_id=None, **kwargs)

Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def upsert_text_artifact(
    self,
    artifact: TextArtifact,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    vector_id: Optional[str] = None,
    **kwargs,
) -> str:
    raise NotImplementedError(f"{self.__class__.__name__} does not support text artifact upsert.")

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/pgai_knowledge_base_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    raise NotImplementedError(f"{self.__class__.__name__} does not support vector upsert.")

PgVectorVectorStoreDriver

Bases: BaseVectorStoreDriver

Attributes

NameTypeDescription
connection_stringOptional[str]An optional string describing the target Postgres database instance.
create_engine_paramsdictAdditional configuration params passed when creating the database connection.
engineEngineAn optional sqlalchemy Postgres engine to use.
table_namestrOptionally specify the name of the table to used to store vectors.
Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
@define
class PgVectorVectorStoreDriver(BaseVectorStoreDriver):
    """A vector store driver to Postgres using the PGVector extension.

    Attributes:
        connection_string: An optional string describing the target Postgres database instance.
        create_engine_params: Additional configuration params passed when creating the database connection.
        engine: An optional sqlalchemy Postgres engine to use.
        table_name: Optionally specify the name of the table to used to store vectors.
    """

    connection_string: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    create_engine_params: dict = field(factory=dict, kw_only=True, metadata={"serializable": True})
    table_name: str = field(kw_only=True, metadata={"serializable": True})
    _model: Any = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True))
    _engine: Optional[sqlalchemy.Engine] = field(
        default=None, kw_only=True, alias="engine", metadata={"serializable": False}
    )

    @connection_string.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_connection_string(self, _: Attribute, connection_string: Optional[str]) -> None:
        # If an engine is provided, the connection string is not used.
        if self._engine is not None:
            return

        # If an engine is not provided, a connection string is required.
        if connection_string is None:
            raise ValueError("An engine or connection string is required")

        if not connection_string.startswith("postgresql://"):
            raise ValueError("The connection string must describe a Postgres database connection")

    @lazy_property()
    def engine(self) -> sqlalchemy.Engine:
        return import_optional_dependency("sqlalchemy").create_engine(
            self.connection_string, **self.create_engine_params
        )

    def setup(
        self,
        *,
        create_schema: bool = True,
        install_uuid_extension: bool = True,
        install_vector_extension: bool = True,
    ) -> None:
        """Provides a mechanism to initialize the database schema and extensions."""
        sqlalchemy_sql = import_optional_dependency("sqlalchemy.sql")

        if install_uuid_extension:
            with self.engine.begin() as conn:
                conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";'))

        if install_vector_extension:
            with self.engine.begin() as conn:
                conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "vector";'))

        if create_schema:
            self._model.metadata.create_all(self.engine)

    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in the collection."""
        sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

        with sqlalchemy_orm.Session(self.engine) as session:
            obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs)

            obj = session.merge(obj)
            session.commit()

            return str(obj.id)

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
        """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
        sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

        with sqlalchemy_orm.Session(self.engine) as session:
            result = session.get(self._model, vector_id)

            return BaseVectorStoreDriver.Entry(
                id=result.id,
                vector=result.vector,
                namespace=result.namespace,
                meta=result.meta,
            )

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace."""
        sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

        with sqlalchemy_orm.Session(self.engine) as session:
            query = session.query(self._model)
            if namespace:
                query = query.filter_by(namespace=namespace)

            results = query.all()

            return [
                BaseVectorStoreDriver.Entry(
                    id=str(result.id),
                    vector=result.vector,
                    namespace=result.namespace,
                    meta=result.meta,
                )
                for result in results
            ]

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        distance_metric: str = "cosine_distance",
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace."""
        sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

        count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT

        distance_metrics = {
            "cosine_distance": self._model.vector.cosine_distance,
            "l2_distance": self._model.vector.l2_distance,
            "inner_product": self._model.vector.max_inner_product,
        }

        if distance_metric not in distance_metrics:
            raise ValueError("Invalid distance metric provided")

        op = distance_metrics[distance_metric]

        with sqlalchemy_orm.Session(self.engine) as session:
            # The query should return both the vector and the distance metric score.
            query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector))  # pyright: ignore[reportOptionalCall]

            filter_kwargs: Optional[OrderedDict] = None

            if namespace is not None:
                filter_kwargs = OrderedDict(namespace=namespace)

            if "filter" in kwargs and isinstance(kwargs["filter"], dict):
                filter_kwargs = filter_kwargs or OrderedDict()
                filter_kwargs.update(kwargs["filter"])

            if filter_kwargs is not None:
                query_result = query_result.filter_by(**filter_kwargs)

            results = query_result.limit(count).all()

            return [
                BaseVectorStoreDriver.Entry(
                    id=str(result[0].id),
                    vector=result[0].vector if include_vectors else None,
                    score=result[1],
                    meta=result[0].meta,
                    namespace=result[0].namespace,
                )
                for result in results
            ]

    def default_vector_model(self) -> Any:
        pgvector_sqlalchemy = import_optional_dependency("pgvector.sqlalchemy")
        sqlalchemy = import_optional_dependency("sqlalchemy")
        sqlalchemy_dialects_postgresql = import_optional_dependency("sqlalchemy.dialects.postgresql")
        sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

        @dataclass
        class VectorModel(sqlalchemy_orm.declarative_base()):
            __tablename__ = self.table_name

            id = sqlalchemy.Column(
                sqlalchemy_dialects_postgresql.UUID(as_uuid=True),
                primary_key=True,
                default=uuid.uuid4,
                unique=True,
                nullable=False,
            )
            vector = sqlalchemy.Column(pgvector_sqlalchemy.Vector())
            namespace = sqlalchemy.Column(sqlalchemy.String)
            meta = sqlalchemy.Column(sqlalchemy.JSON)

        return VectorModel

    def delete_vector(self, vector_id: str) -> NoReturn:
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
  • _engine = field(default=None, kw_only=True, alias='engine', metadata={'serializable': False}) class-attribute instance-attribute

  • _model = field(default=Factory(lambda self: self.default_vector_model(), takes_self=True)) class-attribute instance-attribute

  • connection_string = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • create_engine_params = field(factory=dict, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • table_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

default_vector_model()

Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def default_vector_model(self) -> Any:
    pgvector_sqlalchemy = import_optional_dependency("pgvector.sqlalchemy")
    sqlalchemy = import_optional_dependency("sqlalchemy")
    sqlalchemy_dialects_postgresql = import_optional_dependency("sqlalchemy.dialects.postgresql")
    sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

    @dataclass
    class VectorModel(sqlalchemy_orm.declarative_base()):
        __tablename__ = self.table_name

        id = sqlalchemy.Column(
            sqlalchemy_dialects_postgresql.UUID(as_uuid=True),
            primary_key=True,
            default=uuid.uuid4,
            unique=True,
            nullable=False,
        )
        vector = sqlalchemy.Column(pgvector_sqlalchemy.Vector())
        namespace = sqlalchemy.Column(sqlalchemy.String)
        meta = sqlalchemy.Column(sqlalchemy.JSON)

    return VectorModel

delete_vector(vector_id)

Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn:
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

engine()

Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
@lazy_property()
def engine(self) -> sqlalchemy.Engine:
    return import_optional_dependency("sqlalchemy").create_engine(
        self.connection_string, **self.create_engine_params
    )

load_entries(*, namespace=None)

Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Retrieves all vector entries from the collection, optionally filtering to only those that match the provided namespace."""
    sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

    with sqlalchemy_orm.Session(self.engine) as session:
        query = session.query(self._model)
        if namespace:
            query = query.filter_by(namespace=namespace)

        results = query.all()

        return [
            BaseVectorStoreDriver.Entry(
                id=str(result.id),
                vector=result.vector,
                namespace=result.namespace,
                meta=result.meta,
            )
            for result in results
        ]

load_entry(vector_id, *, namespace=None)

Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> BaseVectorStoreDriver.Entry:
    """Retrieves a specific vector entry from the collection based on its identifier and optional namespace."""
    sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

    with sqlalchemy_orm.Session(self.engine) as session:
        result = session.get(self._model, vector_id)

        return BaseVectorStoreDriver.Entry(
            id=result.id,
            vector=result.vector,
            namespace=result.namespace,
            meta=result.meta,
        )

query_vector(vector, *, count=None, namespace=None, include_vectors=False, distance_metric='cosine_distance', **kwargs)

Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    distance_metric: str = "cosine_distance",
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    """Performs a search on the collection to find vectors similar to the provided input vector, optionally filtering to only those that match the provided namespace."""
    sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

    count = count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT

    distance_metrics = {
        "cosine_distance": self._model.vector.cosine_distance,
        "l2_distance": self._model.vector.l2_distance,
        "inner_product": self._model.vector.max_inner_product,
    }

    if distance_metric not in distance_metrics:
        raise ValueError("Invalid distance metric provided")

    op = distance_metrics[distance_metric]

    with sqlalchemy_orm.Session(self.engine) as session:
        # The query should return both the vector and the distance metric score.
        query_result = session.query(self._model, op(vector).label("score")).order_by(op(vector))  # pyright: ignore[reportOptionalCall]

        filter_kwargs: Optional[OrderedDict] = None

        if namespace is not None:
            filter_kwargs = OrderedDict(namespace=namespace)

        if "filter" in kwargs and isinstance(kwargs["filter"], dict):
            filter_kwargs = filter_kwargs or OrderedDict()
            filter_kwargs.update(kwargs["filter"])

        if filter_kwargs is not None:
            query_result = query_result.filter_by(**filter_kwargs)

        results = query_result.limit(count).all()

        return [
            BaseVectorStoreDriver.Entry(
                id=str(result[0].id),
                vector=result[0].vector if include_vectors else None,
                score=result[1],
                meta=result[0].meta,
                namespace=result[0].namespace,
            )
            for result in results
        ]

setup(*, create_schema=True, install_uuid_extension=True, install_vector_extension=True)

Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def setup(
    self,
    *,
    create_schema: bool = True,
    install_uuid_extension: bool = True,
    install_vector_extension: bool = True,
) -> None:
    """Provides a mechanism to initialize the database schema and extensions."""
    sqlalchemy_sql = import_optional_dependency("sqlalchemy.sql")

    if install_uuid_extension:
        with self.engine.begin() as conn:
            conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";'))

    if install_vector_extension:
        with self.engine.begin() as conn:
            conn.execute(sqlalchemy_sql.text('CREATE EXTENSION IF NOT EXISTS "vector";'))

    if create_schema:
        self._model.metadata.create_all(self.engine)

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in the collection."""
    sqlalchemy_orm = import_optional_dependency("sqlalchemy.orm")

    with sqlalchemy_orm.Session(self.engine) as session:
        obj = self._model(id=vector_id, vector=vector, namespace=namespace, meta=meta, **kwargs)

        obj = session.merge(obj)
        session.commit()

        return str(obj.id)

validateconnection_string(, connection_string)

Source Code in griptape/drivers/vector/pgvector_vector_store_driver.py
@connection_string.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_connection_string(self, _: Attribute, connection_string: Optional[str]) -> None:
    # If an engine is provided, the connection string is not used.
    if self._engine is not None:
        return

    # If an engine is not provided, a connection string is required.
    if connection_string is None:
        raise ValueError("An engine or connection string is required")

    if not connection_string.startswith("postgresql://"):
        raise ValueError("The connection string must describe a Postgres database connection")

PineconeVectorStoreDriver

Bases: BaseVectorStoreDriver

Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
@define
class PineconeVectorStoreDriver(BaseVectorStoreDriver):
    api_key: str = field(kw_only=True, metadata={"serializable": True})
    index_name: str = field(kw_only=True, metadata={"serializable": True})
    environment: str = field(kw_only=True, metadata={"serializable": True})
    project_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    _client: Optional[pinecone.Pinecone] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )
    _index: Optional[pinecone.Index] = field(
        default=None, kw_only=True, alias="index", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> pinecone.Pinecone:
        return import_optional_dependency("pinecone").Pinecone(
            api_key=self.api_key,
            environment=self.environment,
            project_name=self.project_name,
        )

    @lazy_property()
    def index(self) -> pinecone.data.index.Index:
        return self.client.Index(self.index_name)

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        vector_id = vector_id or str_to_hash(str(vector))

        params: dict[str, Any] = {"namespace": namespace} | kwargs

        self.index.upsert(vectors=[(vector_id, vector, meta)], **params)  # pyright: ignore[reportArgumentType]

        return vector_id

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict()  # pyright: ignore[reportAttributeAccessIssue]
        vectors = list(result["vectors"].values())

        if len(vectors) > 0:
            vector = vectors[0]

            return BaseVectorStoreDriver.Entry(
                id=vector["id"],
                meta=vector["metadata"],
                vector=vector["values"],
                namespace=result["namespace"],
            )
        return None

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching
        # all values from a namespace:
        # https://community.pinecone.io/t/is-there-a-way-to-query-all-the-vectors-and-or-metadata-from-a-namespace/797/5

        results = self.index.query(
            vector=self.embedding_driver.embed("", vector_operation="query"),
            top_k=10000,
            include_metadata=True,
            namespace=namespace,
        )

        return [
            BaseVectorStoreDriver.Entry(
                id=r["id"],
                vector=r["values"],
                meta=r["metadata"],
                namespace=results["namespace"],  # pyright: ignore[reportIndexIssue]
            )
            for r in results["matches"]  # pyright: ignore[reportIndexIssue]
        ]

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        include_metadata: bool = True,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        params = {
            "top_k": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
            "namespace": namespace,
            "include_values": include_vectors,
            "include_metadata": include_metadata,
        } | kwargs

        results = self.index.query(vector=vector, **params)

        return [
            BaseVectorStoreDriver.Entry(
                id=r["id"],
                vector=r["values"],
                score=r["score"],
                meta=r["metadata"],
                namespace=results["namespace"],  # pyright: ignore[reportIndexIssue]
            )
            for r in results["matches"]  # pyright: ignore[reportIndexIssue]
        ]

    def delete_vector(self, vector_id: str) -> NoReturn:
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • _index = field(default=None, kw_only=True, alias='index', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • environment = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • index_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • project_name = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
@lazy_property()
def client(self) -> pinecone.Pinecone:
    return import_optional_dependency("pinecone").Pinecone(
        api_key=self.api_key,
        environment=self.environment,
        project_name=self.project_name,
    )

delete_vector(vector_id)

Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn:
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

index()

Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
@lazy_property()
def index(self) -> pinecone.data.index.Index:
    return self.client.Index(self.index_name)

load_entries(*, namespace=None)

Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    # This is a hacky way to query up to 10,000 values from Pinecone. Waiting on an official API for fetching
    # all values from a namespace:
    # https://community.pinecone.io/t/is-there-a-way-to-query-all-the-vectors-and-or-metadata-from-a-namespace/797/5

    results = self.index.query(
        vector=self.embedding_driver.embed("", vector_operation="query"),
        top_k=10000,
        include_metadata=True,
        namespace=namespace,
    )

    return [
        BaseVectorStoreDriver.Entry(
            id=r["id"],
            vector=r["values"],
            meta=r["metadata"],
            namespace=results["namespace"],  # pyright: ignore[reportIndexIssue]
        )
        for r in results["matches"]  # pyright: ignore[reportIndexIssue]
    ]

load_entry(vector_id, *, namespace=None)

Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict()  # pyright: ignore[reportAttributeAccessIssue]
    vectors = list(result["vectors"].values())

    if len(vectors) > 0:
        vector = vectors[0]

        return BaseVectorStoreDriver.Entry(
            id=vector["id"],
            meta=vector["metadata"],
            vector=vector["values"],
            namespace=result["namespace"],
        )
    return None

query_vector(vector, *, count=None, namespace=None, include_vectors=False, include_metadata=True, **kwargs)

Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    include_metadata: bool = True,
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    params = {
        "top_k": count or BaseVectorStoreDriver.DEFAULT_QUERY_COUNT,
        "namespace": namespace,
        "include_values": include_vectors,
        "include_metadata": include_metadata,
    } | kwargs

    results = self.index.query(vector=vector, **params)

    return [
        BaseVectorStoreDriver.Entry(
            id=r["id"],
            vector=r["values"],
            score=r["score"],
            meta=r["metadata"],
            namespace=results["namespace"],  # pyright: ignore[reportIndexIssue]
        )
        for r in results["matches"]  # pyright: ignore[reportIndexIssue]
    ]

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/pinecone_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    vector_id = vector_id or str_to_hash(str(vector))

    params: dict[str, Any] = {"namespace": namespace} | kwargs

    self.index.upsert(vectors=[(vector_id, vector, meta)], **params)  # pyright: ignore[reportArgumentType]

    return vector_id

ProxyWebScraperDriver

Bases: BaseWebScraperDriver

Source Code in griptape/drivers/web_scraper/proxy_web_scraper_driver.py
@define
class ProxyWebScraperDriver(BaseWebScraperDriver):
    proxies: dict = field(kw_only=True, metadata={"serializable": False})
    params: dict = field(default=Factory(dict), kw_only=True, metadata={"serializable": True})

    def fetch_url(self, url: str) -> str:
        response = requests.get(url, proxies=self.proxies, **self.params)

        return response.text

    def extract_page(self, page: str) -> TextArtifact:
        return TextArtifact(page)
  • params = field(default=Factory(dict), kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • proxies = field(kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

extract_page(page)

Source Code in griptape/drivers/web_scraper/proxy_web_scraper_driver.py
def extract_page(self, page: str) -> TextArtifact:
    return TextArtifact(page)

fetch_url(url)

Source Code in griptape/drivers/web_scraper/proxy_web_scraper_driver.py
def fetch_url(self, url: str) -> str:
    response = requests.get(url, proxies=self.proxies, **self.params)

    return response.text

PusherEventListenerDriver

Bases: BaseEventListenerDriver

Source Code in griptape/drivers/event_listener/pusher_event_listener_driver.py
@define
class PusherEventListenerDriver(BaseEventListenerDriver):
    app_id: str = field(kw_only=True, metadata={"serializable": True})
    key: str = field(kw_only=True, metadata={"serializable": True})
    secret: str = field(kw_only=True, metadata={"serializable": False})
    cluster: str = field(kw_only=True, metadata={"serializable": True})
    channel: str = field(kw_only=True, metadata={"serializable": True})
    event_name: str = field(kw_only=True, metadata={"serializable": True})
    ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    _client: Optional[Pusher] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> Pusher:
        return import_optional_dependency("pusher").Pusher(
            app_id=self.app_id,
            key=self.key,
            secret=self.secret,
            cluster=self.cluster,
            ssl=self.ssl,
        )

    def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
        data = [
            {"channel": self.channel, "name": self.event_name, "data": event_payload}
            for event_payload in event_payload_batch
        ]

        self.client.trigger_batch(data)

    def try_publish_event_payload(self, event_payload: dict) -> None:
        self.client.trigger(channels=self.channel, event_name=self.event_name, data=event_payload)
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • app_id = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • channel = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • cluster = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • event_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • key = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • secret = field(kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • ssl = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/event_listener/pusher_event_listener_driver.py
@lazy_property()
def client(self) -> Pusher:
    return import_optional_dependency("pusher").Pusher(
        app_id=self.app_id,
        key=self.key,
        secret=self.secret,
        cluster=self.cluster,
        ssl=self.ssl,
    )

try_publish_event_payload(event_payload)

Source Code in griptape/drivers/event_listener/pusher_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    self.client.trigger(channels=self.channel, event_name=self.event_name, data=event_payload)

try_publish_event_payload_batch(event_payload_batch)

Source Code in griptape/drivers/event_listener/pusher_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
    data = [
        {"channel": self.channel, "name": self.event_name, "data": event_payload}
        for event_payload in event_payload_batch
    ]

    self.client.trigger_batch(data)

QdrantVectorStoreDriver

Bases: BaseVectorStoreDriver

Attributes

NameTypeDescription
locationOptional[str]An optional location for the Qdrant client. If set to ':memory:', an in-memory client is used.
urlOptional[str]An optional Qdrant API URL.
hostOptional[str]An optional Qdrant host.
pathOptional[str]Persistence path for QdrantLocal. Default: None
portintThe port number for the Qdrant client. Defaults: 6333.
grpc_portintThe gRPC port number for the Qdrant client. Defaults: 6334.
prefer_grpcboolA boolean indicating whether to prefer gRPC over HTTP. Defaults: False.
force_disable_check_same_threadOptional[bool]For QdrantLocal, force disable check_same_thread. Default: False Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient.
timeoutOptional[int]Timeout for REST and gRPC API requests. Default: 5 seconds for REST and unlimited for gRPC
api_keyOptional[str]API key for authentication in Qdrant Cloud. Defaults: False
httpsboolIf true - use HTTPS(SSL) protocol. Default: True
prefixOptional[str]Add prefix to the REST URL path. Example: service/v1 will result in Example: service/v1 will result in http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. Defaults: None
distancestrThe distance metric to be used for the vectors. Defaults: 'COSINE'.
collection_namestrThe name of the Qdrant collection.
vector_nameOptional[str]An optional name for the vectors.
content_payload_keystrThe key for the content payload in the metadata. Defaults: 'data'.
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
@define
class QdrantVectorStoreDriver(BaseVectorStoreDriver):
    """Vector Store Driver for Qdrant.

    Attributes:
        location: An optional location for the Qdrant client. If set to ':memory:', an in-memory client is used.
        url: An optional Qdrant API URL.
        host: An optional Qdrant host.
        path: Persistence path for QdrantLocal. Default: None
        port: The port number for the Qdrant client. Defaults: 6333.
        grpc_port: The gRPC port number for the Qdrant client. Defaults: 6334.
        prefer_grpc: A boolean indicating whether to prefer gRPC over HTTP. Defaults: False.
        force_disable_check_same_thread: For QdrantLocal, force disable check_same_thread. Default: False Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient.
        timeout: Timeout for REST and gRPC API requests. Default: 5 seconds for REST and unlimited for gRPC
        api_key: API key for authentication in Qdrant Cloud. Defaults: False
        https: If true - use HTTPS(SSL) protocol. Default: True
        prefix: Add prefix to the REST URL path. Example: service/v1 will result in Example: service/v1 will result in http://localhost:6333/service/v1/{qdrant-endpoint} for REST API. Defaults: None
        distance: The distance metric to be used for the vectors. Defaults: 'COSINE'.
        collection_name: The name of the Qdrant collection.
        vector_name: An optional name for the vectors.
        content_payload_key: The key for the content payload in the metadata. Defaults: 'data'.
    """

    location: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    url: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    host: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    port: int = field(default=6333, kw_only=True, metadata={"serializable": True})
    grpc_port: int = field(default=6334, kw_only=True, metadata={"serializable": True})
    prefer_grpc: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    https: bool = field(default=True, kw_only=True, metadata={"serializable": True})
    prefix: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    force_disable_check_same_thread: Optional[bool] = field(
        default=False,
        kw_only=True,
        metadata={"serializable": True},
    )
    timeout: Optional[int] = field(default=5, kw_only=True, metadata={"serializable": True})
    distance: str = field(default=DEFAULT_DISTANCE, kw_only=True, metadata={"serializable": True})
    collection_name: str = field(kw_only=True, metadata={"serializable": True})
    vector_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
    content_payload_key: str = field(default=CONTENT_PAYLOAD_KEY, kw_only=True, metadata={"serializable": True})
    _client: Optional[QdrantClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> QdrantClient:
        return import_optional_dependency("qdrant_client").QdrantClient(
            location=self.location,
            url=self.url,
            host=self.host,
            path=self.path,
            port=self.port,
            prefer_grpc=self.prefer_grpc,
            grpc_port=self.grpc_port,
            api_key=self.api_key,
            https=self.https,
            prefix=self.prefix,
            force_disable_check_same_thread=self.force_disable_check_same_thread,
            timeout=self.timeout,
        )

    def delete_vector(self, vector_id: str) -> None:
        """Delete a vector from the Qdrant collection based on its ID.

        Parameters:
            vector_id (str | id): ID of the vector to delete.
        """
        deletion_response = self.client.delete(
            collection_name=self.collection_name,
            points_selector=import_optional_dependency("qdrant_client.http.models").PointIdsList(points=[vector_id]),
        )
        if deletion_response.status == import_optional_dependency("qdrant_client.http.models").UpdateStatus.COMPLETED:
            logging.info("ID %s is successfully deleted", vector_id)

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Query the Qdrant collection based on a query vector.

        Parameters:
            vector (list[float]): Query vector.
            count (Optional[int]): Optional number of results to return.
            namespace (Optional[str]): Optional namespace of the vectors.
            include_vectors (bool): Whether to include vectors in the results.

        Returns:
            list[BaseVectorStoreDriver.Entry]: List of Entry objects.
        """
        # Create a search request
        request = {"collection_name": self.collection_name, "query_vector": vector, "limit": count}
        request = {k: v for k, v in request.items() if v is not None}
        results = self.client.search(**request)

        # Convert results to QueryResult objects
        return [
            BaseVectorStoreDriver.Entry(
                id=str(result.id),
                vector=result.vector if include_vectors else [],  # pyright: ignore[reportArgumentType]
                score=result.score,
                meta={k: v for k, v in result.payload.items() if k not in ["_score", "_tensor_facets"]},
            )
            for result in results
            if result.payload is not None
        ]

    def upsert_vector(
        self,
        vector: list[float],
        *,
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        content: Optional[str] = None,
        **kwargs,
    ) -> str:
        """Upsert vectors into the Qdrant collection.

        Parameters:
            vector (list[float]): The vector to be upserted.
            vector_id (Optional[str]): Optional vector ID.
            namespace (Optional[str]): Optional namespace for the vector.
            meta (Optional[dict]): Optional dictionary containing metadata.
            content (Optional[str]): The text content to be included in the payload.

        Returns:
            str: The ID of the upserted vector.
        """
        if vector_id is None:
            vector_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(vector)))

        if meta is None:
            meta = {}

        if content:
            meta[self.content_payload_key] = content

        points = import_optional_dependency("qdrant_client.http.models").Batch(
            ids=[vector_id],
            vectors=[vector],
            payloads=[meta] if meta else None,
        )

        self.client.upsert(collection_name=self.collection_name, points=points)
        return vector_id

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Load a vector entry from the Qdrant collection based on its ID.

        Parameters:
            vector_id (str): ID of the vector to load.
            namespace (str, optional): Optional namespace of the vector.

        Returns:
            Optional[BaseVectorStoreDriver.Entry]: Vector entry if found, else None.
        """
        results = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id])
        if results:
            entry = results[0]
            if entry.payload is None:
                entry.payload = {}
            return BaseVectorStoreDriver.Entry(
                id=str(entry.id),
                vector=entry.vector if entry.vector is not None else [],  # pyright: ignore[reportArgumentType]
                meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]},
            )
        return None

    def load_entries(self, *, namespace: Optional[str] = None, **kwargs) -> list[BaseVectorStoreDriver.Entry]:
        """Load vector entries from the Qdrant collection.

        Parameters:
            namespace: Optional namespace of the vectors.

        Returns:
            List of points.
        """
        results = self.client.retrieve(
            collection_name=self.collection_name,
            ids=kwargs.get("ids", []),
            with_payload=kwargs.get("with_payload", True),
            with_vectors=kwargs.get("with_vectors", True),
        )

        return [
            BaseVectorStoreDriver.Entry(
                id=str(entry.id),
                vector=entry.vector if kwargs.get("with_vectors", True) else [],  # pyright: ignore[reportArgumentType]
                meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]},
            )
            for entry in results
            if entry.payload is not None
        ]
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • collection_name = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • content_payload_key = field(default=CONTENT_PAYLOAD_KEY, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • distance = field(default=DEFAULT_DISTANCE, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • force_disable_check_same_thread = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • grpc_port = field(default=6334, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • host = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • https = field(default=True, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • location = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • path = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • port = field(default=6333, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • prefer_grpc = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • prefix = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • timeout = field(default=5, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • url = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • vector_name = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
@lazy_property()
def client(self) -> QdrantClient:
    return import_optional_dependency("qdrant_client").QdrantClient(
        location=self.location,
        url=self.url,
        host=self.host,
        path=self.path,
        port=self.port,
        prefer_grpc=self.prefer_grpc,
        grpc_port=self.grpc_port,
        api_key=self.api_key,
        https=self.https,
        prefix=self.prefix,
        force_disable_check_same_thread=self.force_disable_check_same_thread,
        timeout=self.timeout,
    )

delete_vector(vector_id)

Delete a vector from the Qdrant collection based on its ID.

Parameters

NameTypeDescriptionDefault
vector_idstr | idID of the vector to delete.
required
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
def delete_vector(self, vector_id: str) -> None:
    """Delete a vector from the Qdrant collection based on its ID.

    Parameters:
        vector_id (str | id): ID of the vector to delete.
    """
    deletion_response = self.client.delete(
        collection_name=self.collection_name,
        points_selector=import_optional_dependency("qdrant_client.http.models").PointIdsList(points=[vector_id]),
    )
    if deletion_response.status == import_optional_dependency("qdrant_client.http.models").UpdateStatus.COMPLETED:
        logging.info("ID %s is successfully deleted", vector_id)

load_entries(*, namespace=None, **kwargs)

Load vector entries from the Qdrant collection.

Parameters

NameTypeDescriptionDefault
namespaceOptional[str]Optional namespace of the vectors.
None

Returns

TypeDescription
list[Entry]List of points.
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None, **kwargs) -> list[BaseVectorStoreDriver.Entry]:
    """Load vector entries from the Qdrant collection.

    Parameters:
        namespace: Optional namespace of the vectors.

    Returns:
        List of points.
    """
    results = self.client.retrieve(
        collection_name=self.collection_name,
        ids=kwargs.get("ids", []),
        with_payload=kwargs.get("with_payload", True),
        with_vectors=kwargs.get("with_vectors", True),
    )

    return [
        BaseVectorStoreDriver.Entry(
            id=str(entry.id),
            vector=entry.vector if kwargs.get("with_vectors", True) else [],  # pyright: ignore[reportArgumentType]
            meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]},
        )
        for entry in results
        if entry.payload is not None
    ]

load_entry(vector_id, *, namespace=None)

Load a vector entry from the Qdrant collection based on its ID.

Parameters

NameTypeDescriptionDefault
vector_idstrID of the vector to load.
required
namespacestrOptional namespace of the vector.
None

Returns

TypeDescription
Optional[Entry]Optional[BaseVectorStoreDriver.Entry]: Vector entry if found, else None.
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Load a vector entry from the Qdrant collection based on its ID.

    Parameters:
        vector_id (str): ID of the vector to load.
        namespace (str, optional): Optional namespace of the vector.

    Returns:
        Optional[BaseVectorStoreDriver.Entry]: Vector entry if found, else None.
    """
    results = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id])
    if results:
        entry = results[0]
        if entry.payload is None:
            entry.payload = {}
        return BaseVectorStoreDriver.Entry(
            id=str(entry.id),
            vector=entry.vector if entry.vector is not None else [],  # pyright: ignore[reportArgumentType]
            meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]},
        )
    return None

query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)

Query the Qdrant collection based on a query vector.

Parameters

NameTypeDescriptionDefault
vectorlist[float]Query vector.
required
countOptional[int]Optional number of results to return.
None
namespaceOptional[str]Optional namespace of the vectors.
None
include_vectorsboolWhether to include vectors in the results.
False

Returns

TypeDescription
list[Entry]list[BaseVectorStoreDriver.Entry]: List of Entry objects.
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    """Query the Qdrant collection based on a query vector.

    Parameters:
        vector (list[float]): Query vector.
        count (Optional[int]): Optional number of results to return.
        namespace (Optional[str]): Optional namespace of the vectors.
        include_vectors (bool): Whether to include vectors in the results.

    Returns:
        list[BaseVectorStoreDriver.Entry]: List of Entry objects.
    """
    # Create a search request
    request = {"collection_name": self.collection_name, "query_vector": vector, "limit": count}
    request = {k: v for k, v in request.items() if v is not None}
    results = self.client.search(**request)

    # Convert results to QueryResult objects
    return [
        BaseVectorStoreDriver.Entry(
            id=str(result.id),
            vector=result.vector if include_vectors else [],  # pyright: ignore[reportArgumentType]
            score=result.score,
            meta={k: v for k, v in result.payload.items() if k not in ["_score", "_tensor_facets"]},
        )
        for result in results
        if result.payload is not None
    ]

upsert_vector(vector, *, vector_id=None, namespace=None, meta=None, content=None, **kwargs)

Upsert vectors into the Qdrant collection.

Parameters

NameTypeDescriptionDefault
vectorlist[float]The vector to be upserted.
required
vector_idOptional[str]Optional vector ID.
None
namespaceOptional[str]Optional namespace for the vector.
None
metaOptional[dict]Optional dictionary containing metadata.
None
contentOptional[str]The text content to be included in the payload.
None

Returns

NameTypeDescription
strstrThe ID of the upserted vector.
Source Code in griptape/drivers/vector/qdrant_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    *,
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    content: Optional[str] = None,
    **kwargs,
) -> str:
    """Upsert vectors into the Qdrant collection.

    Parameters:
        vector (list[float]): The vector to be upserted.
        vector_id (Optional[str]): Optional vector ID.
        namespace (Optional[str]): Optional namespace for the vector.
        meta (Optional[dict]): Optional dictionary containing metadata.
        content (Optional[str]): The text content to be included in the payload.

    Returns:
        str: The ID of the upserted vector.
    """
    if vector_id is None:
        vector_id = str(uuid.uuid5(uuid.NAMESPACE_DNS, str(vector)))

    if meta is None:
        meta = {}

    if content:
        meta[self.content_payload_key] = content

    points = import_optional_dependency("qdrant_client.http.models").Batch(
        ids=[vector_id],
        vectors=[vector],
        payloads=[meta] if meta else None,
    )

    self.client.upsert(collection_name=self.collection_name, points=points)
    return vector_id

RedisConversationMemoryDriver

Bases: BaseConversationMemoryDriver

Attributes

NameTypeDescription
hoststrThe host of the Redis instance.
portintThe port of the Redis instance.
dbintThe database of the Redis instance.
usernamestrThe username of the Redis instance.
passwordOptional[str]The password of the Redis instance.
indexstrThe name of the index to use.
conversation_idstrThe id of the conversation.
Source Code in griptape/drivers/memory/conversation/redis_conversation_memory_driver.py
@define
class RedisConversationMemoryDriver(BaseConversationMemoryDriver):
    """A Conversation Memory Driver for Redis.

    This driver interfaces with a Redis instance and utilizes the Redis hashes and RediSearch module to store,
    retrieve, and query conversations in a structured manner.
    Proper setup of the Redis instance and RediSearch is necessary for the driver to function correctly.

    Attributes:
        host: The host of the Redis instance.
        port: The port of the Redis instance.
        db: The database of the Redis instance.
        username: The username of the Redis instance.
        password: The password of the Redis instance.
        index: The name of the index to use.
        conversation_id: The id of the conversation.
    """

    host: str = field(kw_only=True, metadata={"serializable": True})
    username: str = field(kw_only=True, default="default", metadata={"serializable": False})
    port: int = field(kw_only=True, metadata={"serializable": True})
    db: int = field(kw_only=True, default=0, metadata={"serializable": True})
    password: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    index: str = field(kw_only=True, metadata={"serializable": True})
    conversation_id: str = field(kw_only=True, default=uuid.uuid4().hex)

    client: Redis = field(
        default=Factory(
            lambda self: import_optional_dependency("redis").Redis(
                host=self.host,
                port=self.port,
                db=self.db,
                username=self.username,
                password=self.password,
                decode_responses=False,
            ),
            takes_self=True,
        ),
    )

    def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
        self.client.hset(self.index, self.conversation_id, json.dumps(self._to_params_dict(runs, metadata)))

    def load(self) -> tuple[list[Run], dict[str, Any]]:
        memory_json = self.client.hget(self.index, self.conversation_id)
        if memory_json is not None:
            return self._from_params_dict(json.loads(memory_json))  # pyright: ignore[reportArgumentType] https://github.com/redis/redis-py/issues/2399
        return [], {}
  • client = field(default=Factory(lambda self: import_optional_dependency('redis').Redis(host=self.host, port=self.port, db=self.db, username=self.username, password=self.password, decode_responses=False), takes_self=True)) class-attribute instance-attribute

  • conversation_id = field(kw_only=True, default=uuid.uuid4().hex) class-attribute instance-attribute

  • db = field(kw_only=True, default=0, metadata={'serializable': True}) class-attribute instance-attribute

  • host = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • index = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • password = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • port = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • username = field(kw_only=True, default='default', metadata={'serializable': False}) class-attribute instance-attribute

load()

Source Code in griptape/drivers/memory/conversation/redis_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]:
    memory_json = self.client.hget(self.index, self.conversation_id)
    if memory_json is not None:
        return self._from_params_dict(json.loads(memory_json))  # pyright: ignore[reportArgumentType] https://github.com/redis/redis-py/issues/2399
    return [], {}

store(runs, metadata)

Source Code in griptape/drivers/memory/conversation/redis_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
    self.client.hset(self.index, self.conversation_id, json.dumps(self._to_params_dict(runs, metadata)))

RedisVectorStoreDriver

Bases: BaseVectorStoreDriver

Attributes

NameTypeDescription
hoststrThe host of the Redis instance.
portintThe port of the Redis instance.
dbintThe database of the Redis instance.
usernamestrThe username of the Redis instance.
passwordOptional[str]The password of the Redis instance.
indexstrThe name of the index to use.
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
@define
class RedisVectorStoreDriver(BaseVectorStoreDriver):
    """A Vector Store Driver for Redis.

    This driver interfaces with a Redis instance and utilizes the Redis hashes and RediSearch module to store, retrieve, and query vectors in a structured manner.
    Proper setup of the Redis instance and RediSearch is necessary for the driver to function correctly.

    Attributes:
        host: The host of the Redis instance.
        port: The port of the Redis instance.
        db: The database of the Redis instance.
        username: The username of the Redis instance.
        password: The password of the Redis instance.
        index: The name of the index to use.
    """

    host: str = field(kw_only=True, metadata={"serializable": True})
    username: str = field(kw_only=True, default="default", metadata={"serializable": False})
    port: int = field(kw_only=True, metadata={"serializable": True})
    db: int = field(kw_only=True, default=0, metadata={"serializable": True})
    password: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    index: str = field(kw_only=True, metadata={"serializable": True})
    _client: Optional[Redis] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> Redis:
        return import_optional_dependency("redis").Redis(
            host=self.host,
            port=self.port,
            db=self.db,
            username=self.username,
            password=self.password,
            decode_responses=False,
        )

    def upsert_vector(
        self,
        vector: list[float],
        vector_id: Optional[str] = None,
        namespace: Optional[str] = None,
        meta: Optional[dict] = None,
        **kwargs,
    ) -> str:
        """Inserts or updates a vector in Redis.

        If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
        Metadata associated with the vector can also be provided.
        """
        vector_id = vector_id or str_to_hash(str(vector))
        key = self._generate_key(vector_id, namespace)
        bytes_vector = json.dumps(vector).encode("utf-8")

        mapping = {}
        mapping["vector"] = np.array(vector, dtype=np.float32).tobytes()
        mapping["vec_string"] = bytes_vector

        if namespace:
            mapping["namespace"] = namespace

        if meta:
            mapping["metadata"] = json.dumps(meta)

        self.client.hset(key, mapping=mapping)

        return vector_id

    def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
        """Retrieves a specific vector entry from Redis based on its identifier and optional namespace.

        Returns:
            If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.
        """
        key = self._generate_key(vector_id, namespace)
        result = self.client.hgetall(key)
        vector = np.frombuffer(result[b"vector"], dtype=np.float32).tolist()  # pyright: ignore[reportIndexIssue] https://github.com/redis/redis-py/issues/2399
        meta = json.loads(result[b"metadata"]) if b"metadata" in result else None  # pyright: ignore[reportIndexIssue, reportOperatorIssue]

        return BaseVectorStoreDriver.Entry(id=vector_id, meta=meta, vector=vector, namespace=namespace)

    def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
        """Retrieves all vector entries from Redis that match the optional namespace.

        Returns:
            A list of `BaseVectorStoreDriver.Entry` objects.
        """
        pattern = f"{namespace}:*" if namespace else "*"
        keys = self.client.keys(pattern)

        entries = []
        for key in keys:  # pyright: ignore[reportGeneralTypeIssues] https://github.com/redis/redis-py/issues/2399
            entry = self.load_entry(key.decode("utf-8"), namespace=namespace)
            if entry:
                entries.append(entry)

        return entries

    def query_vector(
        self,
        vector: list[float],
        *,
        count: Optional[int] = None,
        namespace: Optional[str] = None,
        include_vectors: bool = False,
        **kwargs,
    ) -> list[BaseVectorStoreDriver.Entry]:
        """Performs a nearest neighbor search on Redis to find vectors similar to the provided input vector.

        Results can be limited using the count parameter and optionally filtered by a namespace.

        Returns:
            A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
        """
        search_query = import_optional_dependency("redis.commands.search.query")

        filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*"
        query_expression = (
            search_query.Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]")
            .sort_by("score")
            .return_fields("id", "score", "metadata", "vec_string")
            .paging(0, count or 10)
            .dialect(2)
        )

        query_params = {"vector": np.array(vector, dtype=np.float32).tobytes()}

        results = self.client.ft(self.index).search(query_expression, query_params).docs  # pyright: ignore[reportArgumentType, reportAttributeAccessIssue]

        query_results = []
        for document in results:
            metadata = json.loads(document.metadata) if hasattr(document, "metadata") else None
            namespace = document.id.split(":")[0] if ":" in document.id else None
            vector_id = document.id.split(":")[1] if ":" in document.id else document.id
            vector_float_list = json.loads(document.vec_string) if include_vectors else None
            query_results.append(
                BaseVectorStoreDriver.Entry(
                    id=vector_id,
                    vector=vector_float_list,
                    score=float(document.score),
                    meta=metadata,
                    namespace=namespace,
                ),
            )
        return query_results

    def _generate_key(self, vector_id: str, namespace: Optional[str] = None) -> str:
        """Generates a Redis key using the provided vector ID and optionally a namespace."""
        return f"{namespace}:{vector_id}" if namespace else vector_id

    def _get_doc_prefix(self, namespace: Optional[str] = None) -> str:
        """Get the document prefix based on the provided namespace."""
        return f"{namespace}:" if namespace else ""

    def delete_vector(self, vector_id: str) -> NoReturn:
        raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • db = field(kw_only=True, default=0, metadata={'serializable': True}) class-attribute instance-attribute

  • host = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • index = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • password = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • port = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • username = field(kw_only=True, default='default', metadata={'serializable': False}) class-attribute instance-attribute

_generate_key(vector_id, namespace=None)

Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def _generate_key(self, vector_id: str, namespace: Optional[str] = None) -> str:
    """Generates a Redis key using the provided vector ID and optionally a namespace."""
    return f"{namespace}:{vector_id}" if namespace else vector_id

_get_doc_prefix(namespace=None)

Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def _get_doc_prefix(self, namespace: Optional[str] = None) -> str:
    """Get the document prefix based on the provided namespace."""
    return f"{namespace}:" if namespace else ""

client()

Source Code in griptape/drivers/vector/redis_vector_store_driver.py
@lazy_property()
def client(self) -> Redis:
    return import_optional_dependency("redis").Redis(
        host=self.host,
        port=self.port,
        db=self.db,
        username=self.username,
        password=self.password,
        decode_responses=False,
    )

delete_vector(vector_id)

Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def delete_vector(self, vector_id: str) -> NoReturn:
    raise NotImplementedError(f"{self.__class__.__name__} does not support deletion.")

load_entries(*, namespace=None)

Retrieves all vector entries from Redis that match the optional namespace.

Returns

TypeDescription
list[Entry]A list of BaseVectorStoreDriver.Entry objects.
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorStoreDriver.Entry]:
    """Retrieves all vector entries from Redis that match the optional namespace.

    Returns:
        A list of `BaseVectorStoreDriver.Entry` objects.
    """
    pattern = f"{namespace}:*" if namespace else "*"
    keys = self.client.keys(pattern)

    entries = []
    for key in keys:  # pyright: ignore[reportGeneralTypeIssues] https://github.com/redis/redis-py/issues/2399
        entry = self.load_entry(key.decode("utf-8"), namespace=namespace)
        if entry:
            entries.append(entry)

    return entries

load_entry(vector_id, *, namespace=None)

Retrieves a specific vector entry from Redis based on its identifier and optional namespace.

Returns

TypeDescription
Optional[Entry]If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]:
    """Retrieves a specific vector entry from Redis based on its identifier and optional namespace.

    Returns:
        If the entry is found, it returns an instance of BaseVectorStoreDriver.Entry; otherwise, None is returned.
    """
    key = self._generate_key(vector_id, namespace)
    result = self.client.hgetall(key)
    vector = np.frombuffer(result[b"vector"], dtype=np.float32).tolist()  # pyright: ignore[reportIndexIssue] https://github.com/redis/redis-py/issues/2399
    meta = json.loads(result[b"metadata"]) if b"metadata" in result else None  # pyright: ignore[reportIndexIssue, reportOperatorIssue]

    return BaseVectorStoreDriver.Entry(id=vector_id, meta=meta, vector=vector, namespace=namespace)

query_vector(vector, *, count=None, namespace=None, include_vectors=False, **kwargs)

Performs a nearest neighbor search on Redis to find vectors similar to the provided input vector.

Results can be limited using the count parameter and optionally filtered by a namespace.

Returns

TypeDescription
list[Entry]A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def query_vector(
    self,
    vector: list[float],
    *,
    count: Optional[int] = None,
    namespace: Optional[str] = None,
    include_vectors: bool = False,
    **kwargs,
) -> list[BaseVectorStoreDriver.Entry]:
    """Performs a nearest neighbor search on Redis to find vectors similar to the provided input vector.

    Results can be limited using the count parameter and optionally filtered by a namespace.

    Returns:
        A list of BaseVectorStoreDriver.Entry objects, each encapsulating the retrieved vector, its similarity score, metadata, and namespace.
    """
    search_query = import_optional_dependency("redis.commands.search.query")

    filter_expression = f"(@namespace:{{{namespace}}})" if namespace else "*"
    query_expression = (
        search_query.Query(f"{filter_expression}=>[KNN {count or 10} @vector $vector as score]")
        .sort_by("score")
        .return_fields("id", "score", "metadata", "vec_string")
        .paging(0, count or 10)
        .dialect(2)
    )

    query_params = {"vector": np.array(vector, dtype=np.float32).tobytes()}

    results = self.client.ft(self.index).search(query_expression, query_params).docs  # pyright: ignore[reportArgumentType, reportAttributeAccessIssue]

    query_results = []
    for document in results:
        metadata = json.loads(document.metadata) if hasattr(document, "metadata") else None
        namespace = document.id.split(":")[0] if ":" in document.id else None
        vector_id = document.id.split(":")[1] if ":" in document.id else document.id
        vector_float_list = json.loads(document.vec_string) if include_vectors else None
        query_results.append(
            BaseVectorStoreDriver.Entry(
                id=vector_id,
                vector=vector_float_list,
                score=float(document.score),
                meta=metadata,
                namespace=namespace,
            ),
        )
    return query_results

upsert_vector(vector, vector_id=None, namespace=None, meta=None, **kwargs)

Source Code in griptape/drivers/vector/redis_vector_store_driver.py
def upsert_vector(
    self,
    vector: list[float],
    vector_id: Optional[str] = None,
    namespace: Optional[str] = None,
    meta: Optional[dict] = None,
    **kwargs,
) -> str:
    """Inserts or updates a vector in Redis.

    If a vector with the given vector ID already exists, it is updated; otherwise, a new vector is inserted.
    Metadata associated with the vector can also be provided.
    """
    vector_id = vector_id or str_to_hash(str(vector))
    key = self._generate_key(vector_id, namespace)
    bytes_vector = json.dumps(vector).encode("utf-8")

    mapping = {}
    mapping["vector"] = np.array(vector, dtype=np.float32).tobytes()
    mapping["vec_string"] = bytes_vector

    if namespace:
        mapping["namespace"] = namespace

    if meta:
        mapping["metadata"] = json.dumps(meta)

    self.client.hset(key, mapping=mapping)

    return vector_id

SnowflakeSqlDriver

Bases: BaseSqlDriver

Source Code in griptape/drivers/sql/snowflake_sql_driver.py
@define
class SnowflakeSqlDriver(BaseSqlDriver):
    get_connection: Callable[[], SnowflakeConnection] = field(kw_only=True)
    _engine: Optional[Engine] = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False})

    @get_connection.validator  # pyright: ignore[reportFunctionMemberAccess]
    def validate_get_connection(self, _: Attribute, get_connection: Callable[[], SnowflakeConnection]) -> None:
        snowflake_connection = get_connection()
        snowflake = import_optional_dependency("snowflake")

        if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection):
            raise ValueError("The get_connection function must return a SnowflakeConnection")
        if not snowflake_connection.schema or not snowflake_connection.database:
            raise ValueError("Provide a schema and database for the Snowflake connection")

    @lazy_property()
    def engine(self) -> Engine:
        return import_optional_dependency("sqlalchemy").create_engine(
            "snowflake://not@used/db",
            creator=self.get_connection,
        )

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)

        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        with self.engine.connect() as con:
            results = con.execute(sqlalchemy.text(query))

            if results is not None:
                if results.returns_rows:
                    return [dict(result._mapping) for result in results]
                return None
            raise ValueError("No results found")

    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        try:
            metadata_obj = sqlalchemy.MetaData()
            metadata_obj.reflect(bind=self.engine)
            table = sqlalchemy.Table(table_name, metadata_obj, schema=schema, autoload=True, autoload_with=self.engine)
            return str([(c.name, c.type) for c in table.columns])
        except sqlalchemy.exc.NoSuchTableError:
            return None
  • _engine = field(default=None, kw_only=True, alias='engine', metadata={'serializable': False}) class-attribute instance-attribute

  • get_connection = field(kw_only=True) class-attribute instance-attribute

engine()

Source Code in griptape/drivers/sql/snowflake_sql_driver.py
@lazy_property()
def engine(self) -> Engine:
    return import_optional_dependency("sqlalchemy").create_engine(
        "snowflake://not@used/db",
        creator=self.get_connection,
    )

execute_query(query)

Source Code in griptape/drivers/sql/snowflake_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)

    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    return None

execute_query_raw(query)

Source Code in griptape/drivers/sql/snowflake_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Any]]]:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    with self.engine.connect() as con:
        results = con.execute(sqlalchemy.text(query))

        if results is not None:
            if results.returns_rows:
                return [dict(result._mapping) for result in results]
            return None
        raise ValueError("No results found")

get_table_schema(table_name, schema=None)

Source Code in griptape/drivers/sql/snowflake_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    try:
        metadata_obj = sqlalchemy.MetaData()
        metadata_obj.reflect(bind=self.engine)
        table = sqlalchemy.Table(table_name, metadata_obj, schema=schema, autoload=True, autoload_with=self.engine)
        return str([(c.name, c.type) for c in table.columns])
    except sqlalchemy.exc.NoSuchTableError:
        return None

validateget_connection(, get_connection)

Source Code in griptape/drivers/sql/snowflake_sql_driver.py
@get_connection.validator  # pyright: ignore[reportFunctionMemberAccess]
def validate_get_connection(self, _: Attribute, get_connection: Callable[[], SnowflakeConnection]) -> None:
    snowflake_connection = get_connection()
    snowflake = import_optional_dependency("snowflake")

    if not isinstance(snowflake_connection, snowflake.connector.SnowflakeConnection):
        raise ValueError("The get_connection function must return a SnowflakeConnection")
    if not snowflake_connection.schema or not snowflake_connection.database:
        raise ValueError("Provide a schema and database for the Snowflake connection")

SqlDriver

Bases: BaseSqlDriver

Source Code in griptape/drivers/sql/sql_driver.py
@define
class SqlDriver(BaseSqlDriver):
    engine_url: str = field(kw_only=True)
    create_engine_params: dict = field(factory=dict, kw_only=True)
    _engine: Optional[Engine] = field(default=None, kw_only=True, alias="engine", metadata={"serializable": False})

    @lazy_property()
    def engine(self) -> Engine:
        return import_optional_dependency("sqlalchemy").create_engine(self.engine_url, **self.create_engine_params)

    def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
        rows = self.execute_query_raw(query)

        if rows:
            return [BaseSqlDriver.RowResult(row) for row in rows]
        return None

    def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        with self.engine.connect() as con:
            results = con.execute(sqlalchemy.text(query))

            if results is not None:
                if results.returns_rows:
                    return [dict(result._mapping) for result in results]
                con.commit()
                return None
            raise ValueError("No result found")

    def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
        sqlalchemy_exc = import_optional_dependency("sqlalchemy.exc")

        try:
            return str(SqlDriver._get_table_schema(self.engine, table_name, schema))
        except sqlalchemy_exc.NoSuchTableError:
            return None

    @staticmethod
    @lru_cache
    def _get_table_schema(
        engine: Engine, table_name: str, schema: Optional[str] = None
    ) -> Optional[list[tuple[str, str]]]:
        sqlalchemy = import_optional_dependency("sqlalchemy")

        return [(col["name"], col["type"]) for col in sqlalchemy.inspect(engine).get_columns(table_name, schema=schema)]
  • _engine = field(default=None, kw_only=True, alias='engine', metadata={'serializable': False}) class-attribute instance-attribute

  • create_engine_params = field(factory=dict, kw_only=True) class-attribute instance-attribute

  • engine_url = field(kw_only=True) class-attribute instance-attribute

_get_table_schema(engine, table_name, schema=None)cachedstaticmethod

Source Code in griptape/drivers/sql/sql_driver.py
@staticmethod
@lru_cache
def _get_table_schema(
    engine: Engine, table_name: str, schema: Optional[str] = None
) -> Optional[list[tuple[str, str]]]:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    return [(col["name"], col["type"]) for col in sqlalchemy.inspect(engine).get_columns(table_name, schema=schema)]

engine()

Source Code in griptape/drivers/sql/sql_driver.py
@lazy_property()
def engine(self) -> Engine:
    return import_optional_dependency("sqlalchemy").create_engine(self.engine_url, **self.create_engine_params)

execute_query(query)

Source Code in griptape/drivers/sql/sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]:
    rows = self.execute_query_raw(query)

    if rows:
        return [BaseSqlDriver.RowResult(row) for row in rows]
    return None

execute_query_raw(query)

Source Code in griptape/drivers/sql/sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]:
    sqlalchemy = import_optional_dependency("sqlalchemy")

    with self.engine.connect() as con:
        results = con.execute(sqlalchemy.text(query))

        if results is not None:
            if results.returns_rows:
                return [dict(result._mapping) for result in results]
            con.commit()
            return None
        raise ValueError("No result found")

get_table_schema(table_name, schema=None)

Source Code in griptape/drivers/sql/sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]:
    sqlalchemy_exc = import_optional_dependency("sqlalchemy.exc")

    try:
        return str(SqlDriver._get_table_schema(self.engine, table_name, schema))
    except sqlalchemy_exc.NoSuchTableError:
        return None

StableDiffusion3ControlNetImageGenerationPipelineDriver

Bases: StableDiffusion3ImageGenerationPipelineDriver

Attributes

NameTypeDescription
controlnet_modelstrThe ControlNet model to use for image generation.
controlnet_conditioning_scaleOptional[float]The conditioning scale for the ControlNet model. Defaults to None.
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py
@define
class StableDiffusion3ControlNetImageGenerationPipelineDriver(StableDiffusion3ImageGenerationPipelineDriver):
    """Image generation model driver for Stable Diffusion 3 models with ControlNet.

    For more information, see the HuggingFace documentation for the StableDiffusion3ControlNetPipeline:
        https://huggingface.co/docs/diffusers/en/api/pipelines/controlnet_sd3

    Attributes:
        controlnet_model: The ControlNet model to use for image generation.
        controlnet_conditioning_scale: The conditioning scale for the ControlNet model. Defaults to None.
    """

    controlnet_model: str = field(kw_only=True)
    controlnet_conditioning_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})

    def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
        sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel
        sd3_controlnet_pipeline = import_optional_dependency(
            "diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet"
        ).StableDiffusion3ControlNetPipeline

        pipeline_params = {}
        controlnet_pipeline_params = {}
        if self.torch_dtype is not None:
            pipeline_params["torch_dtype"] = self.torch_dtype
            controlnet_pipeline_params["torch_dtype"] = self.torch_dtype

        if self.drop_t5_encoder:
            pipeline_params["text_encoder_3"] = None
            pipeline_params["tokenizer_3"] = None

        # For both Stable Diffusion and ControlNet, models can be provided either
        # as a path to a local file or as a HuggingFace model repo name.
        # We use the from_single_file method if the model is a local file and the
        # from_pretrained method if the model is a local directory or hosted on HuggingFace.
        if os.path.isfile(self.controlnet_model):
            pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file(
                self.controlnet_model, **controlnet_pipeline_params
            )
        else:
            pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained(
                self.controlnet_model, **controlnet_pipeline_params
            )

        if os.path.isfile(model):
            pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params)
        else:
            pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params)

        if self.enable_model_cpu_offload:
            pipeline.enable_model_cpu_offload()

        if device is not None:
            pipeline.to(device)

        return pipeline

    def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]:
        if image is None:
            raise ValueError("Input image is required for ControlNet pipelines.")

        return {"control_image": image}

    def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]:
        additional_params = super().make_additional_params(negative_prompts, device)

        del additional_params["height"]
        del additional_params["width"]

        if self.controlnet_conditioning_scale is not None:
            additional_params["controlnet_conditioning_scale"] = self.controlnet_conditioning_scale

        return additional_params
  • controlnet_conditioning_scale = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • controlnet_model = field(kw_only=True) class-attribute instance-attribute

make_additional_params(negative_prompts, device)

Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]:
    additional_params = super().make_additional_params(negative_prompts, device)

    del additional_params["height"]
    del additional_params["width"]

    if self.controlnet_conditioning_scale is not None:
        additional_params["controlnet_conditioning_scale"] = self.controlnet_conditioning_scale

    return additional_params

make_image_param(image)

Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]:
    if image is None:
        raise ValueError("Input image is required for ControlNet pipelines.")

    return {"control_image": image}

prepare_pipeline(model, device)

Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_controlnet_image_generation_pipeline_driver.py
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
    sd3_controlnet_model = import_optional_dependency("diffusers.models.controlnet_sd3").SD3ControlNetModel
    sd3_controlnet_pipeline = import_optional_dependency(
        "diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet"
    ).StableDiffusion3ControlNetPipeline

    pipeline_params = {}
    controlnet_pipeline_params = {}
    if self.torch_dtype is not None:
        pipeline_params["torch_dtype"] = self.torch_dtype
        controlnet_pipeline_params["torch_dtype"] = self.torch_dtype

    if self.drop_t5_encoder:
        pipeline_params["text_encoder_3"] = None
        pipeline_params["tokenizer_3"] = None

    # For both Stable Diffusion and ControlNet, models can be provided either
    # as a path to a local file or as a HuggingFace model repo name.
    # We use the from_single_file method if the model is a local file and the
    # from_pretrained method if the model is a local directory or hosted on HuggingFace.
    if os.path.isfile(self.controlnet_model):
        pipeline_params["controlnet"] = sd3_controlnet_model.from_single_file(
            self.controlnet_model, **controlnet_pipeline_params
        )
    else:
        pipeline_params["controlnet"] = sd3_controlnet_model.from_pretrained(
            self.controlnet_model, **controlnet_pipeline_params
        )

    if os.path.isfile(model):
        pipeline = sd3_controlnet_pipeline.from_single_file(model, **pipeline_params)
    else:
        pipeline = sd3_controlnet_pipeline.from_pretrained(model, **pipeline_params)

    if self.enable_model_cpu_offload:
        pipeline.enable_model_cpu_offload()

    if device is not None:
        pipeline.to(device)

    return pipeline

StableDiffusion3ImageGenerationPipelineDriver

Bases: BaseDiffusionImageGenerationPipelineDriver

Attributes

NameTypeDescription
widthintThe width of the generated image. Defaults to 1024. Must be a multiple of 64.
heightintThe height of the generated image. Defaults to 1024. Must be a multiple of 64.
seedOptional[int]The random seed to use for image generation. If not provided, a random seed will be used.
guidance_scaleOptional[float]The strength of the guidance loss. If not provided, the default value will be used.
stepsOptional[int]The number of inference steps to use in image generation. If not provided, the default value will be used.
torch_dtypeOptional[dtype]The torch data type to use for image generation. If not provided, the default value will be used.
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py
@define
class StableDiffusion3ImageGenerationPipelineDriver(BaseDiffusionImageGenerationPipelineDriver):
    """Image generation model driver for Stable Diffusion 3 models.

    For more information, see the HuggingFace documentation for the StableDiffusion3Pipeline:
        https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_3

    Attributes:
        width: The width of the generated image. Defaults to 1024. Must be a multiple of 64.
        height: The height of the generated image. Defaults to 1024. Must be a multiple of 64.
        seed: The random seed to use for image generation. If not provided, a random seed will be used.
        guidance_scale: The strength of the guidance loss. If not provided, the default value will be used.
        steps: The number of inference steps to use in image generation. If not provided, the default value will be used.
        torch_dtype: The torch data type to use for image generation. If not provided, the default value will be used.
    """

    width: int = field(default=1024, kw_only=True, metadata={"serializable": True})
    height: int = field(default=1024, kw_only=True, metadata={"serializable": True})
    seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    guidance_scale: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})
    steps: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
    torch_dtype: Optional[torch.dtype] = field(default=None, kw_only=True, metadata={"serializable": True})
    enable_model_cpu_offload: bool = field(default=False, kw_only=True, metadata={"serializable": True})
    drop_t5_encoder: bool = field(default=False, kw_only=True, metadata={"serializable": True})

    def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
        sd3_pipeline = import_optional_dependency(
            "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3"
        ).StableDiffusion3Pipeline

        pipeline_params = {}
        if self.torch_dtype is not None:
            pipeline_params["torch_dtype"] = self.torch_dtype

        if self.drop_t5_encoder:
            pipeline_params["text_encoder_3"] = None
            pipeline_params["tokenizer_3"] = None

        # A model can be provided either as a path to a local file
        # or as a HuggingFace model repo name.
        if os.path.isfile(model):
            # If the model provided is a local file (not a directory),
            # we load it using the from_single_file method.
            pipeline = sd3_pipeline.from_single_file(model, **pipeline_params)
        else:
            # If the model is a local directory or hosted on HuggingFace,
            # we load it using the from_pretrained method.
            pipeline = sd3_pipeline.from_pretrained(model, **pipeline_params)

        if self.enable_model_cpu_offload:
            pipeline.enable_model_cpu_offload()

        # Move inference to particular device if requested.
        if device is not None:
            pipeline.to(device)

        return pipeline

    def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]:
        return None

    def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]:
        torch_generator = import_optional_dependency("torch").Generator

        additional_params = {}
        if negative_prompts:
            additional_params["negative_prompt"] = ", ".join(negative_prompts)

        if self.width is not None:
            additional_params["width"] = self.width

        if self.height is not None:
            additional_params["height"] = self.height

        if self.seed is not None:
            additional_params["generator"] = [torch_generator(device=device).manual_seed(self.seed)]

        if self.guidance_scale is not None:
            additional_params["guidance_scale"] = self.guidance_scale

        if self.steps is not None:
            additional_params["num_inference_steps"] = self.steps

        return additional_params

    @property
    def output_image_dimensions(self) -> tuple[int, int]:
        return self.width, self.height
  • drop_t5_encoder = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • enable_model_cpu_offload = field(default=False, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • guidance_scale = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • height = field(default=1024, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • output_image_dimensions property

  • seed = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • steps = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • torch_dtype = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • width = field(default=1024, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

make_additional_params(negative_prompts, device)

Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]:
    torch_generator = import_optional_dependency("torch").Generator

    additional_params = {}
    if negative_prompts:
        additional_params["negative_prompt"] = ", ".join(negative_prompts)

    if self.width is not None:
        additional_params["width"] = self.width

    if self.height is not None:
        additional_params["height"] = self.height

    if self.seed is not None:
        additional_params["generator"] = [torch_generator(device=device).manual_seed(self.seed)]

    if self.guidance_scale is not None:
        additional_params["guidance_scale"] = self.guidance_scale

    if self.steps is not None:
        additional_params["num_inference_steps"] = self.steps

    return additional_params

make_image_param(image)

Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]:
    return None

prepare_pipeline(model, device)

Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_image_generation_pipeline_driver.py
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
    sd3_pipeline = import_optional_dependency(
        "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3"
    ).StableDiffusion3Pipeline

    pipeline_params = {}
    if self.torch_dtype is not None:
        pipeline_params["torch_dtype"] = self.torch_dtype

    if self.drop_t5_encoder:
        pipeline_params["text_encoder_3"] = None
        pipeline_params["tokenizer_3"] = None

    # A model can be provided either as a path to a local file
    # or as a HuggingFace model repo name.
    if os.path.isfile(model):
        # If the model provided is a local file (not a directory),
        # we load it using the from_single_file method.
        pipeline = sd3_pipeline.from_single_file(model, **pipeline_params)
    else:
        # If the model is a local directory or hosted on HuggingFace,
        # we load it using the from_pretrained method.
        pipeline = sd3_pipeline.from_pretrained(model, **pipeline_params)

    if self.enable_model_cpu_offload:
        pipeline.enable_model_cpu_offload()

    # Move inference to particular device if requested.
    if device is not None:
        pipeline.to(device)

    return pipeline

StableDiffusion3Img2ImgImageGenerationPipelineDriver

Bases: StableDiffusion3ImageGenerationPipelineDriver

Attributes

NameTypeDescription
strengthOptional[float]A value [0.0, 1.0] that determines the strength of the initial image in the output.
Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py
@define
class StableDiffusion3Img2ImgImageGenerationPipelineDriver(StableDiffusion3ImageGenerationPipelineDriver):
    """Image generation model driver for Stable Diffusion 3 model image to image pipelines.

    For more information, see the HuggingFace documentation for the StableDiffusion3Img2ImgPipeline:
        https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

    Attributes:
        strength: A value [0.0, 1.0] that determines the strength of the initial image in the output.
    """

    strength: Optional[float] = field(default=None, kw_only=True, metadata={"serializable": True})

    def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
        sd3_img2img_pipeline = import_optional_dependency(
            "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img"
        ).StableDiffusion3Img2ImgPipeline

        pipeline_params = {}
        if self.torch_dtype is not None:
            pipeline_params["torch_dtype"] = self.torch_dtype

        if self.drop_t5_encoder:
            pipeline_params["text_encoder_3"] = None
            pipeline_params["tokenizer_3"] = None

        # A model can be provided either as a path to a local file
        # or as a HuggingFace model repo name.
        if os.path.isfile(model):
            # If the model provided is a local file (not a directory),
            # we load it using the from_single_file method.
            pipeline = sd3_img2img_pipeline.from_single_file(model, **pipeline_params)
        else:
            # If the model is a local directory or hosted on HuggingFace,
            # we load it using the from_pretrained method.
            pipeline = sd3_img2img_pipeline.from_pretrained(model, **pipeline_params)

        if self.enable_model_cpu_offload:
            pipeline.enable_model_cpu_offload()

        # Move inference to particular device if requested.
        if device is not None:
            pipeline.to(device)

        return pipeline

    def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]:
        if image is None:
            raise ValueError("Input image is required for image to image pipelines.")

        return {"image": image}

    def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]:
        additional_params = super().make_additional_params(negative_prompts, device)

        # Explicit height and width params are not supported, but
        # are instead inferred from input image.
        del additional_params["height"]
        del additional_params["width"]

        if self.strength is not None:
            additional_params["strength"] = self.strength

        return additional_params
  • strength = field(default=None, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

make_additional_params(negative_prompts, device)

Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py
def make_additional_params(self, negative_prompts: Optional[list[str]], device: Optional[str]) -> dict[str, Any]:
    additional_params = super().make_additional_params(negative_prompts, device)

    # Explicit height and width params are not supported, but
    # are instead inferred from input image.
    del additional_params["height"]
    del additional_params["width"]

    if self.strength is not None:
        additional_params["strength"] = self.strength

    return additional_params

make_image_param(image)

Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py
def make_image_param(self, image: Optional[Image]) -> Optional[dict[str, Image]]:
    if image is None:
        raise ValueError("Input image is required for image to image pipelines.")

    return {"image": image}

prepare_pipeline(model, device)

Source Code in griptape/drivers/image_generation_pipeline/stable_diffusion_3_img_2_img_image_generation_pipeline_driver.py
def prepare_pipeline(self, model: str, device: Optional[str]) -> Any:
    sd3_img2img_pipeline = import_optional_dependency(
        "diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img"
    ).StableDiffusion3Img2ImgPipeline

    pipeline_params = {}
    if self.torch_dtype is not None:
        pipeline_params["torch_dtype"] = self.torch_dtype

    if self.drop_t5_encoder:
        pipeline_params["text_encoder_3"] = None
        pipeline_params["tokenizer_3"] = None

    # A model can be provided either as a path to a local file
    # or as a HuggingFace model repo name.
    if os.path.isfile(model):
        # If the model provided is a local file (not a directory),
        # we load it using the from_single_file method.
        pipeline = sd3_img2img_pipeline.from_single_file(model, **pipeline_params)
    else:
        # If the model is a local directory or hosted on HuggingFace,
        # we load it using the from_pretrained method.
        pipeline = sd3_img2img_pipeline.from_pretrained(model, **pipeline_params)

    if self.enable_model_cpu_offload:
        pipeline.enable_model_cpu_offload()

    # Move inference to particular device if requested.
    if device is not None:
        pipeline.to(device)

    return pipeline

TavilyWebSearchDriver

Bases: BaseWebSearchDriver

Source Code in griptape/drivers/web_search/tavily_web_search_driver.py
@define
class TavilyWebSearchDriver(BaseWebSearchDriver):
    api_key: str = field(kw_only=True)
    params: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})
    _client: Optional[TavilyClient] = field(
        default=None, kw_only=True, alias="client", metadata={"serializable": False}
    )

    @lazy_property()
    def client(self) -> TavilyClient:
        return import_optional_dependency("tavily").TavilyClient(self.api_key)

    def search(self, query: str, **kwargs) -> ListArtifact:
        response = self.client.search(query, max_results=self.results_count, **self.params, **kwargs)
        results = response.get("results", [])
        return ListArtifact([(JsonArtifact(result)) for result in results])
  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(kw_only=True) class-attribute instance-attribute

  • params = field(factory=dict, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

client()

Source Code in griptape/drivers/web_search/tavily_web_search_driver.py
@lazy_property()
def client(self) -> TavilyClient:
    return import_optional_dependency("tavily").TavilyClient(self.api_key)

search(query, **kwargs)

Source Code in griptape/drivers/web_search/tavily_web_search_driver.py
def search(self, query: str, **kwargs) -> ListArtifact:
    response = self.client.search(query, max_results=self.results_count, **self.params, **kwargs)
    results = response.get("results", [])
    return ListArtifact([(JsonArtifact(result)) for result in results])

TrafilaturaWebScraperDriver

Bases: BaseWebScraperDriver

Source Code in griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py
@define
class TrafilaturaWebScraperDriver(BaseWebScraperDriver):
    include_links: bool = field(default=True, kw_only=True)
    no_ssl: bool = field(default=False, kw_only=True)

    def fetch_url(self, url: str) -> str:
        trafilatura = import_optional_dependency("trafilatura")
        use_config = trafilatura.settings.use_config

        config = use_config()
        page = trafilatura.fetch_url(url, no_ssl=self.no_ssl)

        # This disables signal, so that trafilatura can work on any thread:
        # More info: https://trafilatura.readthedocs.io/usage-python.html#disabling-signal
        config.set("DEFAULT", "EXTRACTION_TIMEOUT", "0")

        # Disable error logging in trafilatura as it sometimes logs errors from lxml, even though
        # the end result of page parsing is successful.
        logging.getLogger("trafilatura").setLevel(logging.FATAL)

        if page is None:
            raise Exception("can't access URL")

        return page

    def extract_page(self, page: str) -> TextArtifact:
        trafilatura = import_optional_dependency("trafilatura")
        use_config = trafilatura.settings.use_config

        config = use_config()

        extracted_page = trafilatura.extract(
            page,
            include_links=self.include_links,
            output_format="json",
            config=config,
        )

        if not extracted_page:
            raise Exception("can't extract page")

        text = json.loads(extracted_page).get("text")

        return TextArtifact(text)
  • include_links = field(default=True, kw_only=True) class-attribute instance-attribute

  • no_ssl = field(default=False, kw_only=True) class-attribute instance-attribute

extract_page(page)

Source Code in griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py
def extract_page(self, page: str) -> TextArtifact:
    trafilatura = import_optional_dependency("trafilatura")
    use_config = trafilatura.settings.use_config

    config = use_config()

    extracted_page = trafilatura.extract(
        page,
        include_links=self.include_links,
        output_format="json",
        config=config,
    )

    if not extracted_page:
        raise Exception("can't extract page")

    text = json.loads(extracted_page).get("text")

    return TextArtifact(text)

fetch_url(url)

Source Code in griptape/drivers/web_scraper/trafilatura_web_scraper_driver.py
def fetch_url(self, url: str) -> str:
    trafilatura = import_optional_dependency("trafilatura")
    use_config = trafilatura.settings.use_config

    config = use_config()
    page = trafilatura.fetch_url(url, no_ssl=self.no_ssl)

    # This disables signal, so that trafilatura can work on any thread:
    # More info: https://trafilatura.readthedocs.io/usage-python.html#disabling-signal
    config.set("DEFAULT", "EXTRACTION_TIMEOUT", "0")

    # Disable error logging in trafilatura as it sometimes logs errors from lxml, even though
    # the end result of page parsing is successful.
    logging.getLogger("trafilatura").setLevel(logging.FATAL)

    if page is None:
        raise Exception("can't access URL")

    return page

VoyageAiEmbeddingDriver

Bases: BaseEmbeddingDriver

Attributes

NameTypeDescription
modelstrVoyageAI embedding model name. Defaults to voyage-large-2.
api_keyOptional[str]API key to pass directly. Defaults to VOYAGE_API_KEY environment variable.
tokenizerVoyageAiTokenizerOptionally provide custom VoyageAiTokenizer.
clientAnyOptionally provide custom VoyageAI Client.
input_typestrVoyageAI input type. Defaults to document.
Source Code in griptape/drivers/embedding/voyageai_embedding_driver.py
@define
class VoyageAiEmbeddingDriver(BaseEmbeddingDriver):
    """VoyageAI Embedding Driver.

    Attributes:
        model: VoyageAI embedding model name. Defaults to `voyage-large-2`.
        api_key: API key to pass directly. Defaults to `VOYAGE_API_KEY` environment variable.
        tokenizer: Optionally provide custom `VoyageAiTokenizer`.
        client: Optionally provide custom VoyageAI `Client`.
        input_type: VoyageAI input type. Defaults to `document`.
    """

    DEFAULT_MODEL = "voyage-large-2"

    model: str = field(default=DEFAULT_MODEL, kw_only=True, metadata={"serializable": True})
    api_key: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": False})
    tokenizer: VoyageAiTokenizer = field(
        default=Factory(lambda self: VoyageAiTokenizer(model=self.model, api_key=self.api_key), takes_self=True),
        kw_only=True,
    )
    input_type: str = field(default="document", kw_only=True, metadata={"serializable": True})
    _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

    @lazy_property()
    def client(self) -> Any:
        return import_optional_dependency("voyageai").Client(api_key=self.api_key)

    def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact, **kwargs) -> list[float]:
        if isinstance(artifact, TextArtifact):
            return self.try_embed_chunk(artifact.value, **kwargs)
        pil_image = import_optional_dependency("PIL.Image")

        return self.client.multimodal_embed([[pil_image.open(BytesIO(artifact.value))]], model=self.model).embeddings[0]

    def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
        return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0]
  • DEFAULT_MODEL = 'voyage-large-2' class-attribute instance-attribute

  • _client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False}) class-attribute instance-attribute

  • api_key = field(default=None, kw_only=True, metadata={'serializable': False}) class-attribute instance-attribute

  • input_type = field(default='document', kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • model = field(default=DEFAULT_MODEL, kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=Factory(lambda self: VoyageAiTokenizer(model=self.model, api_key=self.api_key), takes_self=True), kw_only=True) class-attribute instance-attribute

client()

Source Code in griptape/drivers/embedding/voyageai_embedding_driver.py
@lazy_property()
def client(self) -> Any:
    return import_optional_dependency("voyageai").Client(api_key=self.api_key)

try_embed_artifact(artifact, **kwargs)

Source Code in griptape/drivers/embedding/voyageai_embedding_driver.py
def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact, **kwargs) -> list[float]:
    if isinstance(artifact, TextArtifact):
        return self.try_embed_chunk(artifact.value, **kwargs)
    pil_image = import_optional_dependency("PIL.Image")

    return self.client.multimodal_embed([[pil_image.open(BytesIO(artifact.value))]], model=self.model).embeddings[0]

try_embed_chunk(chunk, **kwargs)

Source Code in griptape/drivers/embedding/voyageai_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]:
    return self.client.embed([chunk], model=self.model, input_type=self.input_type).embeddings[0]

WebhookEventListenerDriver

Bases: BaseEventListenerDriver

Source Code in griptape/drivers/event_listener/webhook_event_listener_driver.py
@define
class WebhookEventListenerDriver(BaseEventListenerDriver):
    webhook_url: str = field(kw_only=True)
    headers: Optional[dict] = field(default=None, kw_only=True)

    def try_publish_event_payload(self, event_payload: dict) -> None:
        response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers)
        response.raise_for_status()

    def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
        response = requests.post(url=self.webhook_url, json=event_payload_batch, headers=self.headers)
        response.raise_for_status()
  • headers = field(default=None, kw_only=True) class-attribute instance-attribute

  • webhook_url = field(kw_only=True) class-attribute instance-attribute

try_publish_event_payload(event_payload)

Source Code in griptape/drivers/event_listener/webhook_event_listener_driver.py
def try_publish_event_payload(self, event_payload: dict) -> None:
    response = requests.post(url=self.webhook_url, json=event_payload, headers=self.headers)
    response.raise_for_status()

try_publish_event_payload_batch(event_payload_batch)

Source Code in griptape/drivers/event_listener/webhook_event_listener_driver.py
def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
    response = requests.post(url=self.webhook_url, json=event_payload_batch, headers=self.headers)
    response.raise_for_status()

Could this page be better? Report a problem or suggest an addition!