engines
__all__ = ['BaseEvalEngine', 'BaseExtractionEngine', 'BaseSummaryEngine', 'CsvExtractionEngine', 'EvalEngine', 'JsonExtractionEngine', 'PromptSummaryEngine', 'RagEngine']
module-attribute
Bases:
ABCSource Code in griptape/engines/eval/base_eval_engine.py
@define class BaseEvalEngine(ABC): ...
BaseExtractionEngine
Bases:
ABCSource Code in griptape/engines/extraction/base_extraction_engine.py
@define class BaseExtractionEngine(ABC): max_token_multiplier: float = field(default=0.5, kw_only=True) chunk_joiner: str = field(default="\n\n", kw_only=True) prompt_driver: BasePromptDriver = field( default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True ) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), takes_self=True, ), kw_only=True, ) @max_token_multiplier.validator # pyright: ignore[reportAttributeAccessIssue] def validate_max_token_multiplier(self, _: Attribute, max_token_multiplier: int) -> None: if max_token_multiplier > 1: raise ValueError("has to be less than or equal to 1") if max_token_multiplier <= 0: raise ValueError("has to be greater than 0") @property def max_chunker_tokens(self) -> int: return round(self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier) @property def min_response_tokens(self) -> int: return round( self.prompt_driver.tokenizer.max_input_tokens - self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier, ) def extract_text( self, text: str, *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact: return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs) @abstractmethod def extract_artifacts( self, artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact: ...
chunk_joiner = field(default='\n\n', kw_only=True)
class-attribute instance-attributechunker = field(default=Factory(lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), takes_self=True), kw_only=True)
class-attribute instance-attributemax_chunker_tokens
propertymax_token_multiplier = field(default=0.5, kw_only=True)
class-attribute instance-attributemin_response_tokens
propertyprompt_driver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True)
class-attribute instance-attribute
extract_artifacts(artifacts, *, rulesets=None, **kwargs)abstractmethod
Source Code in griptape/engines/extraction/base_extraction_engine.py
@abstractmethod def extract_artifacts( self, artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact: ...
extract_text(text, *, rulesets=None, **kwargs)
Source Code in griptape/engines/extraction/base_extraction_engine.py
def extract_text( self, text: str, *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact: return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs)
validatemax_token_multiplier(, max_token_multiplier)
Source Code in griptape/engines/extraction/base_extraction_engine.py
@max_token_multiplier.validator # pyright: ignore[reportAttributeAccessIssue] def validate_max_token_multiplier(self, _: Attribute, max_token_multiplier: int) -> None: if max_token_multiplier > 1: raise ValueError("has to be less than or equal to 1") if max_token_multiplier <= 0: raise ValueError("has to be greater than 0")
BaseSummaryEngine
Bases:
ABCSource Code in griptape/engines/summary/base_summary_engine.py
@define class BaseSummaryEngine(ABC): def summarize_text(self, text: str, *, rulesets: Optional[list[Ruleset]] = None) -> str: return self.summarize_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets).value @abstractmethod def summarize_artifacts( self, artifacts: ListArtifact, *, rulesets: Optional[list[Ruleset]] = None, ) -> TextArtifact: ...
summarize_artifacts(artifacts, *, rulesets=None)abstractmethod
Source Code in griptape/engines/summary/base_summary_engine.py
@abstractmethod def summarize_artifacts( self, artifacts: ListArtifact, *, rulesets: Optional[list[Ruleset]] = None, ) -> TextArtifact: ...
summarize_text(text, *, rulesets=None)
Source Code in griptape/engines/summary/base_summary_engine.py
def summarize_text(self, text: str, *, rulesets: Optional[list[Ruleset]] = None) -> str: return self.summarize_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets).value
CsvExtractionEngine
Bases:
BaseExtractionEngine
Source Code in griptape/engines/extraction/csv_extraction_engine.py
@define class CsvExtractionEngine(BaseExtractionEngine): column_names: list[str] = field(kw_only=True) generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) format_header: Callable[[list[str]], str] = field( default=Factory(lambda: lambda value: ",".join(value)), kw_only=True ) format_row: Callable[[dict], str] = field( default=Factory(lambda: lambda value: ",".join([value or "" for value in value.values()])), kw_only=True ) def extract_artifacts( self, artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact[TextArtifact]: return ListArtifact( self._extract_rec( cast("list[TextArtifact]", artifacts.value), [TextArtifact(self.format_header(self.column_names))], rulesets=rulesets, ), item_separator="\n", ) def text_to_csv_rows(self, text: str) -> list[TextArtifact]: rows = [] with io.StringIO(text) as f: for row in csv.DictReader(f): rows.append(TextArtifact(self.format_row(row))) return rows def _extract_rec( self, artifacts: list[TextArtifact], rows: list[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.generate_system_template.render( column_names=self.column_names, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) user_prompt = self.generate_user_template.render( text=artifacts_text, ) if ( self.prompt_driver.tokenizer.count_input_tokens_left(system_prompt + user_prompt) >= self.min_response_tokens ): rows.extend( self.text_to_csv_rows( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ] ) ).value, ), ) return rows chunks = self.chunker.chunk(artifacts_text) partial_text = self.generate_user_template.render( text=chunks[0].value, ) rows.extend( self.text_to_csv_rows( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(partial_text, role=Message.USER_ROLE), ] ) ).value, ), ) return self._extract_rec(chunks[1:], rows, rulesets=rulesets)
column_names = field(kw_only=True)
class-attribute instance-attributeformat_header = field(default=Factory(lambda: lambda value: ','.join(value)), kw_only=True)
class-attribute instance-attributeformat_row = field(default=Factory(lambda: lambda value: ','.join([value or '' for value in value.values()])), kw_only=True)
class-attribute instance-attributegenerate_system_template = field(default=Factory(lambda: J2('engines/extraction/csv/system.j2')), kw_only=True)
class-attribute instance-attributegenerate_user_template = field(default=Factory(lambda: J2('engines/extraction/csv/user.j2')), kw_only=True)
class-attribute instance-attribute
_extract_rec(artifacts, rows, *, rulesets=None)
Source Code in griptape/engines/extraction/csv_extraction_engine.py
def _extract_rec( self, artifacts: list[TextArtifact], rows: list[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.generate_system_template.render( column_names=self.column_names, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) user_prompt = self.generate_user_template.render( text=artifacts_text, ) if ( self.prompt_driver.tokenizer.count_input_tokens_left(system_prompt + user_prompt) >= self.min_response_tokens ): rows.extend( self.text_to_csv_rows( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ] ) ).value, ), ) return rows chunks = self.chunker.chunk(artifacts_text) partial_text = self.generate_user_template.render( text=chunks[0].value, ) rows.extend( self.text_to_csv_rows( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(partial_text, role=Message.USER_ROLE), ] ) ).value, ), ) return self._extract_rec(chunks[1:], rows, rulesets=rulesets)
extract_artifacts(artifacts, *, rulesets=None, **kwargs)
Source Code in griptape/engines/extraction/csv_extraction_engine.py
def extract_artifacts( self, artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact[TextArtifact]: return ListArtifact( self._extract_rec( cast("list[TextArtifact]", artifacts.value), [TextArtifact(self.format_header(self.column_names))], rulesets=rulesets, ), item_separator="\n", )
text_to_csv_rows(text)
Source Code in griptape/engines/extraction/csv_extraction_engine.py
def text_to_csv_rows(self, text: str) -> list[TextArtifact]: rows = [] with io.StringIO(text) as f: for row in csv.DictReader(f): rows.append(TextArtifact(self.format_row(row))) return rows
EvalEngine
Bases:
BaseEvalEngine
, SerializableMixin
Source Code in griptape/engines/eval/eval_engine.py
@define(kw_only=True) class EvalEngine(BaseEvalEngine, SerializableMixin): id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) name: str = field( default=Factory(lambda self: self.id, takes_self=True), metadata={"serializable": True}, ) criteria: Optional[str] = field(default=None, metadata={"serializable": True}) evaluation_steps: Optional[list[str]] = field(default=None, metadata={"serializable": True}) prompt_driver: BasePromptDriver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver)) generate_steps_system_template: J2 = field(default=Factory(lambda: J2("engines/eval/steps/system.j2"))) generate_steps_user_template: J2 = field(default=Factory(lambda: J2("engines/eval/steps/user.j2"))) generate_results_system_template: J2 = field(default=Factory(lambda: J2("engines/eval/results/system.j2"))) generate_results_user_template: J2 = field(default=Factory(lambda: J2("engines/eval/results/user.j2"))) @criteria.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_criteria(self, _: Attribute, value: Optional[str]) -> None: if value is None: if self.evaluation_steps is None: raise ValueError("either criteria or evaluation_steps must be specified") return if self.evaluation_steps is not None: raise ValueError("can't have both criteria and evaluation_steps specified") if not value: raise ValueError("criteria must not be empty") @evaluation_steps.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_evaluation_steps(self, _: Attribute, value: Optional[list[str]]) -> None: if value is None: if self.criteria is None: raise ValueError("either evaluation_steps or criteria must be specified") return if self.criteria is not None: raise ValueError("can't have both evaluation_steps and criteria specified") if not value: raise ValueError("evaluation_steps must not be empty") def evaluate(self, input: str, actual_output: str, **kwargs) -> tuple[float, str]: # noqa: A002 evaluation_params = { key.replace("_", " ").title(): value for key, value in {"input": input, "actual_output": actual_output, **kwargs}.items() } if self.evaluation_steps is None: # Need to disable validators to allow for both `criteria` and `evaluation_steps` to be set with validators.disabled(): self.evaluation_steps = self._generate_steps(evaluation_params) return self._generate_results(evaluation_params) def _generate_steps(self, evaluation_params: dict[str, str]) -> list[str]: system_prompt = self.generate_steps_system_template.render( evaluation_params=", ".join(param for param in evaluation_params), criteria=self.criteria, ) user_prompt = self.generate_steps_user_template.render() result = self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ], output_schema=STEPS_SCHEMA, ), ).to_artifact() parsed_result = json.loads(result.value) return parsed_result["steps"] def _generate_results(self, evaluation_params: dict[str, str]) -> tuple[float, str]: system_prompt = self.generate_results_system_template.render( evaluation_params=", ".join(param for param in evaluation_params), evaluation_steps=self.evaluation_steps, evaluation_text="\n\n".join(f"{key}: {value}" for key, value in evaluation_params.items()), ) user_prompt = self.generate_results_user_template.render() result = self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ], output_schema=RESULTS_SCHEMA, ), ).to_text() parsed_result = json.loads(result) # Better to have the LLM deal strictly with integers to avoid ambiguities with floating point precision. # We want the user to receive a float, however. score = float(parsed_result["score"]) / 10 reason = parsed_result["reason"] return score, reason
criteria = field(default=None, metadata={'serializable': True})
class-attribute instance-attributeevaluation_steps = field(default=None, metadata={'serializable': True})
class-attribute instance-attributegenerate_results_system_template = field(default=Factory(lambda: J2('engines/eval/results/system.j2')))
class-attribute instance-attributegenerate_results_user_template = field(default=Factory(lambda: J2('engines/eval/results/user.j2')))
class-attribute instance-attributegenerate_steps_system_template = field(default=Factory(lambda: J2('engines/eval/steps/system.j2')))
class-attribute instance-attributegenerate_steps_user_template = field(default=Factory(lambda: J2('engines/eval/steps/user.j2')))
class-attribute instance-attributeid = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={'serializable': True})
class-attribute instance-attributename = field(default=Factory(lambda self: self.id, takes_self=True), metadata={'serializable': True})
class-attribute instance-attributeprompt_driver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver))
class-attribute instance-attribute
_generate_results(evaluation_params)
Source Code in griptape/engines/eval/eval_engine.py
def _generate_results(self, evaluation_params: dict[str, str]) -> tuple[float, str]: system_prompt = self.generate_results_system_template.render( evaluation_params=", ".join(param for param in evaluation_params), evaluation_steps=self.evaluation_steps, evaluation_text="\n\n".join(f"{key}: {value}" for key, value in evaluation_params.items()), ) user_prompt = self.generate_results_user_template.render() result = self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ], output_schema=RESULTS_SCHEMA, ), ).to_text() parsed_result = json.loads(result) # Better to have the LLM deal strictly with integers to avoid ambiguities with floating point precision. # We want the user to receive a float, however. score = float(parsed_result["score"]) / 10 reason = parsed_result["reason"] return score, reason
_generate_steps(evaluation_params)
Source Code in griptape/engines/eval/eval_engine.py
def _generate_steps(self, evaluation_params: dict[str, str]) -> list[str]: system_prompt = self.generate_steps_system_template.render( evaluation_params=", ".join(param for param in evaluation_params), criteria=self.criteria, ) user_prompt = self.generate_steps_user_template.render() result = self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ], output_schema=STEPS_SCHEMA, ), ).to_artifact() parsed_result = json.loads(result.value) return parsed_result["steps"]
evaluate(input, actual_output, **kwargs)
Source Code in griptape/engines/eval/eval_engine.py
def evaluate(self, input: str, actual_output: str, **kwargs) -> tuple[float, str]: # noqa: A002 evaluation_params = { key.replace("_", " ").title(): value for key, value in {"input": input, "actual_output": actual_output, **kwargs}.items() } if self.evaluation_steps is None: # Need to disable validators to allow for both `criteria` and `evaluation_steps` to be set with validators.disabled(): self.evaluation_steps = self._generate_steps(evaluation_params) return self._generate_results(evaluation_params)
validatecriteria(, value)
Source Code in griptape/engines/eval/eval_engine.py
@criteria.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_criteria(self, _: Attribute, value: Optional[str]) -> None: if value is None: if self.evaluation_steps is None: raise ValueError("either criteria or evaluation_steps must be specified") return if self.evaluation_steps is not None: raise ValueError("can't have both criteria and evaluation_steps specified") if not value: raise ValueError("criteria must not be empty")
validateevaluation_steps(, value)
Source Code in griptape/engines/eval/eval_engine.py
@evaluation_steps.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_evaluation_steps(self, _: Attribute, value: Optional[list[str]]) -> None: if value is None: if self.criteria is None: raise ValueError("either evaluation_steps or criteria must be specified") return if self.criteria is not None: raise ValueError("can't have both evaluation_steps and criteria specified") if not value: raise ValueError("evaluation_steps must not be empty")
JsonExtractionEngine
Bases:
BaseExtractionEngine
Source Code in griptape/engines/extraction/json_extraction_engine.py
@define class JsonExtractionEngine(BaseExtractionEngine): JSON_PATTERN = r"(?s)[^\[]*(\[.*\])" template_schema: dict = field(kw_only=True) generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True) generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True) def extract_artifacts( self, artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact[JsonArtifact]: return ListArtifact( self._extract_rec(cast("list[TextArtifact]", artifacts.value), [], rulesets=rulesets), item_separator="\n", ) def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]: json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL) if json_matches: return [JsonArtifact(e) for e in json.loads(json_matches[-1])] return [] def _extract_rec( self, artifacts: list[TextArtifact], extractions: list[JsonArtifact], *, rulesets: Optional[list[Ruleset]] = None, ) -> list[JsonArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.generate_system_template.render( json_template_schema=json.dumps(self.template_schema), rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) user_prompt = self.generate_user_template.render( text=artifacts_text, ) if ( self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt) >= self.min_response_tokens ): extractions.extend( self.json_to_text_artifacts( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ] ) ).value ), ) return extractions chunks = self.chunker.chunk(artifacts_text) partial_text = self.generate_user_template.render( text=chunks[0].value, ) extractions.extend( self.json_to_text_artifacts( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(partial_text, role=Message.USER_ROLE), ] ) ).value, ), ) return self._extract_rec(chunks[1:], extractions, rulesets=rulesets)
JSON_PATTERN = '(?s)[^\\[]*(\\[.*\\])'
class-attribute instance-attributegenerate_system_template = field(default=Factory(lambda: J2('engines/extraction/json/system.j2')), kw_only=True)
class-attribute instance-attributegenerate_user_template = field(default=Factory(lambda: J2('engines/extraction/json/user.j2')), kw_only=True)
class-attribute instance-attributetemplate_schema = field(kw_only=True)
class-attribute instance-attribute
_extract_rec(artifacts, extractions, *, rulesets=None)
Source Code in griptape/engines/extraction/json_extraction_engine.py
def _extract_rec( self, artifacts: list[TextArtifact], extractions: list[JsonArtifact], *, rulesets: Optional[list[Ruleset]] = None, ) -> list[JsonArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.generate_system_template.render( json_template_schema=json.dumps(self.template_schema), rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) user_prompt = self.generate_user_template.render( text=artifacts_text, ) if ( self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt) >= self.min_response_tokens ): extractions.extend( self.json_to_text_artifacts( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ] ) ).value ), ) return extractions chunks = self.chunker.chunk(artifacts_text) partial_text = self.generate_user_template.render( text=chunks[0].value, ) extractions.extend( self.json_to_text_artifacts( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(partial_text, role=Message.USER_ROLE), ] ) ).value, ), ) return self._extract_rec(chunks[1:], extractions, rulesets=rulesets)
extract_artifacts(artifacts, *, rulesets=None, **kwargs)
Source Code in griptape/engines/extraction/json_extraction_engine.py
def extract_artifacts( self, artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact[JsonArtifact]: return ListArtifact( self._extract_rec(cast("list[TextArtifact]", artifacts.value), [], rulesets=rulesets), item_separator="\n", )
json_to_text_artifacts(json_input)
Source Code in griptape/engines/extraction/json_extraction_engine.py
def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]: json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL) if json_matches: return [JsonArtifact(e) for e in json.loads(json_matches[-1])] return []
PromptSummaryEngine
Bases:
BaseSummaryEngine
Source Code in griptape/engines/summary/prompt_summary_engine.py
@define class PromptSummaryEngine(BaseSummaryEngine): chunk_joiner: str = field(default="\n\n", kw_only=True) max_token_multiplier: float = field(default=0.5, kw_only=True) generate_system_template: J2 = field(default=Factory(lambda: J2("engines/summary/system.j2")), kw_only=True) generate_user_template: J2 = field(default=Factory(lambda: J2("engines/summary/user.j2")), kw_only=True) prompt_driver: BasePromptDriver = field( default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True ) chunker: BaseChunker = field( default=Factory( lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), takes_self=True, ), kw_only=True, ) @max_token_multiplier.validator # pyright: ignore[reportAttributeAccessIssue] def validate_allowlist(self, _: Attribute, max_token_multiplier: int) -> None: if max_token_multiplier > 1: raise ValueError("has to be less than or equal to 1") if max_token_multiplier <= 0: raise ValueError("has to be greater than 0") @property def max_chunker_tokens(self) -> int: return round(self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier) @property def min_response_tokens(self) -> int: return round( self.prompt_driver.tokenizer.max_input_tokens - self.prompt_driver.tokenizer.max_input_tokens * self.max_token_multiplier, ) def summarize_artifacts(self, artifacts: ListArtifact, *, rulesets: Optional[list[Ruleset]] = None) -> TextArtifact: return self.summarize_artifacts_rec(cast("list[TextArtifact]", artifacts.value), None, rulesets=rulesets) def summarize_artifacts_rec( self, artifacts: list[TextArtifact], summary: Optional[str] = None, rulesets: Optional[list[Ruleset]] = None, ) -> TextArtifact: if not artifacts: if summary is None: raise ValueError("No artifacts to summarize") return TextArtifact(summary) artifacts_text = self.chunk_joiner.join([a.to_text() for a in artifacts]) system_prompt = self.generate_system_template.render( summary=summary, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) user_prompt = self.generate_user_template.render(text=artifacts_text) if ( self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt) >= self.min_response_tokens ): result = self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ], ), ).to_artifact() if isinstance(result, TextArtifact): return result raise ValueError("Prompt driver did not return a TextArtifact") chunks = self.chunker.chunk(artifacts_text) partial_text = self.generate_user_template.render(text=chunks[0].value) return self.summarize_artifacts_rec( chunks[1:], self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(partial_text, role=Message.USER_ROLE), ], ), ).value, rulesets=rulesets, )
chunk_joiner = field(default='\n\n', kw_only=True)
class-attribute instance-attributechunker = field(default=Factory(lambda self: TextChunker(tokenizer=self.prompt_driver.tokenizer, max_tokens=self.max_chunker_tokens), takes_self=True), kw_only=True)
class-attribute instance-attributegenerate_system_template = field(default=Factory(lambda: J2('engines/summary/system.j2')), kw_only=True)
class-attribute instance-attributegenerate_user_template = field(default=Factory(lambda: J2('engines/summary/user.j2')), kw_only=True)
class-attribute instance-attributemax_chunker_tokens
propertymax_token_multiplier = field(default=0.5, kw_only=True)
class-attribute instance-attributemin_response_tokens
propertyprompt_driver = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver), kw_only=True)
class-attribute instance-attribute
summarize_artifacts(artifacts, *, rulesets=None)
Source Code in griptape/engines/summary/prompt_summary_engine.py
def summarize_artifacts(self, artifacts: ListArtifact, *, rulesets: Optional[list[Ruleset]] = None) -> TextArtifact: return self.summarize_artifacts_rec(cast("list[TextArtifact]", artifacts.value), None, rulesets=rulesets)
summarize_artifacts_rec(artifacts, summary=None, rulesets=None)
Source Code in griptape/engines/summary/prompt_summary_engine.py
def summarize_artifacts_rec( self, artifacts: list[TextArtifact], summary: Optional[str] = None, rulesets: Optional[list[Ruleset]] = None, ) -> TextArtifact: if not artifacts: if summary is None: raise ValueError("No artifacts to summarize") return TextArtifact(summary) artifacts_text = self.chunk_joiner.join([a.to_text() for a in artifacts]) system_prompt = self.generate_system_template.render( summary=summary, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) user_prompt = self.generate_user_template.render(text=artifacts_text) if ( self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt) >= self.min_response_tokens ): result = self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ], ), ).to_artifact() if isinstance(result, TextArtifact): return result raise ValueError("Prompt driver did not return a TextArtifact") chunks = self.chunker.chunk(artifacts_text) partial_text = self.generate_user_template.render(text=chunks[0].value) return self.summarize_artifacts_rec( chunks[1:], self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(partial_text, role=Message.USER_ROLE), ], ), ).value, rulesets=rulesets, )
validateallowlist(, max_token_multiplier)
Source Code in griptape/engines/summary/prompt_summary_engine.py
@max_token_multiplier.validator # pyright: ignore[reportAttributeAccessIssue] def validate_allowlist(self, _: Attribute, max_token_multiplier: int) -> None: if max_token_multiplier > 1: raise ValueError("has to be less than or equal to 1") if max_token_multiplier <= 0: raise ValueError("has to be greater than 0")
RagEngine
Source Code in griptape/engines/rag/rag_engine.py
@define(kw_only=True) class RagEngine: query_stage: Optional[QueryRagStage] = field(default=None) retrieval_stage: Optional[RetrievalRagStage] = field(default=None) response_stage: Optional[ResponseRagStage] = field(default=None) def __attrs_post_init__(self) -> None: modules = [] if self.query_stage is not None: modules.extend(self.query_stage.modules) if self.retrieval_stage is not None: modules.extend(self.retrieval_stage.modules) if self.response_stage is not None: modules.extend(self.response_stage.modules) module_names = [m.name for m in modules] if len(module_names) > len(set(module_names)): raise ValueError("module names have to be unique") def process_query(self, query: str) -> RagContext: return self.process(RagContext(query=query)) def process(self, context: RagContext) -> RagContext: if self.query_stage: context = self.query_stage.run(context) if self.retrieval_stage: context = self.retrieval_stage.run(context) if self.response_stage: context = self.response_stage.run(context) return context
query_stage = field(default=None)
class-attribute instance-attributeresponse_stage = field(default=None)
class-attribute instance-attributeretrieval_stage = field(default=None)
class-attribute instance-attribute
attrs_post_init()
Source Code in griptape/engines/rag/rag_engine.py
def __attrs_post_init__(self) -> None: modules = [] if self.query_stage is not None: modules.extend(self.query_stage.modules) if self.retrieval_stage is not None: modules.extend(self.retrieval_stage.modules) if self.response_stage is not None: modules.extend(self.response_stage.modules) module_names = [m.name for m in modules] if len(module_names) > len(set(module_names)): raise ValueError("module names have to be unique")
process(context)
Source Code in griptape/engines/rag/rag_engine.py
def process(self, context: RagContext) -> RagContext: if self.query_stage: context = self.query_stage.run(context) if self.retrieval_stage: context = self.retrieval_stage.run(context) if self.response_stage: context = self.response_stage.run(context) return context
process_query(query)
Source Code in griptape/engines/rag/rag_engine.py
def process_query(self, query: str) -> RagContext: return self.process(RagContext(query=query))
Could this page be better? Report a problem or suggest an addition!