• __all__ = ['BaseSchema', 'Bytes', 'PolymorphicSchema', 'PydanticModel', 'Union'] module-attribute

Bases:

Schema
Source Code in griptape/schemas/base_schema.py
class BaseSchema(Schema):
    class Meta:
        unknown = INCLUDE

    DATACLASS_TYPE_MAPPING = {
        **Schema.TYPE_MAPPING,
        dict: fields.Dict,
        bytes: Bytes,
        Any: fields.Raw,
        BaseModel: PydanticModel,
    }

    @classmethod
    def from_attrs_cls(
        cls,
        attrs_cls: type,
        *,
        types_overrides: Optional[dict[str, type]] = None,
        serializable_overrides: Optional[dict[str, bool]] = None,
    ) -> type:
        """Generate a Schema from an attrs class.

        Args:
            attrs_cls: An attrs class.
            types_overrides: A dictionary of types to override when resolving types.
            serializable_overrides: A dictionary of field names to whether they are serializable.
        """
        from marshmallow import post_load

        if serializable_overrides is None:
            serializable_overrides = {}

        class SubSchema(cls):
            @post_load
            def make_obj(self, data: Any, **kwargs) -> Any:
                # Map the serialized keys to their correct deserialization keys
                fields = attrs.fields_dict(attrs_cls)
                for key in list(data):
                    if key in fields:
                        field = fields[key]
                        if field.metadata.get("deserialization_key"):
                            data[field.metadata["deserialization_key"]] = data.pop(key)

                return attrs_cls(**data)

        cls._resolve_types(attrs_cls, types_override=types_overrides)
        fields = {}
        for field in attrs.fields(attrs_cls):
            field_key = field.alias or field.name
            if serializable_overrides.get(field_key, field.metadata.get("serializable", False)):
                fields[field_key] = cls._get_field_for_type(
                    field.type,
                    serialization_key=field.metadata.get("serialization_key"),
                    types_overrides=types_overrides,
                )
        return SubSchema.from_dict(fields, name=f"{attrs_cls.__name__}Schema")

    @classmethod
    def _get_field_for_type(
        cls,
        field_type: type,
        serialization_key: Optional[str] = None,
        types_overrides: Optional[dict[str, type]] = None,
    ) -> fields.Field | fields.Nested:
        """Generate a marshmallow Field instance from a Python type.

        Args:
            field_type: A field type.
            serialization_key: The key to pull the data from before serializing.
            types_overrides: A dictionary of types to override when resolving types.
        """
        from griptape.schemas.polymorphic_schema import PolymorphicSchema

        field_class, args, optional = cls._get_field_type_info(field_type)

        if field_class is None:
            return fields.Constant(None, allow_none=True)

        # Resolve TypeVars to their bound type
        if isinstance(field_class, TypeVar):
            field_class = field_class.__bound__
            if field_class is None:
                return fields.Raw(allow_none=optional, attribute=serialization_key)

        if cls._is_union(field_type):
            return cls._handle_union(
                field_type,
                optional=optional,
                serialization_key=serialization_key,
            )
        if attrs.has(field_class):
            if ABC in field_class.__bases__:
                schema = PolymorphicSchema(field_class, types_overrides=types_overrides)
            else:
                schema = cls.from_attrs_cls(field_class, types_overrides=types_overrides)
            return fields.Nested(schema, allow_none=optional, attribute=serialization_key)
        if cls._is_enum(field_type):
            return fields.String(allow_none=optional, attribute=serialization_key)
        if cls._is_list_sequence(field_class):
            if args:
                return cls._handle_list(
                    args[0],
                    optional=optional,
                    serialization_key=serialization_key,
                )
            raise ValueError(f"Missing type for list field: {field_type}")
        field_class = cls.DATACLASS_TYPE_MAPPING.get(field_class, fields.Raw)

        return field_class(allow_none=optional, attribute=serialization_key)

    @classmethod
    def _handle_list(
        cls,
        list_type: type,
        *,
        optional: bool,
        serialization_key: Optional[str] = None,
    ) -> fields.Field:
        """Handle List Fields, including Union Types.

        Args:
            list_type: The List type to handle.
            optional: Whether the List can be none.
            serialization_key: The key to pull the data from before serializing.

        Returns:
            A marshmallow List field.
        """
        if cls._is_union(list_type):
            instance = cls._handle_union(
                list_type,
                optional=optional,
                serialization_key=serialization_key,
            )
        else:
            instance = cls._get_field_for_type(list_type, serialization_key=serialization_key)
        return fields.List(cls_or_instance=instance, allow_none=optional, attribute=serialization_key)

    @classmethod
    def _handle_union(
        cls,
        union_type: type,
        *,
        optional: bool,
        serialization_key: Optional[str] = None,
    ) -> fields.Field:
        """Handle Union Fields, including Unions with List Types.

        Args:
            union_type: The Union Type to handle.
            optional: Whether the Union can be None.
            serialization_key: The key to pull the data from before serializing.

        Returns:
            A marshmallow Union field.
        """
        candidate_fields = [cls._get_field_for_type(arg) for arg in get_args(union_type) if arg is not type(None)]
        optional_args = [arg is None for arg in get_args(union_type)]
        if optional_args:
            optional = True
        if not candidate_fields:
            raise ValueError(f"Unsupported UnionType field: {union_type}")

        return UnionField(fields=candidate_fields, allow_none=optional, attribute=serialization_key)

    @classmethod
    def _get_field_type_info(cls, field_type: type) -> tuple[type, tuple[type, ...], bool]:
        """Get information about a field type.

        Args:
            field_type: A field type.
        """
        origin = get_origin(field_type) or field_type
        args = get_args(field_type)
        optional = False

        if origin is Union:
            origin = args[0]
            if len(args) > 1 and args[1] is type(None):
                optional = True

            origin, args, _ = cls._get_field_type_info(origin)
        elif origin is Literal:
            origin = type(args[0])
            args = ()

        return origin, args, optional

    @classmethod
    def _resolve_types(cls, attrs_cls: type, types_override: Optional[dict[str, type]] = None) -> None:
        """Resolve types in an attrs class.

        Args:
            attrs_cls: An attrs class.
            types_override: A dictionary of types to override.
        """
        from collections.abc import Sequence
        from typing import Any

        from pydantic import BaseModel
        from schema import Schema

        from griptape.artifacts import (
            ActionArtifact,
            AudioArtifact,
            BaseArtifact,
            BlobArtifact,
            BooleanArtifact,
            ErrorArtifact,
            GenericArtifact,
            ImageArtifact,
            ImageUrlArtifact,
            InfoArtifact,
            JsonArtifact,
            ListArtifact,
            TextArtifact,
        )
        from griptape.common import (
            BaseDeltaMessageContent,
            BaseMessageContent,
            Message,
            PromptStack,
            Reference,
            ToolAction,
        )
        from griptape.drivers.assistant import BaseAssistantDriver
        from griptape.drivers.audio_transcription import BaseAudioTranscriptionDriver
        from griptape.drivers.embedding import BaseEmbeddingDriver
        from griptape.drivers.file_manager import BaseFileManagerDriver
        from griptape.drivers.image_generation import BaseImageGenerationDriver, BaseMultiModelImageGenerationDriver
        from griptape.drivers.image_generation_model import BaseImageGenerationModelDriver
        from griptape.drivers.memory.conversation import BaseConversationMemoryDriver
        from griptape.drivers.observability import BaseObservabilityDriver
        from griptape.drivers.prompt import BasePromptDriver
        from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
        from griptape.drivers.ruleset import BaseRulesetDriver
        from griptape.drivers.sql import BaseSqlDriver
        from griptape.drivers.structure_run import BaseStructureRunDriver
        from griptape.drivers.text_to_speech import BaseTextToSpeechDriver
        from griptape.drivers.vector import BaseVectorStoreDriver
        from griptape.drivers.web_scraper import BaseWebScraperDriver
        from griptape.drivers.web_search import BaseWebSearchDriver
        from griptape.engines.rag import RagContext
        from griptape.events import EventListener
        from griptape.memory import TaskMemory
        from griptape.memory.meta import BaseMetaEntry
        from griptape.memory.structure import BaseConversationMemory, Run
        from griptape.memory.task.storage import BaseArtifactStorage
        from griptape.rules.base_rule import BaseRule
        from griptape.rules.ruleset import Ruleset
        from griptape.structures import Structure
        from griptape.tasks import BaseTask
        from griptape.tokenizers import BaseTokenizer
        from griptape.tools import BaseTool
        from griptape.utils import import_optional_dependency, is_dependency_installed

        if types_override is None:
            types_override = {}

        attrs.resolve_types(
            attrs_cls,
            localns={
                "Any": Any,
                "BasePromptDriver": BasePromptDriver,
                "BaseEmbeddingDriver": BaseEmbeddingDriver,
                "BaseVectorStoreDriver": BaseVectorStoreDriver,
                "BaseTextToSpeechDriver": BaseTextToSpeechDriver,
                "BaseAudioTranscriptionDriver": BaseAudioTranscriptionDriver,
                "BaseConversationMemoryDriver": BaseConversationMemoryDriver,
                "BaseRulesetDriver": BaseRulesetDriver,
                "BaseImageGenerationDriver": BaseImageGenerationDriver,
                "BaseMultiModelImageGenerationDriver": BaseMultiModelImageGenerationDriver,
                "BaseImageGenerationModelDriver": BaseImageGenerationModelDriver,
                "BaseWebSearchDriver": BaseWebSearchDriver,
                "BaseWebScraperDriver": BaseWebScraperDriver,
                "BaseFileManagerDriver": BaseFileManagerDriver,
                "BaseSqlDriver": BaseSqlDriver,
                "BaseObservabilityDriver": BaseObservabilityDriver,
                "BaseAssistantDriver": BaseAssistantDriver,
                "BaseStructureRunDriver": BaseStructureRunDriver,
                "BaseArtifact": BaseArtifact,
                "BaseMetaEntry": BaseMetaEntry,
                "PromptStack": PromptStack,
                "EventListener": EventListener,
                "BaseMessageContent": BaseMessageContent,
                "BaseDeltaMessageContent": BaseDeltaMessageContent,
                "BaseTool": BaseTool,
                "BaseTask": BaseTask,
                "TextArtifact": TextArtifact,
                "ImageArtifact": ImageArtifact,
                "ImageUrlArtifact": ImageUrlArtifact,
                "ErrorArtifact": ErrorArtifact,
                "InfoArtifact": InfoArtifact,
                "JsonArtifact": JsonArtifact,
                "BlobArtifact": BlobArtifact,
                "BooleanArtifact": BooleanArtifact,
                "ListArtifact": ListArtifact,
                "AudioArtifact": AudioArtifact,
                "ActionArtifact": ActionArtifact,
                "GenericArtifact": GenericArtifact,
                "Usage": Message.Usage,
                "Structure": Structure,
                "BaseTokenizer": BaseTokenizer,
                "ToolAction": ToolAction,
                "Reference": Reference,
                "Run": Run,
                "Sequence": Sequence,
                "TaskMemory": TaskMemory,
                "State": BaseTask.State,
                "BaseConversationMemory": BaseConversationMemory,
                "BaseArtifactStorage": BaseArtifactStorage,
                "BaseRule": BaseRule,
                "Ruleset": Ruleset,
                "StructuredOutputStrategy": StructuredOutputStrategy,
                "RagContext": RagContext,
                # Third party modules
                "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any,
                "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any,
                "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel
                if is_dependency_installed("google.generativeai")
                else Any,
                "boto3": import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any,
                "Anthropic": import_optional_dependency("anthropic").Anthropic
                if is_dependency_installed("anthropic")
                else Any,
                "BedrockRuntimeClient": import_optional_dependency("mypy_boto3_bedrock_runtime").BedrockRuntimeClient
                if is_dependency_installed("mypy_boto3_bedrock_runtime")
                else Any,
                "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any,
                "Schema": Schema,
                "BaseModel": BaseModel,
                **types_override,
            },
        )

    @classmethod
    def _is_list_sequence(cls, field_type: type | _SpecialForm) -> bool:
        if isinstance(field_type, type):
            if issubclass(field_type, str) or issubclass(field_type, bytes) or issubclass(field_type, tuple):
                return False
            return issubclass(field_type, Sequence)
        return False

    @classmethod
    def _is_union(cls, field_type: type) -> bool:
        return field_type is Union or get_origin(field_type) is Union

    @classmethod
    def _is_enum(cls, field_type: type) -> bool:
        return isinstance(field_type, type) and issubclass(field_type, Enum)
  • DATACLASS_TYPE_MAPPING = {None: Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw, BaseModel: PydanticModel} class-attribute instance-attribute

