json_extraction_engine

Bases: BaseExtractionEngine

Source Code in griptape/engines/extraction/json_extraction_engine.py
@define
class JsonExtractionEngine(BaseExtractionEngine):
    JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

    template_schema: dict = field(kw_only=True)
    generate_system_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/system.j2")), kw_only=True)
    generate_user_template: J2 = field(default=Factory(lambda: J2("engines/extraction/json/user.j2")), kw_only=True)

    def extract_artifacts(
        self,
        artifacts: ListArtifact[TextArtifact],
        *,
        rulesets: Optional[list[Ruleset]] = None,
        **kwargs,
    ) -> ListArtifact[JsonArtifact]:
        return ListArtifact(
            self._extract_rec(cast("list[TextArtifact]", artifacts.value), [], rulesets=rulesets),
            item_separator="\n",
        )

    def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]:
        json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL)

        if json_matches:
            return [JsonArtifact(e) for e in json.loads(json_matches[-1])]
        return []

    def _extract_rec(
        self,
        artifacts: list[TextArtifact],
        extractions: list[JsonArtifact],
        *,
        rulesets: Optional[list[Ruleset]] = None,
    ) -> list[JsonArtifact]:
        artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
        system_prompt = self.generate_system_template.render(
            json_template_schema=json.dumps(self.template_schema),
            rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
        )
        user_prompt = self.generate_user_template.render(
            text=artifacts_text,
        )

        if (
            self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt)
            >= self.min_response_tokens
        ):
            extractions.extend(
                self.json_to_text_artifacts(
                    self.prompt_driver.run(
                        PromptStack(
                            messages=[
                                Message(system_prompt, role=Message.SYSTEM_ROLE),
                                Message(user_prompt, role=Message.USER_ROLE),
                            ]
                        )
                    ).value
                ),
            )

            return extractions
        chunks = self.chunker.chunk(artifacts_text)
        partial_text = self.generate_user_template.render(
            text=chunks[0].value,
        )

        extractions.extend(
            self.json_to_text_artifacts(
                self.prompt_driver.run(
                    PromptStack(
                        messages=[
                            Message(system_prompt, role=Message.SYSTEM_ROLE),
                            Message(partial_text, role=Message.USER_ROLE),
                        ]
                    )
                ).value,
            ),
        )

        return self._extract_rec(chunks[1:], extractions, rulesets=rulesets)
  • JSON_PATTERN = '(?s)[^\\[]*(\\[.*\\])' class-attribute instance-attribute

  • generate_system_template = field(default=Factory(lambda: J2('engines/extraction/json/system.j2')), kw_only=True) class-attribute instance-attribute

  • generate_user_template = field(default=Factory(lambda: J2('engines/extraction/json/user.j2')), kw_only=True) class-attribute instance-attribute

  • template_schema = field(kw_only=True) class-attribute instance-attribute

_extract_rec(artifacts, extractions, *, rulesets=None)

Source Code in griptape/engines/extraction/json_extraction_engine.py
def _extract_rec(
    self,
    artifacts: list[TextArtifact],
    extractions: list[JsonArtifact],
    *,
    rulesets: Optional[list[Ruleset]] = None,
) -> list[JsonArtifact]:
    artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
    system_prompt = self.generate_system_template.render(
        json_template_schema=json.dumps(self.template_schema),
        rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
    )
    user_prompt = self.generate_user_template.render(
        text=artifacts_text,
    )

    if (
        self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt)
        >= self.min_response_tokens
    ):
        extractions.extend(
            self.json_to_text_artifacts(
                self.prompt_driver.run(
                    PromptStack(
                        messages=[
                            Message(system_prompt, role=Message.SYSTEM_ROLE),
                            Message(user_prompt, role=Message.USER_ROLE),
                        ]
                    )
                ).value
            ),
        )

        return extractions
    chunks = self.chunker.chunk(artifacts_text)
    partial_text = self.generate_user_template.render(
        text=chunks[0].value,
    )

    extractions.extend(
        self.json_to_text_artifacts(
            self.prompt_driver.run(
                PromptStack(
                    messages=[
                        Message(system_prompt, role=Message.SYSTEM_ROLE),
                        Message(partial_text, role=Message.USER_ROLE),
                    ]
                )
            ).value,
        ),
    )

    return self._extract_rec(chunks[1:], extractions, rulesets=rulesets)

extract_artifacts(artifacts, *, rulesets=None, **kwargs)

Source Code in griptape/engines/extraction/json_extraction_engine.py
def extract_artifacts(
    self,
    artifacts: ListArtifact[TextArtifact],
    *,
    rulesets: Optional[list[Ruleset]] = None,
    **kwargs,
) -> ListArtifact[JsonArtifact]:
    return ListArtifact(
        self._extract_rec(cast("list[TextArtifact]", artifacts.value), [], rulesets=rulesets),
        item_separator="\n",
    )

json_to_text_artifacts(json_input)

Source Code in griptape/engines/extraction/json_extraction_engine.py
def json_to_text_artifacts(self, json_input: str) -> list[JsonArtifact]:
    json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL)

    if json_matches:
        return [JsonArtifact(e) for e in json.loads(json_matches[-1])]
    return []

Could this page be better? Report a problem or suggest an addition!