local
Adapted from the Griptape AI Framework documentation.
__all__ = ['LocalRerankDriver']
module-attribute
Bases:
BaseRerankDriver
, FuturesExecutorMixin
Source Code in griptape/drivers/rerank/local_rerank_driver.py
@define(kw_only=True) class LocalRerankDriver(BaseRerankDriver, FuturesExecutorMixin): calculate_relatedness: Callable = field(default=lambda x, y: dot(x, y) / (norm(x) * norm(y))) embedding_driver: BaseEmbeddingDriver = field( kw_only=True, default=Factory(lambda: Defaults.drivers_config.embedding_driver), metadata={"serializable": True} ) def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: query_embedding = self.embedding_driver.embed(query, vector_operation="query") with self.create_futures_executor() as futures_executor: artifact_embeddings = execute_futures_list( [ futures_executor.submit( with_contextvars(self.embedding_driver.embed_text_artifact), a, vector_operation="upsert" ) for a in artifacts ], ) artifacts_and_relatednesses = [ (artifact, self.calculate_relatedness(query_embedding, artifact_embedding)) for artifact, artifact_embedding in zip(artifacts, artifact_embeddings) ] artifacts_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True) return [artifact for artifact, _ in artifacts_and_relatednesses]
calculate_relatedness = field(default=lambda x, y: dot(x, y) / norm(x) * norm(y))
class-attribute instance-attributeembedding_driver = field(kw_only=True, default=Factory(lambda: Defaults.drivers_config.embedding_driver), metadata={'serializable': True})
class-attribute instance-attribute
run(query, artifacts)
Source Code in griptape/drivers/rerank/local_rerank_driver.py
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: query_embedding = self.embedding_driver.embed(query, vector_operation="query") with self.create_futures_executor() as futures_executor: artifact_embeddings = execute_futures_list( [ futures_executor.submit( with_contextvars(self.embedding_driver.embed_text_artifact), a, vector_operation="upsert" ) for a in artifacts ], ) artifacts_and_relatednesses = [ (artifact, self.calculate_relatedness(query_embedding, artifact_embedding)) for artifact, artifact_embedding in zip(artifacts, artifact_embeddings) ] artifacts_and_relatednesses.sort(key=operator.itemgetter(1), reverse=True) return [artifact for artifact, _ in artifacts_and_relatednesses]
- On this page
- run(query, artifacts)
Could this page be better? Report a problem or suggest an addition!