Meta

Source Code in griptape/schemas/base_schema.py
class Meta:
    unknown = INCLUDE
  • unknown = INCLUDE class-attribute instance-attribute

_get_field_for_type(field_type, serialization_key=None, types_overrides=None)classmethod

Generate a marshmallow Field instance from a Python type.

Parameters

NameTypeDescriptionDefault
field_typetypeA field type.
required
serialization_keyOptional[str]The key to pull the data from before serializing.
None
types_overridesOptional[dict[str, type]]A dictionary of types to override when resolving types.
None
Source Code in griptape/schemas/base_schema.py
@classmethod
def _get_field_for_type(
    cls,
    field_type: type,
    serialization_key: Optional[str] = None,
    types_overrides: Optional[dict[str, type]] = None,
) -> fields.Field | fields.Nested:
    """Generate a marshmallow Field instance from a Python type.

    Args:
        field_type: A field type.
        serialization_key: The key to pull the data from before serializing.
        types_overrides: A dictionary of types to override when resolving types.
    """
    from griptape.schemas.polymorphic_schema import PolymorphicSchema

    field_class, args, optional = cls._get_field_type_info(field_type)

    if field_class is None:
        return fields.Constant(None, allow_none=True)

    # Resolve TypeVars to their bound type
    if isinstance(field_class, TypeVar):
        field_class = field_class.__bound__
        if field_class is None:
            return fields.Raw(allow_none=optional, attribute=serialization_key)

    if cls._is_union(field_type):
        return cls._handle_union(
            field_type,
            optional=optional,
            serialization_key=serialization_key,
        )
    if attrs.has(field_class):
        if ABC in field_class.__bases__:
            schema = PolymorphicSchema(field_class, types_overrides=types_overrides)
        else:
            schema = cls.from_attrs_cls(field_class, types_overrides=types_overrides)
        return fields.Nested(schema, allow_none=optional, attribute=serialization_key)
    if cls._is_enum(field_type):
        return fields.String(allow_none=optional, attribute=serialization_key)
    if cls._is_list_sequence(field_class):
        if args:
            return cls._handle_list(
                args[0],
                optional=optional,
                serialization_key=serialization_key,
            )
        raise ValueError(f"Missing type for list field: {field_type}")
    field_class = cls.DATACLASS_TYPE_MAPPING.get(field_class, fields.Raw)

    return field_class(allow_none=optional, attribute=serialization_key)

