cohere
Adapted from the Griptape AI Framework documentation.
__all__ = ['CohereRerankDriver']
module-attribute
Bases:
BaseRerankDriver
Source Code in griptape/drivers/rerank/cohere_rerank_driver.py
@define(kw_only=True) class CohereRerankDriver(BaseRerankDriver): model: str = field(default="rerank-english-v3.0", metadata={"serializable": True}) top_n: Optional[int] = field(default=None) api_key: str = field(metadata={"serializable": True}) client: Client = field( default=Factory(lambda self: import_optional_dependency("cohere").Client(self.api_key), takes_self=True), ) def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: # Cohere errors out if passed "empty" documents or no documents at all artifacts_dict = {str(hash(a.to_text())): a for a in artifacts if a} if artifacts_dict: response = self.client.rerank( model=self.model, query=query, documents=[a.to_text() for a in artifacts_dict.values()], return_documents=True, top_n=self.top_n, ) return [artifacts_dict[str(hash(r.document.text))] for r in response.results if r.document is not None] return []
api_key = field(metadata={'serializable': True})
class-attribute instance-attributeclient = field(default=Factory(lambda self: import_optional_dependency('cohere').Client(self.api_key), takes_self=True))
class-attribute instance-attributemodel = field(default='rerank-english-v3.0', metadata={'serializable': True})
class-attribute instance-attributetop_n = field(default=None)
class-attribute instance-attribute
run(query, artifacts)
Source Code in griptape/drivers/rerank/cohere_rerank_driver.py
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: # Cohere errors out if passed "empty" documents or no documents at all artifacts_dict = {str(hash(a.to_text())): a for a in artifacts if a} if artifacts_dict: response = self.client.rerank( model=self.model, query=query, documents=[a.to_text() for a in artifacts_dict.values()], return_documents=True, top_n=self.top_n, ) return [artifacts_dict[str(hash(r.document.text))] for r in response.results if r.document is not None] return []
- On this page
- run(query, artifacts)
Could this page be better? Report a problem or suggest an addition!