griptape_cloud
__all__ = ['GriptapeCloudConversationMemoryDriver']
module-attribute
Bases:
BaseConversationMemoryDriver
Attributes
Name | Type | Description |
---|---|---|
thread_id | Optional[str] | The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to retrieve the ID from the environment variable GT_CLOUD_THREAD_ID . |
alias | Optional[str] | The alias of the Thread to store the conversation memory in. |
base_url | str | The base URL of the Gen AI Builder API. Defaults to the value of the environment variable GT_CLOUD_BASE_URL or https://cloud.griptape.ai . |
api_key | str | The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable GT_CLOUD_API_KEY . |
Raises
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
@define(kw_only=True) class GriptapeCloudConversationMemoryDriver(BaseConversationMemoryDriver): """A driver for storing conversation memory in the Gen AI Builder. Attributes: thread_id: The ID of the Thread to store the conversation memory in. If not provided, the driver will attempt to retrieve the ID from the environment variable `GT_CLOUD_THREAD_ID`. alias: The alias of the Thread to store the conversation memory in. base_url: The base URL of the Gen AI Builder API. Defaults to the value of the environment variable `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`. api_key: The API key to use for authenticating with the Gen AI Builder API. If not provided, the driver will attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`. Raises: ValueError: If `api_key` is not provided. """ thread_id: Optional[str] = field( default=None, metadata={"serializable": True}, ) alias: Optional[str] = field( default=None, metadata={"serializable": True}, ) base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), ) api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), init=False, ) _thread: Optional[dict] = field(default=None, init=False) @api_key.validator # pyright: ignore[reportAttributeAccessIssue] def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: if value is None: raise ValueError(f"{self.__class__.__name__} requires an API key") return value @property def thread(self) -> dict: """Try to get the Thread by ID, alias, or create a new one.""" if self._thread is None: thread = None if self.thread_id is None: self.thread_id = os.getenv("GT_CLOUD_THREAD_ID") if self.thread_id is not None: res = self._call_api("get", f"/threads/{self.thread_id}", raise_for_status=False) if res.status_code == 200: thread = res.json() # use name as 'alias' to get thread if thread is None and self.alias is not None: res = self._call_api("get", f"/threads?alias={self.alias}").json() if res.get("threads"): thread = res["threads"][0] self.thread_id = thread.get("thread_id") # no thread by name or thread_id if thread is None: data = {"name": uuid.uuid4().hex} if self.alias is None else {"name": self.alias, "alias": self.alias} thread = self._call_api("post", "/threads", data).json() self.thread_id = thread["thread_id"] self.alias = thread.get("alias") self._thread = thread return self._thread # pyright: ignore[reportReturnType] def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: # serialize the run artifacts to json strings messages = [ dict_merge( { "input": run.input.to_json(), "output": run.output.to_json(), "metadata": {"run_id": run.id}, }, run.meta, ) for run in runs ] body = dict_merge( { "messages": messages, }, metadata, ) # patch the Thread with the new messages and metadata # all old Messages are replaced with the new ones thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id self._call_api("patch", f"/threads/{thread_id}", body) self._thread = None def load(self) -> tuple[list[Run], dict[str, Any]]: from griptape.memory.structure import Run thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id # get the Messages from the Thread messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json() # retrieve the Thread to get the metadata thread_response = self._call_api("get", f"/threads/{thread_id}").json() runs = [ Run( **({"id": message["metadata"].pop("run_id", None)} if "run_id" in message.get("metadata") else {}), meta=message["metadata"], input=BaseArtifact.from_json(message["input"]), output=BaseArtifact.from_json(message["output"]), ) for message in messages_response.get("messages", []) ] return runs, thread_response.get("metadata", {}) def _get_url(self, path: str) -> str: path = path.lstrip("/") return griptape_cloud_url(self.base_url, f"api/{path}") def _call_api( self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True ) -> requests.Response: res = requests.request(method, self._get_url(path), json=json, headers=self.headers) if raise_for_status: res.raise_for_status() return res
_thread = field(default=None, init=False)
class-attribute instance-attributealias = field(default=None, metadata={'serializable': True})
class-attribute instance-attributeapi_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY']))
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')))
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), init=False)
class-attribute instance-attributethread
propertythread_id = field(default=None, metadata={'serializable': True})
class-attribute instance-attribute
_call_api(method, path, json=None, *, raise_for_status=True)
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def _call_api( self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True ) -> requests.Response: res = requests.request(method, self._get_url(path), json=json, headers=self.headers) if raise_for_status: res.raise_for_status() return res
_get_url(path)
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def _get_url(self, path: str) -> str: path = path.lstrip("/") return griptape_cloud_url(self.base_url, f"api/{path}")
load()
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def load(self) -> tuple[list[Run], dict[str, Any]]: from griptape.memory.structure import Run thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id # get the Messages from the Thread messages_response = self._call_api("get", f"/threads/{thread_id}/messages").json() # retrieve the Thread to get the metadata thread_response = self._call_api("get", f"/threads/{thread_id}").json() runs = [ Run( **({"id": message["metadata"].pop("run_id", None)} if "run_id" in message.get("metadata") else {}), meta=message["metadata"], input=BaseArtifact.from_json(message["input"]), output=BaseArtifact.from_json(message["output"]), ) for message in messages_response.get("messages", []) ] return runs, thread_response.get("metadata", {})
store(runs, metadata)
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
def store(self, runs: list[Run], metadata: dict[str, Any]) -> None: # serialize the run artifacts to json strings messages = [ dict_merge( { "input": run.input.to_json(), "output": run.output.to_json(), "metadata": {"run_id": run.id}, }, run.meta, ) for run in runs ] body = dict_merge( { "messages": messages, }, metadata, ) # patch the Thread with the new messages and metadata # all old Messages are replaced with the new ones thread_id = self.thread["thread_id"] if self.thread_id is None else self.thread_id self._call_api("patch", f"/threads/{thread_id}", body) self._thread = None
validateapi_key(, value)
Source Code in griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py
@api_key.validator # pyright: ignore[reportAttributeAccessIssue] def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: if value is None: raise ValueError(f"{self.__class__.__name__} requires an API key") return value
← Prev
griptape_cloud_conversation_memory_driver
↑ Up
conversation
Next →
local_conversation_memory_driver
Could this page be better? Report a problem or suggest an addition!