griptape_cloud

  • __all__ = ['GriptapeCloudConversationMemoryDriver'] module-attribute

Bases: BaseConversationMemoryDriver

Attributes

NameTypeDescription
thread_idOptional[str]The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to retrieve the ID from the environment variable GT_CLOUD_THREAD_ID.
aliasOptional[str]The alias of the Thread to store the conversation memory in.
base_urlstrThe base URL of the Gen AI Builder API. Defaults to the value of the environment variable GT_CLOUD_BASE_URL or https://cloud.griptape.ai.
api_keystrThe API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable GT_CLOUD_API_KEY.

Raises

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
@define(kw_only=True)
class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver):
    """A driver for storing conversation memory in the Gen AI Builder.

    Attributes:
        thread_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to
            retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`.
        alias: The alias of the Thread to store the conversation memory in.
        base_url: The base URL of the Gen AI Builder API. Defaults to the value of the environment variable
            `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`.
        api_key: The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will
            attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`.

    Raises:
        ValueError: If `api_key` is not provided.
    """

    thread_id: Optional[str] = field(
        default=None,
        metadata={"serializable": True},
    )
    alias: Optional[str] = field(
        default=None,
        metadata={"serializable": True},
    )
    base_url: str = field(
        default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
    )
    api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
    headers: dict = field(
        default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
        init=False,
    )
    _thread: Optional[dict] = field(default=None, init=False)

    @api_key.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
        if value is None:
            raise ValueError(f"{self.__class__.__name__} requires an API key")
        return value

    @property
    def thread(self) -> dict:
        """Try to get the Thread by ID, alias, or create a new one."""
        if self._thread is None:
            thread = None
            if self.thread_id is None:
                self.thread_id = os.getenv("GT_CLOUD_THREAD_ID")

            if self.thread_id is not None:
                res = self._call_api("get", f"/threads/{self.thread_id}", raise_for_status=False)
                if res.status_code == 200:
                    thread = res.json()

            # use name as 'alias' to get thread
            if thread is None and self.alias is not None:
                res = self._call_api("get", f"/threads?alias={self.alias}").json()
                if res.get("threads"):
                    thread = res["threads"][0]
                    self.thread_id = thread.get("thread_id")

            # no thread by name or thread_id
            if thread is None:
                data = {"name": uuid.uuid4().hex} if self.alias is None else {"name": self.alias, "alias": self.alias}
                thread = self._call_api("post", "/threads", data).json()
                self.thread_id = thread["thread_id"]
                self.alias = thread.get("alias")

            self._thread = thread

        return self._thread  # pyright: ignore[reportReturnType]

    def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
        # serialize the run artifacts to json strings
        messages = [
            dict_merge(
                {
                    "input": run.input.to_json(),
                    "output": run.output.to_json(),
                    "metadata": {"run_id": run.id},
                },
                run.meta,
            )
            for run in runs
        ]

        body = dict_merge(
            {
                "messages": messages,
            },
            metadata,
        )

        # patch the Thread with the new messages and metadata
        # all old Messages are replaced with the new ones
        thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id
        self._call_api("patch", f"/threads/{thread_id}", body)
        self._thread = None

    def load(self) -> tuple[list[Run], dict[str, Any]]:
        from griptape.memory.structure import Run

        thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id

        # get the Messages from the Thread
        messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json()

        # retrieve the Thread to get the metadata
        thread_response = self._call_api("get", f"/threads/{thread_id}").json()

        runs = [
            Run(
                **({"id": message["metadata"].pop("run_id", None)} if "run_id" in message.get("metadata") else {}),
                meta=message["metadata"],
                input=BaseArtifact.from_json(message["input"]),
                output=BaseArtifact.from_json(message["output"]),
            )
            for message in messages_response.get("messages", [])
        ]

        return runs, thread_response.get("metadata", {})

    def _get_url(self, path: str) -> str:
        path = path.lstrip("/")
        return griptape_cloud_url(self.base_url, f"api/{path}")

    def _call_api(
        self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True
    ) -> requests.Response:
        res = requests.request(method, self._get_url(path), json=json, headers=self.headers)
        if raise_for_status:
            res.raise_for_status()
        return res
  • _thread = field(default=None, init=False) class-attribute instance-attribute

  • alias = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

  • api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY'])) class-attribute instance-attribute

  • base_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai'))) class-attribute instance-attribute

  • headers = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), init=False) class-attribute instance-attribute

  • thread property

  • thread_id = field(default=None, metadata={'serializable': True}) class-attribute instance-attribute

_call_api(method, path, json=None, *, raise_for_status=True)

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def _call_api(
    self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True
) -> requests.Response:
    res = requests.request(method, self._get_url(path), json=json, headers=self.headers)
    if raise_for_status:
        res.raise_for_status()
    return res

_get_url(path)

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def _get_url(self, path: str) -> str:
    path = path.lstrip("/")
    return griptape_cloud_url(self.base_url, f"api/{path}")

load()

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]:
    from griptape.memory.structure import Run

    thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id

    # get the Messages from the Thread
    messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json()

    # retrieve the Thread to get the metadata
    thread_response = self._call_api("get", f"/threads/{thread_id}").json()

    runs = [
        Run(
            **({"id": message["metadata"].pop("run_id", None)} if "run_id" in message.get("metadata") else {}),
            meta=message["metadata"],
            input=BaseArtifact.from_json(message["input"]),
            output=BaseArtifact.from_json(message["output"]),
        )
        for message in messages_response.get("messages", [])
    ]

    return runs, thread_response.get("metadata", {})

store(runs, metadata)

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None:
    # serialize the run artifacts to json strings
    messages = [
        dict_merge(
            {
                "input": run.input.to_json(),
                "output": run.output.to_json(),
                "metadata": {"run_id": run.id},
            },
            run.meta,
        )
        for run in runs
    ]

    body = dict_merge(
        {
            "messages": messages,
        },
        metadata,
    )

    # patch the Thread with the new messages and metadata
    # all old Messages are replaced with the new ones
    thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id
    self._call_api("patch", f"/threads/{thread_id}", body)
    self._thread = None

validateapi_key(, value)

Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
@api_key.validator  # pyright: ignore[reportAttributeAccessIssue]
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
    if value is None:
        raise ValueError(f"{self.__class__.__name__} requires an API key")
    return value

Could this page be better? Report a problem or suggest an addition!