amazon_sagemaker_jumpstart
Adapted from the Griptape AI Framework documentation.
__all__ = ['AmazonSageMakerJumpstartEmbeddingDriver']
module-attribute
Bases:
BaseEmbeddingDriver
Source Code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
@define class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver): session: boto3.Session = field(default=Factory(lambda: import_optional_dependency("boto3").Session()), kw_only=True) endpoint: str = field(kw_only=True, metadata={"serializable": True}) custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) _client: Optional[SageMakerRuntimeClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> SageMakerRuntimeClient: return self.session.client("sagemaker-runtime") def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: payload = {"text_inputs": chunk, "mode": "embedding"} endpoint_response = self.client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", Body=json.dumps(payload).encode("utf-8"), CustomAttributes=self.custom_attributes, **( {"InferenceComponentName": self.inference_component_name} if self.inference_component_name is not None else {} ), ) response = json.loads(endpoint_response.get("Body").read().decode("utf-8")) if "embedding" in response: embedding = response["embedding"] if embedding: if isinstance(embedding[0], list): return embedding[0] return embedding raise ValueError("model response is empty") raise ValueError("invalid response from model")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributecustom_attributes = field(default='accept_eula=true', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeendpoint = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributeinference_component_name = field(default=None, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributesession = field(default=Factory(lambda: import_optional_dependency('boto3').Session()), kw_only=True)
class-attribute instance-attribute
client()
Source Code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
@lazy_property() def client(self) -> SageMakerRuntimeClient: return self.session.client("sagemaker-runtime")
try_embed_chunk(chunk, **kwargs)
Source Code in griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py
def try_embed_chunk(self, chunk: str, **kwargs) -> list[float]: payload = {"text_inputs": chunk, "mode": "embedding"} endpoint_response = self.client.invoke_endpoint( EndpointName=self.endpoint, ContentType="application/json", Body=json.dumps(payload).encode("utf-8"), CustomAttributes=self.custom_attributes, **( {"InferenceComponentName": self.inference_component_name} if self.inference_component_name is not None else {} ), ) response = json.loads(endpoint_response.get("Body").read().decode("utf-8")) if "embedding" in response: embedding = response["embedding"] if embedding: if isinstance(embedding[0], list): return embedding[0] return embedding raise ValueError("model response is empty") raise ValueError("invalid response from model")
- On this page
- client()
- try_embed_chunk(chunk, **kwargs)
Could this page be better? Report a problem or suggest an addition!