base_embedding_driver

  • VectorOperation = Literal['query', 'upsert'] module-attribute

Bases: SerializableMixin , ExponentialBackoffMixin , ABC

Attributes

NameTypeDescription
modelstrThe name of the model to use.
tokenizerOptional[BaseTokenizer]An instance of BaseTokenizer to use when calculating tokens.
Source Code in griptape/drivers/embedding/base_embedding_driver.py
@define
class BaseEmbeddingDriver(SerializableMixin, ExponentialBackoffMixin, ABC):
    """Base Embedding Driver.

    Attributes:
        model: The name of the model to use.
        tokenizer: An instance of `BaseTokenizer` to use when calculating tokens.
    """

    model: str = field(kw_only=True, metadata={"serializable": True})
    tokenizer: Optional[BaseTokenizer] = field(default=None, kw_only=True)
    chunker: Optional[BaseChunker] = field(init=False)

    def __attrs_post_init__(self) -> None:
        self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None

    def embed_text_artifact(
        self, artifact: TextArtifact, *, vector_operation: VectorOperation | None = None
    ) -> list[float]:
        warnings.warn(
            "`BaseEmbeddingDriver.embed_text_artifact` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.",
            DeprecationWarning,
            stacklevel=2,
        )
        return self.embed(artifact, vector_operation=vector_operation)

    def embed_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
        warnings.warn(
            "`BaseEmbeddingDriver.embed_string` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.",
            DeprecationWarning,
            stacklevel=2,
        )
        return self.embed(string, vector_operation=vector_operation)

    def embed(
        self, value: str | TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None
    ) -> list[float]:
        for attempt in self.retrying():
            with attempt:
                if isinstance(value, str):
                    if (
                        self.tokenizer is not None
                        and self.tokenizer.count_tokens(value) > self.tokenizer.max_input_tokens
                    ):
                        return self._embed_long_string(value, vector_operation=vector_operation)
                    return self.try_embed_chunk(value, vector_operation=vector_operation)
                if isinstance(value, TextArtifact):
                    return self.embed(value.to_text(), vector_operation=vector_operation)
                if isinstance(value, ImageArtifact):
                    return self.try_embed_artifact(value, vector_operation=vector_operation)
        raise RuntimeError("Failed to embed string.")

    def try_embed_artifact(
        self, artifact: TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None
    ) -> list[float]:
        # TODO: Mark as abstract method for griptape 2.0
        if isinstance(artifact, TextArtifact):
            return self.try_embed_chunk(artifact.value, vector_operation=vector_operation)
        raise ValueError(f"{self.__class__.__name__} does not support embedding images.")

    @abstractmethod
    def try_embed_chunk(self, chunk: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
        # TODO: Remove for griptape 2.0, subclasses should implement `try_embed_artifact` instead
        ...

    def _embed_long_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
        """Embeds a string that is too long to embed in one go.

        Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb
        """
        chunks = self.chunker.chunk(string)  # pyright: ignore[reportOptionalMemberAccess] In practice this is never None

        embedding_chunks = []
        length_chunks = []
        for chunk in chunks:
            embedding_chunks.append(self.embed(chunk.value, vector_operation=vector_operation))
            length_chunks.append(len(chunk))

        # generate weighted averages
        embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)

        # normalize length to 1
        embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)

        return embedding_chunks.tolist()
  • chunker = field(init=False) class-attribute instance-attribute

  • model = field(kw_only=True, metadata={'serializable': True}) class-attribute instance-attribute

  • tokenizer = field(default=None, kw_only=True) class-attribute instance-attribute

attrs_post_init()

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def __attrs_post_init__(self) -> None:
    self.chunker = TextChunker(tokenizer=self.tokenizer) if self.tokenizer else None

_embed_long_string(string, *, vector_operation=None)

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def _embed_long_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
    """Embeds a string that is too long to embed in one go.

    Adapted from: https://github.com/openai/openai-cookbook/blob/683e5f5a71bc7a1b0e5b7a35e087f53cc55fceea/examples/Embedding_long_inputs.ipynb
    """
    chunks = self.chunker.chunk(string)  # pyright: ignore[reportOptionalMemberAccess] In practice this is never None

    embedding_chunks = []
    length_chunks = []
    for chunk in chunks:
        embedding_chunks.append(self.embed(chunk.value, vector_operation=vector_operation))
        length_chunks.append(len(chunk))

    # generate weighted averages
    embedding_chunks = np.average(embedding_chunks, axis=0, weights=length_chunks)

    # normalize length to 1
    embedding_chunks = embedding_chunks / np.linalg.norm(embedding_chunks)

    return embedding_chunks.tolist()

embed(value, *, vector_operation=None)

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def embed(
    self, value: str | TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None
) -> list[float]:
    for attempt in self.retrying():
        with attempt:
            if isinstance(value, str):
                if (
                    self.tokenizer is not None
                    and self.tokenizer.count_tokens(value) > self.tokenizer.max_input_tokens
                ):
                    return self._embed_long_string(value, vector_operation=vector_operation)
                return self.try_embed_chunk(value, vector_operation=vector_operation)
            if isinstance(value, TextArtifact):
                return self.embed(value.to_text(), vector_operation=vector_operation)
            if isinstance(value, ImageArtifact):
                return self.try_embed_artifact(value, vector_operation=vector_operation)
    raise RuntimeError("Failed to embed string.")

embed_string(string, *, vector_operation=None)

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def embed_string(self, string: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
    warnings.warn(
        "`BaseEmbeddingDriver.embed_string` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.",
        DeprecationWarning,
        stacklevel=2,
    )
    return self.embed(string, vector_operation=vector_operation)

embed_text_artifact(artifact, *, vector_operation=None)

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def embed_text_artifact(
    self, artifact: TextArtifact, *, vector_operation: VectorOperation | None = None
) -> list[float]:
    warnings.warn(
        "`BaseEmbeddingDriver.embed_text_artifact` is deprecated and will be removed in a future release. `BaseEmbeddingDriver.embed` is a drop-in replacement.",
        DeprecationWarning,
        stacklevel=2,
    )
    return self.embed(artifact, vector_operation=vector_operation)

try_embed_artifact(artifact, *, vector_operation=None)

Source Code in griptape/drivers/embedding/base_embedding_driver.py
def try_embed_artifact(
    self, artifact: TextArtifact | ImageArtifact, *, vector_operation: VectorOperation | None = None
) -> list[float]:
    # TODO: Mark as abstract method for griptape 2.0
    if isinstance(artifact, TextArtifact):
        return self.try_embed_chunk(artifact.value, vector_operation=vector_operation)
    raise ValueError(f"{self.__class__.__name__} does not support embedding images.")

try_embed_chunk(chunk, *, vector_operation=None)abstractmethod

Source Code in griptape/drivers/embedding/base_embedding_driver.py
@abstractmethod
def try_embed_chunk(self, chunk: str, *, vector_operation: VectorOperation | None = None) -> list[float]:
    # TODO: Remove for griptape 2.0, subclasses should implement `try_embed_artifact` instead
    ...

Could this page be better? Report a problem or suggest an addition!