_get_field_type_info(field_type)classmethod

Get information about a field type.

Parameters

NameTypeDescriptionDefault
field_typetypeA field type.
required
Source Code in griptape/schemas/base_schema.py
@classmethod
def _get_field_type_info(cls, field_type: type) -> tuple[type, tuple[type, ...], bool]:
    """Get information about a field type.

    Args:
        field_type: A field type.
    """
    origin = get_origin(field_type) or field_type
    args = get_args(field_type)
    optional = False

    if origin is Union:
        origin = args[0]
        if len(args) > 1 and args[1] is type(None):
            optional = True

        origin, args, _ = cls._get_field_type_info(origin)
    elif origin is Literal:
        origin = type(args[0])
        args = ()

    return origin, args, optional

_handle_list(list_type, *, optional, serialization_key=None)classmethod

Handle List Fields, including Union Types.

Parameters

NameTypeDescriptionDefault
list_typetypeThe List type to handle.
required
optionalboolWhether the List can be none.
required
serialization_keyOptional[str]The key to pull the data from before serializing.
None

Returns

TypeDescription
FieldA marshmallow List field.
Source Code in griptape/schemas/base_schema.py
@classmethod
def _handle_list(
    cls,
    list_type: type,
    *,
    optional: bool,
    serialization_key: Optional[str] = None,
) -> fields.Field:
    """Handle List Fields, including Union Types.

    Args:
        list_type: The List type to handle.
        optional: Whether the List can be none.
        serialization_key: The key to pull the data from before serializing.

    Returns:
        A marshmallow List field.
    """
    if cls._is_union(list_type):
        instance = cls._handle_union(
            list_type,
            optional=optional,
            serialization_key=serialization_key,
        )
    else:
        instance = cls._get_field_for_type(list_type, serialization_key=serialization_key)
    return fields.List(cls_or_instance=instance, allow_none=optional, attribute=serialization_key)

