nvidia_nim_rerank_driver
Adapted from the Griptape AI Framework documentation.
Bases:
BaseRerankDriver
Source Code in griptape/drivers/rerank/nvidia_nim_rerank_driver.py
@define(kw_only=True) class NvidiaNimRerankDriver(BaseRerankDriver): """Nvidia Rerank Driver.""" model: str = field() base_url: str = field() truncate: Literal["NONE", "END"] = field(default="NONE") headers: dict = field(factory=dict) def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: if not artifacts: return [] response = requests.post( url=f"{self.base_url.rstrip('/')}/v1/ranking", json=self._get_body(query, artifacts), headers=self.headers, ) response.raise_for_status() ranked_artifacts = [] for ranking in response.json()["rankings"]: artifact = artifacts[ranking["index"]] artifact.meta.update({"logit": ranking["logit"], "usage": ranking.get("usage")}) ranked_artifacts.append(artifact) return ranked_artifacts def _get_body(self, query: str, artifacts: list[TextArtifact]) -> dict: return { "model": self.model, "query": {"text": query}, "passages": [{"text": artifact.value} for artifact in artifacts], "truncate": self.truncate, }
base_url = field()
class-attribute instance-attributeheaders = field(factory=dict)
class-attribute instance-attributemodel = field()
class-attribute instance-attributetruncate = field(default='NONE')
class-attribute instance-attribute
_get_body(query, artifacts)
Source Code in griptape/drivers/rerank/nvidia_nim_rerank_driver.py
def _get_body(self, query: str, artifacts: list[TextArtifact]) -> dict: return { "model": self.model, "query": {"text": query}, "passages": [{"text": artifact.value} for artifact in artifacts], "truncate": self.truncate, }
run(query, artifacts)
Source Code in griptape/drivers/rerank/nvidia_nim_rerank_driver.py
def run(self, query: str, artifacts: list[TextArtifact]) -> list[TextArtifact]: if not artifacts: return [] response = requests.post( url=f"{self.base_url.rstrip('/')}/v1/ranking", json=self._get_body(query, artifacts), headers=self.headers, ) response.raise_for_status() ranked_artifacts = [] for ranking in response.json()["rankings"]: artifact = artifacts[ranking["index"]] artifact.meta.update({"logit": ranking["logit"], "usage": ranking.get("usage")}) ranked_artifacts.append(artifact) return ranked_artifacts
- On this page
- _get_body(query, artifacts)
- run(query, artifacts)
Could this page be better? Report a problem or suggest an addition!