schemas
__all__ = ['BaseSchema', 'Bytes', 'PolymorphicSchema', 'PydanticModel', 'Union']
module-attribute
Bases:
SchemaSource Code in griptape/schemas/base_schema.py
class BaseSchema(Schema): class Meta: unknown = INCLUDE DATACLASS_TYPE_MAPPING = { **Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw, BaseModel: PydanticModel, } @classmethod def from_attrs_cls( cls, attrs_cls: type, *, types_overrides: Optional[dict[str, type]] = None, serializable_overrides: Optional[dict[str, bool]] = None, ) -> type: """Generate a Schema from an attrs class. Args: attrs_cls: An attrs class. types_overrides: A dictionary of types to override when resolving types. serializable_overrides: A dictionary of field names to whether they are serializable. """ from marshmallow import post_load if serializable_overrides is None: serializable_overrides = {} class SubSchema(cls): @post_load def make_obj(self, data: Any, **kwargs) -> Any: # Map the serialized keys to their correct deserialization keys fields = attrs.fields_dict(attrs_cls) for key in list(data): if key in fields: field = fields[key] if field.metadata.get("deserialization_key"): data[field.metadata["deserialization_key"]] = data.pop(key) return attrs_cls(**data) cls._resolve_types(attrs_cls, types_override=types_overrides) fields = {} for field in attrs.fields(attrs_cls): field_key = field.alias or field.name if serializable_overrides.get(field_key, field.metadata.get("serializable", False)): fields[field_key] = cls._get_field_for_type( field.type, serialization_key=field.metadata.get("serialization_key"), types_overrides=types_overrides, ) return SubSchema.from_dict(fields, name=f"{attrs_cls.__name__}Schema") @classmethod def _get_field_for_type( cls, field_type: type, serialization_key: Optional[str] = None, types_overrides: Optional[dict[str, type]] = None, ) -> fields.Field | fields.Nested: """Generate a marshmallow Field instance from a Python type. Args: field_type: A field type. serialization_key: The key to pull the data from before serializing. types_overrides: A dictionary of types to override when resolving types. """ from griptape.schemas.polymorphic_schema import PolymorphicSchema field_class, args, optional = cls._get_field_type_info(field_type) if field_class is None: return fields.Constant(None, allow_none=True) # Resolve TypeVars to their bound type if isinstance(field_class, TypeVar): field_class = field_class.__bound__ if field_class is None: return fields.Raw(allow_none=optional, attribute=serialization_key) if cls._is_union(field_type): return cls._handle_union( field_type, optional=optional, serialization_key=serialization_key, ) if attrs.has(field_class): if ABC in field_class.__bases__: schema = PolymorphicSchema(field_class, types_overrides=types_overrides) else: schema = cls.from_attrs_cls(field_class, types_overrides=types_overrides) return fields.Nested(schema, allow_none=optional, attribute=serialization_key) if cls._is_enum(field_type): return fields.String(allow_none=optional, attribute=serialization_key) if cls._is_list_sequence(field_class): if args: return cls._handle_list( args[0], optional=optional, serialization_key=serialization_key, ) raise ValueError(f"Missing type for list field: {field_type}") field_class = cls.DATACLASS_TYPE_MAPPING.get(field_class, fields.Raw) return field_class(allow_none=optional, attribute=serialization_key) @classmethod def _handle_list( cls, list_type: type, *, optional: bool, serialization_key: Optional[str] = None, ) -> fields.Field: """Handle List Fields, including Union Types. Args: list_type: The List type to handle. optional: Whether the List can be none. serialization_key: The key to pull the data from before serializing. Returns: A marshmallow List field. """ if cls._is_union(list_type): instance = cls._handle_union( list_type, optional=optional, serialization_key=serialization_key, ) else: instance = cls._get_field_for_type(list_type, serialization_key=serialization_key) return fields.List(cls_or_instance=instance, allow_none=optional, attribute=serialization_key) @classmethod def _handle_union( cls, union_type: type, *, optional: bool, serialization_key: Optional[str] = None, ) -> fields.Field: """Handle Union Fields, including Unions with List Types. Args: union_type: The Union Type to handle. optional: Whether the Union can be None. serialization_key: The key to pull the data from before serializing. Returns: A marshmallow Union field. """ candidate_fields = [cls._get_field_for_type(arg) for arg in get_args(union_type) if arg is not type(None)] optional_args = [arg is None for arg in get_args(union_type)] if optional_args: optional = True if not candidate_fields: raise ValueError(f"Unsupported UnionType field: {union_type}") return UnionField(fields=candidate_fields, allow_none=optional, attribute=serialization_key) @classmethod def _get_field_type_info(cls, field_type: type) -> tuple[type, tuple[type, ...], bool]: """Get information about a field type. Args: field_type: A field type. """ origin = get_origin(field_type) or field_type args = get_args(field_type) optional = False if origin is Union: origin = args[0] if len(args) > 1 and args[1] is type(None): optional = True origin, args, _ = cls._get_field_type_info(origin) elif origin is Literal: origin = type(args[0]) args = () return origin, args, optional @classmethod def _resolve_types(cls, attrs_cls: type, types_override: Optional[dict[str, type]] = None) -> None: """Resolve types in an attrs class. Args: attrs_cls: An attrs class. types_override: A dictionary of types to override. """ from collections.abc import Sequence from typing import Any from pydantic import BaseModel from schema import Schema from griptape.artifacts import ( ActionArtifact, AudioArtifact, BaseArtifact, BlobArtifact, BooleanArtifact, ErrorArtifact, GenericArtifact, ImageArtifact, ImageUrlArtifact, InfoArtifact, JsonArtifact, ListArtifact, TextArtifact, ) from griptape.common import ( BaseDeltaMessageContent, BaseMessageContent, Message, PromptStack, Reference, ToolAction, ) from griptape.drivers.assistant import BaseAssistantDriver from griptape.drivers.audio_transcription import BaseAudioTranscriptionDriver from griptape.drivers.embedding import BaseEmbeddingDriver from griptape.drivers.file_manager import BaseFileManagerDriver from griptape.drivers.image_generation import BaseImageGenerationDriver, BaseMultiModelImageGenerationDriver from griptape.drivers.image_generation_model import BaseImageGenerationModelDriver from griptape.drivers.memory.conversation import BaseConversationMemoryDriver from griptape.drivers.observability import BaseObservabilityDriver from griptape.drivers.prompt import BasePromptDriver from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.drivers.ruleset import BaseRulesetDriver from griptape.drivers.sql import BaseSqlDriver from griptape.drivers.structure_run import BaseStructureRunDriver from griptape.drivers.text_to_speech import BaseTextToSpeechDriver from griptape.drivers.vector import BaseVectorStoreDriver from griptape.drivers.web_scraper import BaseWebScraperDriver from griptape.drivers.web_search import BaseWebSearchDriver from griptape.engines.rag import RagContext from griptape.events import EventListener from griptape.memory import TaskMemory from griptape.memory.meta import BaseMetaEntry from griptape.memory.structure import BaseConversationMemory, Run from griptape.memory.task.storage import BaseArtifactStorage from griptape.rules.base_rule import BaseRule from griptape.rules.ruleset import Ruleset from griptape.structures import Structure from griptape.tasks import BaseTask from griptape.tokenizers import BaseTokenizer from griptape.tools import BaseTool from griptape.utils import import_optional_dependency, is_dependency_installed if types_override is None: types_override = {} attrs.resolve_types( attrs_cls, localns={ "Any": Any, "BasePromptDriver": BasePromptDriver, "BaseEmbeddingDriver": BaseEmbeddingDriver, "BaseVectorStoreDriver": BaseVectorStoreDriver, "BaseTextToSpeechDriver": BaseTextToSpeechDriver, "BaseAudioTranscriptionDriver": BaseAudioTranscriptionDriver, "BaseConversationMemoryDriver": BaseConversationMemoryDriver, "BaseRulesetDriver": BaseRulesetDriver, "BaseImageGenerationDriver": BaseImageGenerationDriver, "BaseMultiModelImageGenerationDriver": BaseMultiModelImageGenerationDriver, "BaseImageGenerationModelDriver": BaseImageGenerationModelDriver, "BaseWebSearchDriver": BaseWebSearchDriver, "BaseWebScraperDriver": BaseWebScraperDriver, "BaseFileManagerDriver": BaseFileManagerDriver, "BaseSqlDriver": BaseSqlDriver, "BaseObservabilityDriver": BaseObservabilityDriver, "BaseAssistantDriver": BaseAssistantDriver, "BaseStructureRunDriver": BaseStructureRunDriver, "BaseArtifact": BaseArtifact, "BaseMetaEntry": BaseMetaEntry, "PromptStack": PromptStack, "EventListener": EventListener, "BaseMessageContent": BaseMessageContent, "BaseDeltaMessageContent": BaseDeltaMessageContent, "BaseTool": BaseTool, "BaseTask": BaseTask, "TextArtifact": TextArtifact, "ImageArtifact": ImageArtifact, "ImageUrlArtifact": ImageUrlArtifact, "ErrorArtifact": ErrorArtifact, "InfoArtifact": InfoArtifact, "JsonArtifact": JsonArtifact, "BlobArtifact": BlobArtifact, "BooleanArtifact": BooleanArtifact, "ListArtifact": ListArtifact, "AudioArtifact": AudioArtifact, "ActionArtifact": ActionArtifact, "GenericArtifact": GenericArtifact, "Usage": Message.Usage, "Structure": Structure, "BaseTokenizer": BaseTokenizer, "ToolAction": ToolAction, "Reference": Reference, "Run": Run, "Sequence": Sequence, "TaskMemory": TaskMemory, "State": BaseTask.State, "BaseConversationMemory": BaseConversationMemory, "BaseArtifactStorage": BaseArtifactStorage, "BaseRule": BaseRule, "Ruleset": Ruleset, "StructuredOutputStrategy": StructuredOutputStrategy, "RagContext": RagContext, # Third party modules "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any, "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any, "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel if is_dependency_installed("google.generativeai") else Any, "boto3": import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any, "Anthropic": import_optional_dependency("anthropic").Anthropic if is_dependency_installed("anthropic") else Any, "BedrockRuntimeClient": import_optional_dependency("mypy_boto3_bedrock_runtime").BedrockRuntimeClient if is_dependency_installed("mypy_boto3_bedrock_runtime") else Any, "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any, "Schema": Schema, "BaseModel": BaseModel, **types_override, }, ) @classmethod def _is_list_sequence(cls, field_type: type | _SpecialForm) -> bool: if isinstance(field_type, type): if issubclass(field_type, str) or issubclass(field_type, bytes) or issubclass(field_type, tuple): return False return issubclass(field_type, Sequence) return False @classmethod def _is_union(cls, field_type: type) -> bool: return field_type is Union or get_origin(field_type) is Union @classmethod def _is_enum(cls, field_type: type) -> bool: return isinstance(field_type, type) and issubclass(field_type, Enum)
DATACLASS_TYPE_MAPPING = {None: Schema.TYPE_MAPPING, dict: fields.Dict, bytes: Bytes, Any: fields.Raw, BaseModel: PydanticModel}
class-attribute instance-attribute
Meta
Source Code in griptape/schemas/base_schema.py
class Meta: unknown = INCLUDE
unknown = INCLUDE
class-attribute instance-attribute
_get_field_for_type(field_type, serialization_key=None, types_overrides=None)classmethod
Generate a marshmallow Field instance from a Python type.
Parameters
Name | Type | Description | Default |
---|---|---|---|
field_type | type | A field type. | required |
serialization_key | Optional[str] | The key to pull the data from before serializing. | None |
types_overrides | Optional[dict[str, type]] | A dictionary of types to override when resolving types. | None |
Source Code in griptape/schemas/base_schema.py
@classmethod def _get_field_for_type( cls, field_type: type, serialization_key: Optional[str] = None, types_overrides: Optional[dict[str, type]] = None, ) -> fields.Field | fields.Nested: """Generate a marshmallow Field instance from a Python type. Args: field_type: A field type. serialization_key: The key to pull the data from before serializing. types_overrides: A dictionary of types to override when resolving types. """ from griptape.schemas.polymorphic_schema import PolymorphicSchema field_class, args, optional = cls._get_field_type_info(field_type) if field_class is None: return fields.Constant(None, allow_none=True) # Resolve TypeVars to their bound type if isinstance(field_class, TypeVar): field_class = field_class.__bound__ if field_class is None: return fields.Raw(allow_none=optional, attribute=serialization_key) if cls._is_union(field_type): return cls._handle_union( field_type, optional=optional, serialization_key=serialization_key, ) if attrs.has(field_class): if ABC in field_class.__bases__: schema = PolymorphicSchema(field_class, types_overrides=types_overrides) else: schema = cls.from_attrs_cls(field_class, types_overrides=types_overrides) return fields.Nested(schema, allow_none=optional, attribute=serialization_key) if cls._is_enum(field_type): return fields.String(allow_none=optional, attribute=serialization_key) if cls._is_list_sequence(field_class): if args: return cls._handle_list( args[0], optional=optional, serialization_key=serialization_key, ) raise ValueError(f"Missing type for list field: {field_type}") field_class = cls.DATACLASS_TYPE_MAPPING.get(field_class, fields.Raw) return field_class(allow_none=optional, attribute=serialization_key)
_get_field_type_info(field_type)classmethod
Get information about a field type.
Parameters
Name | Type | Description | Default |
---|---|---|---|
field_type | type | A field type. | required |
Source Code in griptape/schemas/base_schema.py
@classmethod def _get_field_type_info(cls, field_type: type) -> tuple[type, tuple[type, ...], bool]: """Get information about a field type. Args: field_type: A field type. """ origin = get_origin(field_type) or field_type args = get_args(field_type) optional = False if origin is Union: origin = args[0] if len(args) > 1 and args[1] is type(None): optional = True origin, args, _ = cls._get_field_type_info(origin) elif origin is Literal: origin = type(args[0]) args = () return origin, args, optional
_handle_list(list_type, *, optional, serialization_key=None)classmethod
Handle List Fields, including Union Types.
Parameters
Name | Type | Description | Default |
---|---|---|---|
list_type | type | The List type to handle. | required |
optional | bool | Whether the List can be none. | required |
serialization_key | Optional[str] | The key to pull the data from before serializing. | None |
Returns
Type | Description |
---|---|
Field | A marshmallow List field. |
Source Code in griptape/schemas/base_schema.py
@classmethod def _handle_list( cls, list_type: type, *, optional: bool, serialization_key: Optional[str] = None, ) -> fields.Field: """Handle List Fields, including Union Types. Args: list_type: The List type to handle. optional: Whether the List can be none. serialization_key: The key to pull the data from before serializing. Returns: A marshmallow List field. """ if cls._is_union(list_type): instance = cls._handle_union( list_type, optional=optional, serialization_key=serialization_key, ) else: instance = cls._get_field_for_type(list_type, serialization_key=serialization_key) return fields.List(cls_or_instance=instance, allow_none=optional, attribute=serialization_key)
_handle_union(union_type, *, optional, serialization_key=None)classmethod
Handle Union Fields, including Unions with List Types.
Parameters
Name | Type | Description | Default |
---|---|---|---|
union_type | type | The Union Type to handle. | required |
optional | bool | Whether the Union can be None. | required |
serialization_key | Optional[str] | The key to pull the data from before serializing. | None |
Returns
Type | Description |
---|---|
Field | A marshmallow Union field. |
Source Code in griptape/schemas/base_schema.py
@classmethod def _handle_union( cls, union_type: type, *, optional: bool, serialization_key: Optional[str] = None, ) -> fields.Field: """Handle Union Fields, including Unions with List Types. Args: union_type: The Union Type to handle. optional: Whether the Union can be None. serialization_key: The key to pull the data from before serializing. Returns: A marshmallow Union field. """ candidate_fields = [cls._get_field_for_type(arg) for arg in get_args(union_type) if arg is not type(None)] optional_args = [arg is None for arg in get_args(union_type)] if optional_args: optional = True if not candidate_fields: raise ValueError(f"Unsupported UnionType field: {union_type}") return UnionField(fields=candidate_fields, allow_none=optional, attribute=serialization_key)
_is_enum(field_type)classmethod
Source Code in griptape/schemas/base_schema.py
@classmethod def _is_enum(cls, field_type: type) -> bool: return isinstance(field_type, type) and issubclass(field_type, Enum)
_is_list_sequence(field_type)classmethod
Source Code in griptape/schemas/base_schema.py
@classmethod def _is_list_sequence(cls, field_type: type | _SpecialForm) -> bool: if isinstance(field_type, type): if issubclass(field_type, str) or issubclass(field_type, bytes) or issubclass(field_type, tuple): return False return issubclass(field_type, Sequence) return False
_is_union(field_type)classmethod
Source Code in griptape/schemas/base_schema.py
@classmethod def _is_union(cls, field_type: type) -> bool: return field_type is Union or get_origin(field_type) is Union
_resolve_types(attrs_cls, types_override=None)classmethod
Resolve types in an attrs class.
Parameters
Name | Type | Description | Default |
---|---|---|---|
attrs_cls | type | An attrs class. | required |
types_override | Optional[dict[str, type]] | A dictionary of types to override. | None |
Source Code in griptape/schemas/base_schema.py
@classmethod def _resolve_types(cls, attrs_cls: type, types_override: Optional[dict[str, type]] = None) -> None: """Resolve types in an attrs class. Args: attrs_cls: An attrs class. types_override: A dictionary of types to override. """ from collections.abc import Sequence from typing import Any from pydantic import BaseModel from schema import Schema from griptape.artifacts import ( ActionArtifact, AudioArtifact, BaseArtifact, BlobArtifact, BooleanArtifact, ErrorArtifact, GenericArtifact, ImageArtifact, ImageUrlArtifact, InfoArtifact, JsonArtifact, ListArtifact, TextArtifact, ) from griptape.common import ( BaseDeltaMessageContent, BaseMessageContent, Message, PromptStack, Reference, ToolAction, ) from griptape.drivers.assistant import BaseAssistantDriver from griptape.drivers.audio_transcription import BaseAudioTranscriptionDriver from griptape.drivers.embedding import BaseEmbeddingDriver from griptape.drivers.file_manager import BaseFileManagerDriver from griptape.drivers.image_generation import BaseImageGenerationDriver, BaseMultiModelImageGenerationDriver from griptape.drivers.image_generation_model import BaseImageGenerationModelDriver from griptape.drivers.memory.conversation import BaseConversationMemoryDriver from griptape.drivers.observability import BaseObservabilityDriver from griptape.drivers.prompt import BasePromptDriver from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.drivers.ruleset import BaseRulesetDriver from griptape.drivers.sql import BaseSqlDriver from griptape.drivers.structure_run import BaseStructureRunDriver from griptape.drivers.text_to_speech import BaseTextToSpeechDriver from griptape.drivers.vector import BaseVectorStoreDriver from griptape.drivers.web_scraper import BaseWebScraperDriver from griptape.drivers.web_search import BaseWebSearchDriver from griptape.engines.rag import RagContext from griptape.events import EventListener from griptape.memory import TaskMemory from griptape.memory.meta import BaseMetaEntry from griptape.memory.structure import BaseConversationMemory, Run from griptape.memory.task.storage import BaseArtifactStorage from griptape.rules.base_rule import BaseRule from griptape.rules.ruleset import Ruleset from griptape.structures import Structure from griptape.tasks import BaseTask from griptape.tokenizers import BaseTokenizer from griptape.tools import BaseTool from griptape.utils import import_optional_dependency, is_dependency_installed if types_override is None: types_override = {} attrs.resolve_types( attrs_cls, localns={ "Any": Any, "BasePromptDriver": BasePromptDriver, "BaseEmbeddingDriver": BaseEmbeddingDriver, "BaseVectorStoreDriver": BaseVectorStoreDriver, "BaseTextToSpeechDriver": BaseTextToSpeechDriver, "BaseAudioTranscriptionDriver": BaseAudioTranscriptionDriver, "BaseConversationMemoryDriver": BaseConversationMemoryDriver, "BaseRulesetDriver": BaseRulesetDriver, "BaseImageGenerationDriver": BaseImageGenerationDriver, "BaseMultiModelImageGenerationDriver": BaseMultiModelImageGenerationDriver, "BaseImageGenerationModelDriver": BaseImageGenerationModelDriver, "BaseWebSearchDriver": BaseWebSearchDriver, "BaseWebScraperDriver": BaseWebScraperDriver, "BaseFileManagerDriver": BaseFileManagerDriver, "BaseSqlDriver": BaseSqlDriver, "BaseObservabilityDriver": BaseObservabilityDriver, "BaseAssistantDriver": BaseAssistantDriver, "BaseStructureRunDriver": BaseStructureRunDriver, "BaseArtifact": BaseArtifact, "BaseMetaEntry": BaseMetaEntry, "PromptStack": PromptStack, "EventListener": EventListener, "BaseMessageContent": BaseMessageContent, "BaseDeltaMessageContent": BaseDeltaMessageContent, "BaseTool": BaseTool, "BaseTask": BaseTask, "TextArtifact": TextArtifact, "ImageArtifact": ImageArtifact, "ImageUrlArtifact": ImageUrlArtifact, "ErrorArtifact": ErrorArtifact, "InfoArtifact": InfoArtifact, "JsonArtifact": JsonArtifact, "BlobArtifact": BlobArtifact, "BooleanArtifact": BooleanArtifact, "ListArtifact": ListArtifact, "AudioArtifact": AudioArtifact, "ActionArtifact": ActionArtifact, "GenericArtifact": GenericArtifact, "Usage": Message.Usage, "Structure": Structure, "BaseTokenizer": BaseTokenizer, "ToolAction": ToolAction, "Reference": Reference, "Run": Run, "Sequence": Sequence, "TaskMemory": TaskMemory, "State": BaseTask.State, "BaseConversationMemory": BaseConversationMemory, "BaseArtifactStorage": BaseArtifactStorage, "BaseRule": BaseRule, "Ruleset": Ruleset, "StructuredOutputStrategy": StructuredOutputStrategy, "RagContext": RagContext, # Third party modules "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any, "ClientV2": import_optional_dependency("cohere").ClientV2 if is_dependency_installed("cohere") else Any, "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel if is_dependency_installed("google.generativeai") else Any, "boto3": import_optional_dependency("boto3") if is_dependency_installed("boto3") else Any, "Anthropic": import_optional_dependency("anthropic").Anthropic if is_dependency_installed("anthropic") else Any, "BedrockRuntimeClient": import_optional_dependency("mypy_boto3_bedrock_runtime").BedrockRuntimeClient if is_dependency_installed("mypy_boto3_bedrock_runtime") else Any, "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any, "Schema": Schema, "BaseModel": BaseModel, **types_override, }, )
from_attrs_cls(attrs_cls, *, types_overrides=None, serializable_overrides=None)classmethod
Generate a Schema from an attrs class.
Parameters
Name | Type | Description | Default |
---|---|---|---|
attrs_cls | type | An attrs class. | required |
types_overrides | Optional[dict[str, type]] | A dictionary of types to override when resolving types. | None |
serializable_overrides | Optional[dict[str, bool]] | A dictionary of field names to whether they are serializable. | None |
Source Code in griptape/schemas/base_schema.py
@classmethod def from_attrs_cls( cls, attrs_cls: type, *, types_overrides: Optional[dict[str, type]] = None, serializable_overrides: Optional[dict[str, bool]] = None, ) -> type: """Generate a Schema from an attrs class. Args: attrs_cls: An attrs class. types_overrides: A dictionary of types to override when resolving types. serializable_overrides: A dictionary of field names to whether they are serializable. """ from marshmallow import post_load if serializable_overrides is None: serializable_overrides = {} class SubSchema(cls): @post_load def make_obj(self, data: Any, **kwargs) -> Any: # Map the serialized keys to their correct deserialization keys fields = attrs.fields_dict(attrs_cls) for key in list(data): if key in fields: field = fields[key] if field.metadata.get("deserialization_key"): data[field.metadata["deserialization_key"]] = data.pop(key) return attrs_cls(**data) cls._resolve_types(attrs_cls, types_override=types_overrides) fields = {} for field in attrs.fields(attrs_cls): field_key = field.alias or field.name if serializable_overrides.get(field_key, field.metadata.get("serializable", False)): fields[field_key] = cls._get_field_for_type( field.type, serialization_key=field.metadata.get("serialization_key"), types_overrides=types_overrides, ) return SubSchema.from_dict(fields, name=f"{attrs_cls.__name__}Schema")
Bytes
Bases:
FieldSource Code in griptape/schemas/bytes_field.py
class Bytes(fields.Field): def _serialize(self, value: Any, attr: Any, obj: Any, **kwargs) -> str: return base64.b64encode(value).decode() def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs) -> bytes: return base64.b64decode(value) def _validate(self, value: Any) -> None: if not isinstance(value, bytes): raise ValidationError("Invalid input type.")
_deserialize(value, attr, data, **kwargs)
Source Code in griptape/schemas/bytes_field.py
def _deserialize(self, value: Any, attr: Any, data: Any, **kwargs) -> bytes: return base64.b64decode(value)
_serialize(value, attr, obj, **kwargs)
Source Code in griptape/schemas/bytes_field.py
def _serialize(self, value: Any, attr: Any, obj: Any, **kwargs) -> str: return base64.b64encode(value).decode()
_validate(value)
Source Code in griptape/schemas/bytes_field.py
def _validate(self, value: Any) -> None: if not isinstance(value, bytes): raise ValidationError("Invalid input type.")
PolymorphicSchema
Bases:
BaseSchema
Source Code in griptape/schemas/polymorphic_schema.py
class PolymorphicSchema(BaseSchema): """PolymorphicSchema is based on https://github.com/marshmallow-code/marshmallow-oneofschema.""" def __init__(self, inner_class: Any, types_overrides: Optional[dict[str, type]] = None, **kwargs) -> None: super().__init__(**kwargs) self.inner_class = inner_class self.types_overrides = types_overrides type_field = "type" type_field_remove = True def get_obj_type(self, obj: Any) -> Any: """Returns name of the schema during dump() calls, given the object being dumped.""" return obj.__class__.__name__ def get_data_type(self, data: Any) -> Any: """Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up `type_field` in the data.""" data_type = data.get(self.type_field) if self.type_field in data and self.type_field_remove: data.pop(self.type_field) return data_type def dump(self, obj: Any, *, many: Any = None, **kwargs) -> Any: errors = {} result_data = [] result_errors = {} many = self.many if many is None else bool(many) if not many: result = result_data = self._dump(obj, **kwargs) else: for idx, o in enumerate(obj): try: result = self._dump(o, **kwargs) result_data.append(result) except ValidationError as error: result_errors[idx] = error.normalized_messages() result_data.append(error.valid_data) result = result_data errors = result_errors if not errors: return result exc = ValidationError(errors, data=obj, valid_data=result) # pyright: ignore[reportArgumentType] raise exc def _dump(self, obj: Any, *, update_fields: bool = True, **kwargs) -> Any: obj_type = self.get_obj_type(obj) if not obj_type: return (None, {"_schema": f"Unknown object class: {obj.__class__.__name__}"}) type_schema = BaseSchema.from_attrs_cls(obj.__class__, types_overrides=self.types_overrides) if not type_schema: return None, {"_schema": f"Unsupported object type: {obj_type}"} schema = type_schema if isinstance(type_schema, Schema) else type_schema() schema.context.update(getattr(self, "context", {})) result = schema.dump(obj, many=False, **kwargs) if result is not None: result[self.type_field] = obj_type # pyright: ignore[reportArgumentType,reportCallIssue] return result def load(self, data: Any, *, many: Any = None, partial: Any = None, unknown: Any = None, **kwargs) -> Any: errors = {} result_data = [] result_errors = {} many = self.many if many is None else bool(many) if partial is None: partial = self.partial if not many: try: result = result_data = self._load(data, partial=partial, unknown=unknown, **kwargs) except ValidationError as error: result_errors = error.normalized_messages() result_data.append(error.valid_data) else: for idx, item in enumerate(data): try: result = self._load(item, partial=partial, **kwargs) result_data.append(result) except ValidationError as error: result_errors[idx] = error.normalized_messages() result_data.append(error.valid_data) result = result_data errors = result_errors if not errors: return result exc = ValidationError(errors, data=data, valid_data=result) raise exc def _load(self, data: Any, *, partial: Any = None, unknown: Any = None, **kwargs) -> Any: if not isinstance(data, dict): raise ValidationError({"_schema": f"Invalid data type: {data}"}) data = dict(data) unknown = unknown or self.unknown data_type = self.get_data_type(data) if data_type is None: raise ValidationError({self.type_field: ["Missing data for required field."]}) type_schema = self.inner_class.get_schema(data_type, module_name=data.get("module_name")) if not type_schema: raise ValidationError({self.type_field: [f"Unsupported value: {data_type}"]}) schema = type_schema if isinstance(type_schema, Schema) else type_schema() schema.context.update(getattr(self, "context", {})) return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs) def validate(self, data: Any, *, many: Any = None, partial: Any = None) -> Any: # pyright: ignore[reportIncompatibleMethodOverride] try: self.load(data, many=many, partial=partial) except ValidationError as ve: return ve.messages return {}
inner_class = inner_class
instance-attributetype_field = 'type'
class-attribute instance-attributetype_field_remove = True
class-attribute instance-attributetypes_overrides = types_overrides
instance-attribute
init(inner_class, types_overrides=None, **kwargs)
Source Code in griptape/schemas/polymorphic_schema.py
def __init__(self, inner_class: Any, types_overrides: Optional[dict[str, type]] = None, **kwargs) -> None: super().__init__(**kwargs) self.inner_class = inner_class self.types_overrides = types_overrides
_dump(obj, *, update_fields=True, **kwargs)
Source Code in griptape/schemas/polymorphic_schema.py
def _dump(self, obj: Any, *, update_fields: bool = True, **kwargs) -> Any: obj_type = self.get_obj_type(obj) if not obj_type: return (None, {"_schema": f"Unknown object class: {obj.__class__.__name__}"}) type_schema = BaseSchema.from_attrs_cls(obj.__class__, types_overrides=self.types_overrides) if not type_schema: return None, {"_schema": f"Unsupported object type: {obj_type}"} schema = type_schema if isinstance(type_schema, Schema) else type_schema() schema.context.update(getattr(self, "context", {})) result = schema.dump(obj, many=False, **kwargs) if result is not None: result[self.type_field] = obj_type # pyright: ignore[reportArgumentType,reportCallIssue] return result
_load(data, *, partial=None, unknown=None, **kwargs)
Source Code in griptape/schemas/polymorphic_schema.py
def _load(self, data: Any, *, partial: Any = None, unknown: Any = None, **kwargs) -> Any: if not isinstance(data, dict): raise ValidationError({"_schema": f"Invalid data type: {data}"}) data = dict(data) unknown = unknown or self.unknown data_type = self.get_data_type(data) if data_type is None: raise ValidationError({self.type_field: ["Missing data for required field."]}) type_schema = self.inner_class.get_schema(data_type, module_name=data.get("module_name")) if not type_schema: raise ValidationError({self.type_field: [f"Unsupported value: {data_type}"]}) schema = type_schema if isinstance(type_schema, Schema) else type_schema() schema.context.update(getattr(self, "context", {})) return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs)
dump(obj, *, many=None, **kwargs)
Source Code in griptape/schemas/polymorphic_schema.py
def dump(self, obj: Any, *, many: Any = None, **kwargs) -> Any: errors = {} result_data = [] result_errors = {} many = self.many if many is None else bool(many) if not many: result = result_data = self._dump(obj, **kwargs) else: for idx, o in enumerate(obj): try: result = self._dump(o, **kwargs) result_data.append(result) except ValidationError as error: result_errors[idx] = error.normalized_messages() result_data.append(error.valid_data) result = result_data errors = result_errors if not errors: return result exc = ValidationError(errors, data=obj, valid_data=result) # pyright: ignore[reportArgumentType] raise exc
get_data_type(data)
Source Code in griptape/schemas/polymorphic_schema.py
def get_data_type(self, data: Any) -> Any: """Returns name of the schema during load() calls, given the data being loaded. Defaults to looking up `type_field` in the data.""" data_type = data.get(self.type_field) if self.type_field in data and self.type_field_remove: data.pop(self.type_field) return data_type
get_obj_type(obj)
Source Code in griptape/schemas/polymorphic_schema.py
def get_obj_type(self, obj: Any) -> Any: """Returns name of the schema during dump() calls, given the object being dumped.""" return obj.__class__.__name__
load(data, *, many=None, partial=None, unknown=None, **kwargs)
Source Code in griptape/schemas/polymorphic_schema.py
def load(self, data: Any, *, many: Any = None, partial: Any = None, unknown: Any = None, **kwargs) -> Any: errors = {} result_data = [] result_errors = {} many = self.many if many is None else bool(many) if partial is None: partial = self.partial if not many: try: result = result_data = self._load(data, partial=partial, unknown=unknown, **kwargs) except ValidationError as error: result_errors = error.normalized_messages() result_data.append(error.valid_data) else: for idx, item in enumerate(data): try: result = self._load(item, partial=partial, **kwargs) result_data.append(result) except ValidationError as error: result_errors[idx] = error.normalized_messages() result_data.append(error.valid_data) result = result_data errors = result_errors if not errors: return result exc = ValidationError(errors, data=data, valid_data=result) raise exc
validate(data, *, many=None, partial=None)
Source Code in griptape/schemas/polymorphic_schema.py
def validate(self, data: Any, *, many: Any = None, partial: Any = None) -> Any: # pyright: ignore[reportIncompatibleMethodOverride] try: self.load(data, many=many, partial=partial) except ValidationError as ve: return ve.messages return {}
PydanticModel
Bases:
FieldSource Code in griptape/schemas/pydantic_model_field.py
class PydanticModel(fields.Field): def _serialize(self, value: Optional[BaseModel], attr: Any, obj: Any, **kwargs) -> Optional[dict]: if value is None: return None return value.model_dump() def _deserialize(self, value: dict, attr: Any, data: Any, **kwargs) -> BaseModel: return RootModel(value)
_deserialize(value, attr, data, **kwargs)
Source Code in griptape/schemas/pydantic_model_field.py
def _deserialize(self, value: dict, attr: Any, data: Any, **kwargs) -> BaseModel: return RootModel(value)
_serialize(value, attr, obj, **kwargs)
Source Code in griptape/schemas/pydantic_model_field.py
def _serialize(self, value: Optional[BaseModel], attr: Any, obj: Any, **kwargs) -> Optional[dict]: if value is None: return None return value.model_dump()
Union
Bases:
FieldParameters
Name | Type | Description | Default |
---|---|---|---|
fields | list[Field] | The list of candidate fields to try. | required |
reverse_serialize_candidates | bool | Whether to try the candidates in reverse order when serializing. | False |
Source Code in griptape/schemas/union_field.py
class Union(marshmallow.fields.Field): """Field that accepts any one of multiple fields. Source: https://github.com/adamboche/python-marshmallow-union Each argument will be tried until one succeeds. Args: fields: The list of candidate fields to try. reverse_serialize_candidates: Whether to try the candidates in reverse order when serializing. """ def __init__( self, fields: list[marshmallow.fields.Field], *, reverse_serialize_candidates: bool = False, **kwargs: Any, ) -> None: self._candidate_fields = fields self._reverse_serialize_candidates = reverse_serialize_candidates super().__init__(**kwargs) def _serialize(self, value: Any, attr: str | None, obj: str, **kwargs: Any) -> Any: """Pulls the value for the given key from the object, applies the field's formatting and returns the result. Args: value: The value to be serialized. attr: The attribute or key to get from the object. obj: The object to pull the key from. kwargs: Field-specific keyword arguments. Raises: marshmallow.exceptions.ValidationError: In case of formatting problem """ error_store = kwargs.pop("error_store", marshmallow.error_store.ErrorStore()) fields = ( list(reversed(self._candidate_fields)) if self._reverse_serialize_candidates else self._candidate_fields ) for candidate_field in fields: try: # pylint: disable=protected-access return candidate_field._serialize(value, attr, obj, error_store=error_store, **kwargs) except (TypeError, ValueError) as e: error_store.store_error({attr: str(e)}) raise ExceptionGroupError("All serializers raised exceptions.", error_store.errors) def _deserialize(self, value: Any, attr: str | None = None, data: Any = None, **kwargs: Any) -> Any: """Deserialize ``value``. Args: value: The value to be deserialized. attr: The attribute/key in `data` to be deserialized. data: The raw input data passed to the `Schema.load`. kwargs: Field-specific keyword arguments. Raises: ValidationError: If an invalid value is passed or if a required value is missing. """ errors = [] for candidate_field in self._candidate_fields: try: return candidate_field.deserialize(value, attr, data, **kwargs) except marshmallow.exceptions.ValidationError as exc: errors.append(exc.messages) raise marshmallow.exceptions.ValidationError(message=errors, field_name=attr or "")
_candidate_fields = fields
instance-attribute_reverse_serialize_candidates = reverse_serialize_candidates
instance-attribute
init(fields, *, reverse_serialize_candidates=False, **kwargs)
Source Code in griptape/schemas/union_field.py
def __init__( self, fields: list[marshmallow.fields.Field], *, reverse_serialize_candidates: bool = False, **kwargs: Any, ) -> None: self._candidate_fields = fields self._reverse_serialize_candidates = reverse_serialize_candidates super().__init__(**kwargs)
_deserialize(value, attr=None, data=None, **kwargs)
Deserialize value
.
Parameters
Name | Type | Description | Default |
---|---|---|---|
value | Any | The value to be deserialized. | required |
attr | str | None | The attribute/key in data to be deserialized. | None |
data | Any | The raw input data passed to the Schema.load . | None |
kwargs | Any | Field-specific keyword arguments. | {} |
Raises
Source Code in griptape/schemas/union_field.py
def _deserialize(self, value: Any, attr: str | None = None, data: Any = None, **kwargs: Any) -> Any: """Deserialize ``value``. Args: value: The value to be deserialized. attr: The attribute/key in `data` to be deserialized. data: The raw input data passed to the `Schema.load`. kwargs: Field-specific keyword arguments. Raises: ValidationError: If an invalid value is passed or if a required value is missing. """ errors = [] for candidate_field in self._candidate_fields: try: return candidate_field.deserialize(value, attr, data, **kwargs) except marshmallow.exceptions.ValidationError as exc: errors.append(exc.messages) raise marshmallow.exceptions.ValidationError(message=errors, field_name=attr or "")
_serialize(value, attr, obj, **kwargs)
Pulls the value for the given key from the object, applies the field's formatting and returns the result.
Parameters
Name | Type | Description | Default |
---|---|---|---|
value | Any | The value to be serialized. | required |
attr | str | None | The attribute or key to get from the object. | required |
obj | str | The object to pull the key from. | required |
kwargs | Any | Field-specific keyword arguments. | {} |
Raises
Source Code in griptape/schemas/union_field.py
def _serialize(self, value: Any, attr: str | None, obj: str, **kwargs: Any) -> Any: """Pulls the value for the given key from the object, applies the field's formatting and returns the result. Args: value: The value to be serialized. attr: The attribute or key to get from the object. obj: The object to pull the key from. kwargs: Field-specific keyword arguments. Raises: marshmallow.exceptions.ValidationError: In case of formatting problem """ error_store = kwargs.pop("error_store", marshmallow.error_store.ErrorStore()) fields = ( list(reversed(self._candidate_fields)) if self._reverse_serialize_candidates else self._candidate_fields ) for candidate_field in fields: try: # pylint: disable=protected-access return candidate_field._serialize(value, attr, obj, error_store=error_store, **kwargs) except (TypeError, ValueError) as e: error_store.store_error({attr: str(e)}) raise ExceptionGroupError("All serializers raised exceptions.", error_store.errors)
- On this page
- Bytes
- PolymorphicSchema
- PydanticModel
- Union
Could this page be better? Report a problem or suggest an addition!