_handle_union(union_type, *, optional, serialization_key=None)classmethod

Handle Union Fields, including Unions with List Types.

Parameters

NameTypeDescriptionDefault
union_typetypeThe Union Type to handle.
required
optionalboolWhether the Union can be None.
required
serialization_keyOptional[str]The key to pull the data from before serializing.
None

Returns

TypeDescription
FieldA marshmallow Union field.
Source Code in griptape/schemas/base_schema.py
@classmethod
def _handle_union(
    cls,
    union_type: type,
    *,
    optional: bool,
    serialization_key: Optional[str] = None,
) -> fields.Field:
    """Handle Union Fields, including Unions with List Types.

    Args:
        union_type: The Union Type to handle.
        optional: Whether the Union can be None.
        serialization_key: The key to pull the data from before serializing.

    Returns:
        A marshmallow Union field.
    """
    candidate_fields = [cls._get_field_for_type(arg) for arg in get_args(union_type) if arg is not type(None)]
    optional_args = [arg is None for arg in get_args(union_type)]
    if optional_args:
        optional = True
    if not candidate_fields:
        raise ValueError(f"Unsupported UnionType field: {union_type}")

    return UnionField(fields=candidate_fields, allow_none=optional, attribute=serialization_key)

_is_enum(field_type)classmethod

Source Code in griptape/schemas/base_schema.py
@classmethod
def _is_enum(cls, field_type: type) -> bool:
    return isinstance(field_type, type) and issubclass(field_type, Enum)

_is_list_sequence(field_type)classmethod

Source Code in griptape/schemas/base_schema.py
@classmethod
def _is_list_sequence(cls, field_type: type | _SpecialForm) -> bool:
    if isinstance(field_type, type):
        if issubclass(field_type, str) or issubclass(field_type, bytes) or issubclass(field_type, tuple):
            return False
        return issubclass(field_type, Sequence)
    return False

_is_union(field_type)classmethod

Source Code in griptape/schemas/base_schema.py
@classmethod
def _is_union(cls, field_type: type) -> bool:
    return field_type is Union or get_origin(field_type) is Union

_resolve_types(attrs_cls, types_override=None)classmethod

Resolve types in an attrs class.

Parameters

