csv_extraction_engine
Bases:
BaseExtractionEngine
Source Code in griptape/engines/extraction/csv_extraction_engine.py
@define class CsvExtractionEngine(BaseExtractionEngine): column_names: list[str] = field(kw_only=True) generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/system.j2")), kw_only=True) generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/csv/user.j2")), kw_only=True) format_header: Callable[[list[str]], str] = field( default=Factory(lambda: lambda value: ",".join(value)), kw_only=True ) format_row: Callable[[dict], str] = field( default=Factory(lambda: lambda value: ",".join([value or "" for value in value.values()])), kw_only=True ) def extract_artifacts( self, artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact[TextArtifact]: return ListArtifact( self._extract_rec( cast("list[TextArtifact]", artifacts.value), [TextArtifact(self.format_header(self.column_names))], rulesets=rulesets, ), item_separator="\n", ) def text_to_csv_rows(self, text: str) -> list[TextArtifact]: rows = [] with io.StringIO(text) as f: for row in csv.DictReader(f): rows.append(TextArtifact(self.format_row(row))) return rows def _extract_rec( self, artifacts: list[TextArtifact], rows: list[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.generate_system_template.render( column_names=self.column_names, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) user_prompt = self.generate_user_template.render( text=artifacts_text, ) if ( self.prompt_driver.tokenizer.count_input_tokens_left(system_prompt + user_prompt) >= self.min_response_tokens ): rows.extend( self.text_to_csv_rows( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ] ) ).value, ), ) return rows chunks = self.chunker.chunk(artifacts_text) partial_text = self.generate_user_template.render( text=chunks[0].value, ) rows.extend( self.text_to_csv_rows( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(partial_text, role=Message.USER_ROLE), ] ) ).value, ), ) return self._extract_rec(chunks[1:], rows, rulesets=rulesets)
column_names = field(kw_only=True)
class-attribute instance-attributeformat_header = field(default=Factory(lambda: lambda value: ','.join(value)), kw_only=True)
class-attribute instance-attributeformat_row = field(default=Factory(lambda: lambda value: ','.join([value or '' for value in value.values()])), kw_only=True)
class-attribute instance-attributegenerate_system_template = field(default=Factory(lambda: J2('engines/extraction/csv/system.j2')), kw_only=True)
class-attribute instance-attributegenerate_user_template = field(default=Factory(lambda: J2('engines/extraction/csv/user.j2')), kw_only=True)
class-attribute instance-attribute
_extract_rec(artifacts, rows, *, rulesets=None)
Source Code in griptape/engines/extraction/csv_extraction_engine.py
def _extract_rec( self, artifacts: list[TextArtifact], rows: list[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) system_prompt = self.generate_system_template.render( column_names=self.column_names, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) user_prompt = self.generate_user_template.render( text=artifacts_text, ) if ( self.prompt_driver.tokenizer.count_input_tokens_left(system_prompt + user_prompt) >= self.min_response_tokens ): rows.extend( self.text_to_csv_rows( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(user_prompt, role=Message.USER_ROLE), ] ) ).value, ), ) return rows chunks = self.chunker.chunk(artifacts_text) partial_text = self.generate_user_template.render( text=chunks[0].value, ) rows.extend( self.text_to_csv_rows( self.prompt_driver.run( PromptStack( messages=[ Message(system_prompt, role=Message.SYSTEM_ROLE), Message(partial_text, role=Message.USER_ROLE), ] ) ).value, ), ) return self._extract_rec(chunks[1:], rows, rulesets=rulesets)
extract_artifacts(artifacts, *, rulesets=None, **kwargs)
Source Code in griptape/engines/extraction/csv_extraction_engine.py
def extract_artifacts( self, artifacts: ListArtifact[TextArtifact], *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact[TextArtifact]: return ListArtifact( self._extract_rec( cast("list[TextArtifact]", artifacts.value), [TextArtifact(self.format_header(self.column_names))], rulesets=rulesets, ), item_separator="\n", )
text_to_csv_rows(text)
Source Code in griptape/engines/extraction/csv_extraction_engine.py
def text_to_csv_rows(self, text: str) -> list[TextArtifact]: rows = [] with io.StringIO(text) as f: for row in csv.DictReader(f): rows.append(TextArtifact(self.format_row(row))) return rows
Could this page be better? Report a problem or suggest an addition!