amazon_redshift_sql_driver
Bases:
BaseSqlDriver
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@define class AmazonRedshiftSqlDriver(BaseSqlDriver): database: str = field(kw_only=True) session: boto3.Session = field(kw_only=True) cluster_identifier: Optional[str] = field(default=None, kw_only=True) workgroup_name: Optional[str] = field(default=None, kw_only=True) db_user: Optional[str] = field(default=None, kw_only=True) database_credentials_secret_arn: Optional[str] = field(default=None, kw_only=True) wait_for_query_completion_sec: float = field(default=0.3, kw_only=True) _client: Optional[RedshiftDataAPIServiceClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> RedshiftDataAPIServiceClient: return self.session.client("redshift-data") @workgroup_name.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None: if not self.cluster_identifier and not self.workgroup_name: raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`") if self.cluster_identifier and self.workgroup_name: raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both") @classmethod def _process_rows_from_records(cls, records: list) -> list[list]: return [[c[list(c.keys())[0]] for c in r] for r in records] @classmethod def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]: return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows] @classmethod def _process_columns_from_column_metadata(cls, meta: dict) -> list: return [k["name"] for k in meta] @classmethod def _post_process(cls, meta: dict, records: list) -> list[dict[str, Any]]: columns = cls._process_columns_from_column_metadata(meta) rows = cls._process_rows_from_records(records) return cls._process_cells_from_rows_and_columns(columns, rows) def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: rows = self.execute_query_raw(query) if rows: return [BaseSqlDriver.RowResult(row) for row in rows] return None def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]: function_kwargs = {"Sql": query, "Database": self.database} if self.workgroup_name: function_kwargs["WorkgroupName"] = self.workgroup_name if self.cluster_identifier: function_kwargs["ClusterIdentifier"] = self.cluster_identifier if self.db_user: function_kwargs["DbUser"] = self.db_user if self.database_credentials_secret_arn: function_kwargs["SecretArn"] = self.database_credentials_secret_arn response = self.client.execute_statement(**function_kwargs) # pyright: ignore[reportArgumentType] response_id = response["Id"] statement = self.client.describe_statement(Id=response_id) while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]: time.sleep(self.wait_for_query_completion_sec) statement = self.client.describe_statement(Id=response_id) if statement["Status"] == "FINISHED": statement_result = self.client.get_statement_result(Id=response_id) results = statement_result.get("Records", []) while "NextToken" in statement_result: statement_result = self.client.get_statement_result( Id=response_id, NextToken=statement_result["NextToken"], ) results = results + response.get("Records", []) return self._post_process(statement_result["ColumnMetadata"], results) # pyright: ignore[reportArgumentType] if statement["Status"] in ["FAILED", "ABORTED"]: return None return None def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: function_kwargs = {"Database": self.database, "Table": table_name} if schema: function_kwargs["Schema"] = schema if self.workgroup_name: function_kwargs["WorkgroupName"] = self.workgroup_name if self.cluster_identifier: function_kwargs["ClusterIdentifier"] = self.cluster_identifier if self.db_user: function_kwargs["DbUser"] = self.db_user if self.database_credentials_secret_arn: function_kwargs["SecretArn"] = self.database_credentials_secret_arn response = self.client.describe_table(**function_kwargs) # pyright: ignore[reportArgumentType] return str([col["name"] for col in response["ColumnList"] if "name" in col])
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributecluster_identifier = field(default=None, kw_only=True)
class-attribute instance-attributedatabase = field(kw_only=True)
class-attribute instance-attributedatabase_credentials_secret_arn = field(default=None, kw_only=True)
class-attribute instance-attributedb_user = field(default=None, kw_only=True)
class-attribute instance-attributesession = field(kw_only=True)
class-attribute instance-attributewait_for_query_completion_sec = field(default=0.3, kw_only=True)
class-attribute instance-attributeworkgroup_name = field(default=None, kw_only=True)
class-attribute instance-attribute
_post_process(meta, records)classmethod
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod def _post_process(cls, meta: dict, records: list) -> list[dict[str, Any]]: columns = cls._process_columns_from_column_metadata(meta) rows = cls._process_rows_from_records(records) return cls._process_cells_from_rows_and_columns(columns, rows)
_process_cells_from_rows_and_columns(columns, rows)classmethod
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod def _process_cells_from_rows_and_columns(cls, columns: list, rows: list[list]) -> list[dict[str, Any]]: return [{column: r[idx] for idx, column in enumerate(columns)} for r in rows]
_process_columns_from_column_metadata(meta)classmethod
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod def _process_columns_from_column_metadata(cls, meta: dict) -> list: return [k["name"] for k in meta]
_process_rows_from_records(records)classmethod
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@classmethod def _process_rows_from_records(cls, records: list) -> list[list]: return [[c[list(c.keys())[0]] for c in r] for r in records]
client()
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@lazy_property() def client(self) -> RedshiftDataAPIServiceClient: return self.session.client("redshift-data")
execute_query(query)
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query(self, query: str) -> Optional[list[BaseSqlDriver.RowResult]]: rows = self.execute_query_raw(query) if rows: return [BaseSqlDriver.RowResult(row) for row in rows] return None
execute_query_raw(query)
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any]]]]: function_kwargs = {"Sql": query, "Database": self.database} if self.workgroup_name: function_kwargs["WorkgroupName"] = self.workgroup_name if self.cluster_identifier: function_kwargs["ClusterIdentifier"] = self.cluster_identifier if self.db_user: function_kwargs["DbUser"] = self.db_user if self.database_credentials_secret_arn: function_kwargs["SecretArn"] = self.database_credentials_secret_arn response = self.client.execute_statement(**function_kwargs) # pyright: ignore[reportArgumentType] response_id = response["Id"] statement = self.client.describe_statement(Id=response_id) while statement["Status"] in ["SUBMITTED", "PICKED", "STARTED"]: time.sleep(self.wait_for_query_completion_sec) statement = self.client.describe_statement(Id=response_id) if statement["Status"] == "FINISHED": statement_result = self.client.get_statement_result(Id=response_id) results = statement_result.get("Records", []) while "NextToken" in statement_result: statement_result = self.client.get_statement_result( Id=response_id, NextToken=statement_result["NextToken"], ) results = results + response.get("Records", []) return self._post_process(statement_result["ColumnMetadata"], results) # pyright: ignore[reportArgumentType] if statement["Status"] in ["FAILED", "ABORTED"]: return None return None
get_table_schema(table_name, schema=None)
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: function_kwargs = {"Database": self.database, "Table": table_name} if schema: function_kwargs["Schema"] = schema if self.workgroup_name: function_kwargs["WorkgroupName"] = self.workgroup_name if self.cluster_identifier: function_kwargs["ClusterIdentifier"] = self.cluster_identifier if self.db_user: function_kwargs["DbUser"] = self.db_user if self.database_credentials_secret_arn: function_kwargs["SecretArn"] = self.database_credentials_secret_arn response = self.client.describe_table(**function_kwargs) # pyright: ignore[reportArgumentType] return str([col["name"] for col in response["ColumnList"] if "name" in col])
validateparams(, workgroup_name)
Source Code in griptape/drivers/sql/amazon_redshift_sql_driver.py
@workgroup_name.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_params(self, _: Attribute, workgroup_name: Optional[str]) -> None: if not self.cluster_identifier and not self.workgroup_name: raise ValueError("Provide a value for one of `cluster_identifier` or `workgroup_name`") if self.cluster_identifier and self.workgroup_name: raise ValueError("Provide a value for either `cluster_identifier` or `workgroup_name`, but not both")
- On this page
- _post_process(meta, records)classmethod
- _process_cells_from_rows_and_columns(columns, rows)classmethod
- _process_columns_from_column_metadata(meta)classmethod
- _process_rows_from_records(records)classmethod
- client()
- execute_query(query)
- execute_query_raw(query)
- get_table_schema(table_name, schema=None)
- validateparams(, workgroup_name)
Could this page be better? Report a problem or suggest an addition!