NameTypeDescriptionDefault
attrs_clstypeAn attrs class.
required
types_overrideOptional[dict[str, type]]A dictionary of types to override.
None
Source Code in griptape/schemas/base_schema.py
@classmethod
def _resolve_types(cls, attrs_cls: type, types_override: Optional[dict[str, type]] = None) -> None:
    """Resolve types in an attrs class.

    Args:
        attrs_cls: An attrs class.
        types_override: A dictionary of types to override.
    """
    from collections.abc import Sequence
    from typing import Any

    from pydantic import BaseModel
    from schema import Schema

    from griptape.artifacts import (
        ActionArtifact,
        AudioArtifact,
        BaseArtifact,
        BlobArtifact,
        BooleanArtifact,
        ErrorArtifact,
        GenericArtifact,
        ImageArtifact,
        ImageUrlArtifact,
        InfoArtifact,
        JsonArtifact,
        ListArtifact,
        TextArtifact,
    )
    from griptape.common import (
        BaseDeltaMessageContent,
        BaseMessageContent,
        Message,
        PromptStack,
        Reference,
        ToolAction,
    )
    from griptape.drivers.assistant import BaseAssistantDriver
    from griptape.drivers.audio_transcription import BaseAudioTranscriptionDriver
    from griptape.drivers.embedding import BaseEmbeddingDriver
    from griptape.drivers.file_manager import BaseFileManagerDriver
    from griptape.drivers.image_generation import BaseImageGenerationDriver, BaseMultiModelImageGenerationDriver
    from griptape.drivers.image_generation_model import BaseImageGenerationModelDriver
    from griptape.drivers.memory.conversation import BaseConversationMemoryDriver
    from griptape.drivers.observability import BaseObservabilityDriver
    from griptape.drivers.prompt import BasePromptDriver
    from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy
    from griptape.drivers.ruleset import BaseRulesetDriver
    from griptape.drivers.sql import BaseSqlDriver
    from griptape.drivers.structure_run import BaseStructureRunDriver
    from griptape.drivers.text_to_speech import BaseTextToSpeechDriver
    from griptape.drivers.vector import BaseVectorStoreDriver
    from griptape.drivers.web_scraper import BaseWebScraperDriver
    from griptape.drivers.web_search import BaseWebSearchDriver
    from griptape.engines.rag import RagContext
    from griptape.events import EventListener
    from griptape.memory import TaskMemory
    from griptape.memory.meta import BaseMetaEntry
    from griptape.memory.structure import BaseConversationMemory, Run
    from griptape.memory.task.storage import BaseArtifactStorage
    from griptape.rules.base_rule import BaseRule
    from griptape.rules.ruleset import Ruleset
    from griptape.structures import Structure
    from griptape.tasks import BaseTask
    from griptape.tokenizers import BaseTokenizer
    from griptape.tools import BaseTool
    from griptape.utils import import_optional_dependency, is_dependency_installed

    if types_override is None:
        types_override = {}

    attrs.resolve_types(
        attrs_cls,
        localns={
            "Any": Any,
            "BasePromptDriver": BasePromptDriver,
            "BaseEmbeddingDriver": BaseEmbeddingDriver,
            "BaseVectorStoreDriver": BaseVectorStoreDriver,
            "BaseTextToSpeechDriver": BaseTextToSpeechDriver,
            "BaseAudioTranscriptionDriver": BaseAudioTranscriptionDriver,
            "BaseConversationMemoryDriver": BaseConversationMemoryDriver,
            "BaseRulesetDriver": BaseRulesetDriver,
            "BaseImageGenerationDriver": BaseImageGenerationDriver,
            "BaseMultiModelImageGenerationDriver": BaseMultiModelImageGenerationDriver,
            "BaseImageGenerationModelDriver": BaseImageGenerationModelDriver,
            "BaseWebSearchDriver": BaseWebSearchDriver,
            "BaseWebScraperDriver": BaseWebScraperDriver,
            "BaseFileManagerDriver": BaseFileManagerDriver,
            "BaseSqlDriver": BaseSqlDriver,
            "BaseObservabilityDriver": BaseObservabilityDriver,
            "BaseAssistantDriver": BaseAssistantDriver,
            "BaseStructureRunDriver": BaseStructureRunDriver,
            "BaseArtifact": BaseArtifact,
            "BaseMetaEntry": BaseMetaEntry,
            "PromptStack": PromptStack,
            "EventListener": EventListener,
            "BaseMessageContent": BaseMessageContent,
            "BaseDeltaMessageContent": BaseDeltaMessageContent,
            "BaseTool": BaseTool,
            "BaseTask": BaseTask,
            "TextArtifact": TextArtifact,
            "ImageArtifact": ImageArtifact,
            "ImageUrlArtifact": ImageUrlArtifact,
            "ErrorArtifact": ErrorArtifact,
            "InfoArtifact": InfoArtifact,
            "JsonArtifact": JsonArtifact,
            "BlobArtifact": BlobArtifact,
            "BooleanArtifact": BooleanArtifact,
            "ListArtifact": ListArtifact,
            "AudioArtifact": AudioArtifact,
            "ActionArtifact": ActionArtifact,
            "GenericArtifact": GenericArtifact,
            "Usage": Message.Usage,
            "Structure": Structure,
            "BaseTokenizer": BaseTokenizer,
            "ToolAction": ToolAction,
            "Reference": Reference,
            "Run": Run,
            "Sequence": Sequence,
            "TaskMemory": TaskMemory,
            "State": BaseTask.State,
            "BaseConversationMemory": BaseConversationMemory,
            "BaseArtifactStorage": BaseArtifactStorage,
            "BaseRule": BaseRule,
            "Ruleset": Ruleset,
            "StructuredOutputStrategy": StructuredOutputStrategy,
            "RagContext": RagContext,
            # Third party modules
            "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any,
            "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any,
            "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel
            if is_dependency_installed("google.generativeai")
            else Any,
            "boto3": import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any,
            "Anthropic": import_optional_dependency("anthropic").Anthropic
            if is_dependency_installed("anthropic")
            else Any,
            "BedrockRuntimeClient": import_optional_dependency("mypy_boto3_bedrock_runtime").BedrockRuntimeClient
            if is_dependency_installed("mypy_boto3_bedrock_runtime")
            else Any,
            "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any,
            "Schema": Schema,
            "BaseModel": BaseModel,
            **types_override,
        },
    )

from_attrs_cls(attrs_cls, *, types_overrides=None, serializable_overrides=None)classmethod

Generate a Schema from an attrs class.

Parameters

NameTypeDescriptionDefault
attrs_clstypeAn attrs class.
required
types_overridesOptional[dict[str, type]]A dictionary of types to override when resolving types.
None
serializable_overridesOptional[dict[str, bool]]A dictionary of field names to whether they are serializable.
None
Source Code in griptape/schemas/base_schema.py
@classmethod
def from_attrs_cls(
    cls,
    attrs_cls: type,
    *,
    types_overrides: Optional[dict[str, type]] = None,
    serializable_overrides: Optional[dict[str, bool]] = None,
) -> type:
    """Generate a Schema from an attrs class.

    Args:
        attrs_cls: An attrs class.
        types_overrides: A dictionary of types to override when resolving types.
        serializable_overrides: A dictionary of field names to whether they are serializable.
    """
    from marshmallow import post_load

    if serializable_overrides is None:
        serializable_overrides = {}

    class SubSchema(cls):
        @post_load
        def make_obj(self, data: Any, **kwargs) -> Any:
            # Map the serialized keys to their correct deserialization keys
            fields = attrs.fields_dict(attrs_cls)
            for key in list(data):
                if key in fields:
                    field = fields[key]
                    if field.metadata.get("deserialization_key"):
                        data[field.metadata["deserialization_key"]] = data.pop(key)

            return attrs_cls(**data)

    cls._resolve_types(attrs_cls, types_override=types_overrides)
    fields = {}
    for field in attrs.fields(attrs_cls):
        field_key = field.alias or field.name
        if serializable_overrides.get(field_key, field.metadata.get("serializable", False)):
            fields[field_key] = cls._get_field_for_type(
                field.type,
                serialization_key=field.metadata.get("serialization_key"),
                types_overrides=types_overrides,
            )
    return SubSchema.from_dict(fields, name=f"{attrs_cls.__name__}Schema")

Bytes

Bases:

Field
Source Code in griptape/schemas/bytes_field.py
class Bytes(fields.Field):
    def _serialize(self, value: Any, attr: Any, obj: Any, **kwargs) -> str:
        return base64.b64encode(value).decode()

    def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs) -> bytes:
        return base64.b64decode(value)

    def _validate(self, value: Any) -> None:
        if not isinstance(value, bytes):
            raise ValidationError("Invalid input type.")

