griptape_cloud_prompt_driver
logger = logging.getLogger(Defaults.logging_config.logger_name)
module-attribute
Bases:
BasePromptDriver
Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
@define class GriptapeCloudPromptDriver(BasePromptDriver): model: Optional[str] = field(default=None, kw_only=True) base_url: str = field( default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), ) api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True ) tokenizer: BaseTokenizer = field( default=Factory( lambda self: SimpleTokenizer( characters_per_token=4, max_input_tokens=2000, max_output_tokens=self.max_tokens, ), takes_self=True, ), kw_only=True, ) use_native_tools: bool = field(default=True, kw_only=True) structured_output_strategy: StructuredOutputStrategy = field( default="native", kw_only=True, metadata={"serializable": True} ) @observable def try_run(self, prompt_stack: PromptStack) -> Message: url = griptape_cloud_url(self.base_url, "api/chat/messages") params = self._base_params(prompt_stack) logger.debug(params) response = requests.post(url, headers=self.headers, json=params) response.raise_for_status() response_json = response.json() logger.debug(response_json) return Message.from_dict(response_json) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: url = griptape_cloud_url(self.base_url, "api/chat/messages/stream") params = self._base_params(prompt_stack) logger.debug(params) with requests.post(url, headers=self.headers, json=params, stream=True) as response: response.raise_for_status() for line in response.iter_lines(): if line: decoded_line = line.decode("utf-8") if decoded_line.startswith("data:"): delta_message_payload = decoded_line.removeprefix("data:").strip() logger.debug(delta_message_payload) yield DeltaMessage.from_json(delta_message_payload) def _base_params(self, prompt_stack: PromptStack) -> dict: return { "messages": prompt_stack.to_dict()["messages"], "tools": self.__to_griptape_tools(prompt_stack.tools), **({"output_schema": prompt_stack.to_output_json_schema()} if prompt_stack.output_schema else {}), "driver_configuration": { **({"model": self.model} if self.model else {}), "max_tokens": self.max_tokens, "use_native_tools": self.use_native_tools, "temperature": self.temperature, "structured_output_strategy": self.structured_output_strategy, "extra_params": self.extra_params, }, } def __to_griptape_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "name": tool.name, "activities": [ { "name": activity.__name__, "description": tool.activity_description(activity), "json_schema": tool.to_activity_json_schema(activity, "Schema"), } for activity in tool.activities() ], } for tool in tools ]
api_key = field(default=Factory(lambda: os.environ['GT_CLOUD_API_KEY']))
class-attribute instance-attributebase_url = field(default=Factory(lambda: os.getenv('GT_CLOUD_BASE_URL', 'https://cloud.griptape.ai')))
class-attribute instance-attributeheaders = field(default=Factory(lambda self: {'Authorization': f'Bearer {self.api_key}'}, takes_self=True), kw_only=True)
class-attribute instance-attributemodel = field(default=None, kw_only=True)
class-attribute instance-attributestructured_output_strategy = field(default='native', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: SimpleTokenizer(characters_per_token=4, max_input_tokens=2000, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True)
class-attribute instance-attributeuse_native_tools = field(default=True, kw_only=True)
class-attribute instance-attribute
__to_griptape_tools(tools)
Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
def __to_griptape_tools(self, tools: list[BaseTool]) -> list[dict]: return [ { "name": tool.name, "activities": [ { "name": activity.__name__, "description": tool.activity_description(activity), "json_schema": tool.to_activity_json_schema(activity, "Schema"), } for activity in tool.activities() ], } for tool in tools ]
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: return { "messages": prompt_stack.to_dict()["messages"], "tools": self.__to_griptape_tools(prompt_stack.tools), **({"output_schema": prompt_stack.to_output_json_schema()} if prompt_stack.output_schema else {}), "driver_configuration": { **({"model": self.model} if self.model else {}), "max_tokens": self.max_tokens, "use_native_tools": self.use_native_tools, "temperature": self.temperature, "structured_output_strategy": self.structured_output_strategy, "extra_params": self.extra_params, }, }
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: url = griptape_cloud_url(self.base_url, "api/chat/messages") params = self._base_params(prompt_stack) logger.debug(params) response = requests.post(url, headers=self.headers, json=params) response.raise_for_status() response_json = response.json() logger.debug(response_json) return Message.from_dict(response_json)
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/griptape_cloud_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: url = griptape_cloud_url(self.base_url, "api/chat/messages/stream") params = self._base_params(prompt_stack) logger.debug(params) with requests.post(url, headers=self.headers, json=params, stream=True) as response: response.raise_for_status() for line in response.iter_lines(): if line: decoded_line = line.decode("utf-8") if decoded_line.startswith("data:"): delta_message_payload = decoded_line.removeprefix("data:").strip() logger.debug(delta_message_payload) yield DeltaMessage.from_json(delta_message_payload)
Could this page be better? Report a problem or suggest an addition!