_deserialize(value, attr, data, **kwargs)

Source Code in griptape/schemas/bytes_field.py
def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs) -> bytes:
    return base64.b64decode(value)

_serialize(value, attr, obj, **kwargs)

Source Code in griptape/schemas/bytes_field.py
def _serialize(self, value: Any, attr: Any, obj: Any, **kwargs) -> str:
    return base64.b64encode(value).decode()

_validate(value)

Source Code in griptape/schemas/bytes_field.py
def _validate(self, value: Any) -> None:
    if not isinstance(value, bytes):
        raise ValidationError("Invalid input type.")

PolymorphicSchema

Bases: BaseSchema

Source Code in griptape/schemas/polymorphic_schema.py
class PolymorphicSchema(BaseSchema):
    """PolymorphicSchema is based on https://github.com/marshmallow-code/marshmallow-oneofschema."""

    def __init__(self, inner_class: Any, types_overrides: Optional[dict[str, type]] = None, **kwargs) -> None:
        super().__init__(**kwargs)

        self.inner_class = inner_class
        self.types_overrides = types_overrides

    type_field = "type"
    type_field_remove = True

    def get_obj_type(self, obj: Any) -> Any:
        """Returns name of the schema during dump() calls, given the object being dumped."""
        return obj.__class__.__name__

    def get_data_type(self, data: Any) -> Any:
        """Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up `type_field` in the data."""
        data_type = data.get(self.type_field)
        if self.type_field in data and self.type_field_remove:
            data.pop(self.type_field)
        return data_type

    def dump(self, obj: Any, *, many: Any = None, **kwargs) -> Any:
        errors = {}
        result_data = []
        result_errors = {}
        many = self.many if many is None else bool(many)
        if not many:
            result = result_data = self._dump(obj, **kwargs)
        else:
            for idx, o in enumerate(obj):
                try:
                    result = self._dump(o, **kwargs)
                    result_data.append(result)
                except ValidationError as error:
                    result_errors[idx] = error.normalized_messages()
                    result_data.append(error.valid_data)

        result = result_data
        errors = result_errors

        if not errors:
            return result
        exc = ValidationError(errors, data=obj, valid_data=result)  # pyright: ignore[reportArgumentType]
        raise exc

    def _dump(self, obj: Any, *, update_fields: bool = True, **kwargs) -> Any:
        obj_type = self.get_obj_type(obj)

        if not obj_type:
            return (None, {"_schema": f"Unknown object class: {obj.__class__.__name__}"})

        type_schema = BaseSchema.from_attrs_cls(obj.__class__, types_overrides=self.types_overrides)

        if not type_schema:
            return None, {"_schema": f"Unsupported object type: {obj_type}"}

        schema = type_schema if isinstance(type_schema, Schema) else type_schema()

        schema.context.update(getattr(self, "context", {}))

        result = schema.dump(obj, many=False, **kwargs)

        if result is not None:
            result[self.type_field] = obj_type  # pyright: ignore[reportArgumentType,reportCallIssue]

        return result

    def load(self, data: Any, *, many: Any = None, partial: Any = None, unknown: Any = None, **kwargs) -> Any:
        errors = {}
        result_data = []
        result_errors = {}
        many = self.many if many is None else bool(many)
        if partial is None:
            partial = self.partial
        if not many:
            try:
                result = result_data = self._load(data, partial=partial, unknown=unknown, **kwargs)
            except ValidationError as error:
                result_errors = error.normalized_messages()
                result_data.append(error.valid_data)
        else:
            for idx, item in enumerate(data):
                try:
                    result = self._load(item, partial=partial, **kwargs)
                    result_data.append(result)
                except ValidationError as error:
                    result_errors[idx] = error.normalized_messages()
                    result_data.append(error.valid_data)

        result = result_data
        errors = result_errors

        if not errors:
            return result
        exc = ValidationError(errors, data=data, valid_data=result)
        raise exc

    def _load(self, data: Any, *, partial: Any = None, unknown: Any = None, **kwargs) -> Any:
        if not isinstance(data, dict):
            raise ValidationError({"_schema": f"Invalid data type: {data}"})

        data = dict(data)
        unknown = unknown or self.unknown
        data_type = self.get_data_type(data)

        if data_type is None:
            raise ValidationError({self.type_field: ["Missing data for required field."]})

        type_schema = self.inner_class.get_schema(data_type, module_name=data.get("module_name"))
        if not type_schema:
            raise ValidationError({self.type_field: [f"Unsupported value: {data_type}"]})

        schema = type_schema if isinstance(type_schema, Schema) else type_schema()

        schema.context.update(getattr(self, "context", {}))

        return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs)

    def validate(self, data: Any, *, many: Any = None, partial: Any = None) -> Any:  # pyright: ignore[reportIncompatibleMethodOverride]
        try:
            self.load(data, many=many, partial=partial)
        except ValidationError as ve:
            return ve.messages
        return {}
  • inner_class = inner_class instance-attribute

  • type_field = 'type' class-attribute instance-attribute

  • type_field_remove = True class-attribute instance-attribute

  • types_overrides = types_overrides instance-attribute

init(inner_class, types_overrides=None, **kwargs)

Source Code in griptape/schemas/polymorphic_schema.py
def __init__(self, inner_class: Any, types_overrides: Optional[dict[str, type]] = None, **kwargs) -> None:
    super().__init__(**kwargs)

    self.inner_class = inner_class
    self.types_overrides = types_overrides

_dump(obj, *, update_fields=True, **kwargs)

Source Code in griptape/schemas/polymorphic_schema.py
def _dump(self, obj: Any, *, update_fields: bool = True, **kwargs) -> Any:
    obj_type = self.get_obj_type(obj)

    if not obj_type:
        return (None, {"_schema": f"Unknown object class: {obj.__class__.__name__}"})

    type_schema = BaseSchema.from_attrs_cls(obj.__class__, types_overrides=self.types_overrides)

    if not type_schema:
        return None, {"_schema": f"Unsupported object type: {obj_type}"}

    schema = type_schema if isinstance(type_schema, Schema) else type_schema()

    schema.context.update(getattr(self, "context", {}))

    result = schema.dump(obj, many=False, **kwargs)

    if result is not None:
        result[self.type_field] = obj_type  # pyright: ignore[reportArgumentType,reportCallIssue]

    return result

_load(data, *, partial=None, unknown=None, **kwargs)

Source Code in griptape/schemas/polymorphic_schema.py
def _load(self, data: Any, *, partial: Any = None, unknown: Any = None, **kwargs) -> Any:
    if not isinstance(data, dict):
        raise ValidationError({"_schema": f"Invalid data type: {data}"})

    data = dict(data)
    unknown = unknown or self.unknown
    data_type = self.get_data_type(data)

    if data_type is None:
        raise ValidationError({self.type_field: ["Missing data for required field."]})

    type_schema = self.inner_class.get_schema(data_type, module_name=data.get("module_name"))
    if not type_schema:
        raise ValidationError({self.type_field: [f"Unsupported value: {data_type}"]})

    schema = type_schema if isinstance(type_schema, Schema) else type_schema()

    schema.context.update(getattr(self, "context", {}))

    return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs)

dump(obj, *, many=None, **kwargs)

Source Code in griptape/schemas/polymorphic_schema.py
def dump(self, obj: Any, *, many: Any = None, **kwargs) -> Any:
    errors = {}
    result_data = []
    result_errors = {}
    many = self.many if many is None else bool(many)
    if not many:
        result = result_data = self._dump(obj, **kwargs)
    else:
        for idx, o in enumerate(obj):
            try:
                result = self._dump(o, **kwargs)
                result_data.append(result)
            except ValidationError as error:
                result_errors[idx] = error.normalized_messages()
                result_data.append(error.valid_data)

    result = result_data
    errors = result_errors

    if not errors:
        return result
    exc = ValidationError(errors, data=obj, valid_data=result)  # pyright: ignore[reportArgumentType]
    raise exc

get_data_type(data)

Source Code in griptape/schemas/polymorphic_schema.py
def get_data_type(self, data: Any) -> Any:
    """Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up `type_field` in the data."""
    data_type = data.get(self.type_field)
    if self.type_field in data and self.type_field_remove:
        data.pop(self.type_field)
    return data_type

get_obj_type(obj)

Source Code in griptape/schemas/polymorphic_schema.py
def get_obj_type(self, obj: Any) -> Any:
    """Returns name of the schema during dump() calls, given the object being dumped."""
    return obj.__class__.__name__

load(data, *, many=None, partial=None, unknown=None, **kwargs)

Source Code in griptape/schemas/polymorphic_schema.py
def load(self, data: Any, *, many: Any = None, partial: Any = None, unknown: Any = None, **kwargs) -> Any:
    errors = {}
    result_data = []
    result_errors = {}
    many = self.many if many is None else bool(many)
    if partial is None:
        partial = self.partial
    if not many:
        try:
            result = result_data = self._load(data, partial=partial, unknown=unknown, **kwargs)
        except ValidationError as error:
            result_errors = error.normalized_messages()
            result_data.append(error.valid_data)
    else:
        for idx, item in enumerate(data):
            try:
                result = self._load(item, partial=partial, **kwargs)
                result_data.append(result)
            except ValidationError as error:
                result_errors[idx] = error.normalized_messages()
                result_data.append(error.valid_data)

    result = result_data
    errors = result_errors

    if not errors:
        return result
    exc = ValidationError(errors, data=data, valid_data=result)
    raise exc

validate(data, *, many=None, partial=None)

Source Code in griptape/schemas/polymorphic_schema.py
def validate(self, data: Any, *, many: Any = None, partial: Any = None) -> Any:  # pyright: ignore[reportIncompatibleMethodOverride]
    try:
        self.load(data, many=many, partial=partial)
    except ValidationError as ve:
        return ve.messages
    return {}

PydanticModel

Bases:

Field
Source Code in griptape/schemas/pydantic_model_field.py
class PydanticModel(fields.Field):
    def _serialize(self, value: Optional[BaseModel], attr: Any, obj: Any, **kwargs) -> Optional[dict]:
        if value is None:
            return None
        return value.model_dump()

    def _deserialize(self, value: dict, attr: Any, data: Any, **kwargs) -> BaseModel:
        return RootModel(value)

_deserialize(value, attr, data, **kwargs)

Source Code in griptape/schemas/pydantic_model_field.py
def _deserialize(self, value: dict, attr: Any, data: Any, **kwargs) -> BaseModel:
    return RootModel(value)

_serialize(value, attr, obj, **kwargs)

Source Code in griptape/schemas/pydantic_model_field.py
def _serialize(self, value: Optional[BaseModel], attr: Any, obj: Any, **kwargs) -> Optional[dict]:
    if value is None:
        return None
    return value.model_dump()

Union

Bases:

Field

Parameters

NameTypeDescriptionDefault
fieldslist[Field]The list of candidate fields to try.
required
reverse_serialize_candidatesboolWhether to try the candidates in reverse order when serializing.
False
Source Code in griptape/schemas/union_field.py
class Union(marshmallow.fields.Field):
    """Field that accepts any one of multiple fields.

    Source: https://github.com/adamboche/python-marshmallow-union
    Each argument will be tried until one succeeds.

    Args:
        fields: The list of candidate fields to try.
        reverse_serialize_candidates: Whether to try the candidates in reverse order when serializing.
    """

    def __init__(
        self,
        fields: list[marshmallow.fields.Field],
        *,
        reverse_serialize_candidates: bool = False,
        **kwargs: Any,
    ) -> None:
        self._candidate_fields = fields
        self._reverse_serialize_candidates = reverse_serialize_candidates
        super().__init__(**kwargs)

    def _serialize(self, value: Any, attr: str | None, obj: str, **kwargs: Any) -> Any:
        """Pulls the value for the given key from the object, applies the field's formatting and returns the result.

        Args:
            value: The value to be serialized.
            attr: The attribute or key to get from the object.
            obj: The object to pull the key from.
            kwargs: Field-specific keyword arguments.

        Raises:
            marshmallow.exceptions.ValidationError: In case of formatting problem
        """
        error_store = kwargs.pop("error_store", marshmallow.error_store.ErrorStore())
        fields = (
            list(reversed(self._candidate_fields)) if self._reverse_serialize_candidates else self._candidate_fields
        )

        for candidate_field in fields:
            try:
                # pylint: disable=protected-access
                return candidate_field._serialize(value, attr, obj, error_store=error_store, **kwargs)
            except (TypeError, ValueError) as e:
                error_store.store_error({attr: str(e)})

        raise ExceptionGroupError("All serializers raised exceptions.", error_store.errors)

    def _deserialize(self, value: Any, attr: str | None = None, data: Any = None, **kwargs: Any) -> Any:
        """Deserialize ``value``.

        Args:
            value: The value to be deserialized.
            attr: The attribute/key in `data` to be deserialized.
            data: The raw input data passed to the `Schema.load`.
            kwargs: Field-specific keyword arguments.

        Raises:
            ValidationError: If an invalid value is passed or if a required value is missing.
        """
        errors = []
        for candidate_field in self._candidate_fields:
            try:
                return candidate_field.deserialize(value, attr, data, **kwargs)
            except marshmallow.exceptions.ValidationError as exc:
                errors.append(exc.messages)

        raise marshmallow.exceptions.ValidationError(message=errors, field_name=attr or "")
  • _candidate_fields = fields instance-attribute

  • _reverse_serialize_candidates = reverse_serialize_candidates instance-attribute

init(fields, *, reverse_serialize_candidates=False, **kwargs)

Source Code in griptape/schemas/union_field.py
def __init__(
    self,
    fields: list[marshmallow.fields.Field],
    *,
    reverse_serialize_candidates: bool = False,
    **kwargs: Any,
) -> None:
    self._candidate_fields = fields
    self._reverse_serialize_candidates = reverse_serialize_candidates
    super().__init__(**kwargs)

_deserialize(value, attr=None, data=None, **kwargs)

Deserialize value.

Parameters

NameTypeDescriptionDefault
valueAnyThe value to be deserialized.
required
attrstr | NoneThe attribute/key in data to be deserialized.
None
dataAnyThe raw input data passed to the Schema.load.
None
kwargsAnyField-specific keyword arguments.
{}

Raises

Source Code in griptape/schemas/union_field.py
def _deserialize(self, value: Any, attr: str | None = None, data: Any = None, **kwargs: Any) -> Any:
    """Deserialize ``value``.

    Args:
        value: The value to be deserialized.
        attr: The attribute/key in `data` to be deserialized.
        data: The raw input data passed to the `Schema.load`.
        kwargs: Field-specific keyword arguments.

    Raises:
        ValidationError: If an invalid value is passed or if a required value is missing.
    """
    errors = []
    for candidate_field in self._candidate_fields:
        try:
            return candidate_field.deserialize(value, attr, data, **kwargs)
        except marshmallow.exceptions.ValidationError as exc:
            errors.append(exc.messages)

    raise marshmallow.exceptions.ValidationError(message=errors, field_name=attr or "")

_serialize(value, attr, obj, **kwargs)

Pulls the value for the given key from the object, applies the field's formatting and returns the result.

Parameters

NameTypeDescriptionDefault
valueAnyThe value to be serialized.
required
attrstr | NoneThe attribute or key to get from the object.
required
objstrThe object to pull the key from.
required
kwargsAnyField-specific keyword arguments.
{}

Raises

Source Code in griptape/schemas/union_field.py
def _serialize(self, value: Any, attr: str | None, obj: str, **kwargs: Any) -> Any:
    """Pulls the value for the given key from the object, applies the field's formatting and returns the result.

    Args:
        value: The value to be serialized.
        attr: The attribute or key to get from the object.
        obj: The object to pull the key from.
        kwargs: Field-specific keyword arguments.

    Raises:
        marshmallow.exceptions.ValidationError: In case of formatting problem
    """
    error_store = kwargs.pop("error_store", marshmallow.error_store.ErrorStore())
    fields = (
        list(reversed(self._candidate_fields)) if self._reverse_serialize_candidates else self._candidate_fields
    )

    for candidate_field in fields:
        try:
            # pylint: disable=protected-access
            return candidate_field._serialize(value, attr, obj, error_store=error_store, **kwargs)
        except (TypeError, ValueError) as e:
            error_store.store_error({attr: str(e)})

    raise ExceptionGroupError("All serializers raised exceptions.", error_store.errors)

Could this page be better? Report a problem or suggest an addition!