Bases: Structure

Source Code in griptape/structures/agent.py
@define
class Agent(Structure):
    input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field(
        default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""),
    )
    stream: Optional[bool] = field(default=None, kw_only=True)
    prompt_driver: Optional[BasePromptDriver] = field(default=None, kw_only=True)
    output_schema: Optional[Union[Schema, type[BaseModel]]] = field(default=None, kw_only=True)
    tools: list[BaseTool] = field(factory=list, kw_only=True)
    max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True)
    fail_fast: bool = field(default=False, kw_only=True)
    _tasks: list[Union[BaseTask, list[BaseTask]]] = field(
        factory=list, kw_only=True, alias="tasks", metadata={"serializable": True}
    )

    @fail_fast.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None:  # noqa: FBT001
        if fail_fast:
            raise ValueError("Agents cannot fail fast, as they can only have 1 task.")

    @prompt_driver.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
    def validate_prompt_driver(self, _: Attribute, prompt_driver: Optional[BasePromptDriver]) -> None:
        if prompt_driver is not None and self.stream is not None:
            warnings.warn(
                "`Agent.prompt_driver` is set, but `Agent.stream` was provided. `Agent.stream` will be ignored. This will be an error in the future.",
                UserWarning,
                stacklevel=2,
            )

    @_tasks.validator  # pyright: ignore[reportAttributeAccessIssue]
    def validate_tasks(self, _: Attribute, tasks: list) -> None:
        if tasks and self.prompt_driver is not None:
            warnings.warn(
                "`Agent.tasks` is set, but `Agent.prompt_driver` was provided. `Agent.prompt_driver` will be ignored. This will be an error in the future.",
                UserWarning,
                stacklevel=2,
            )

    def __attrs_post_init__(self) -> None:
        super().__attrs_post_init__()

        if len(self.tasks) == 0:
            self._init_task()

    @property
    def task(self) -> BaseTask:
        return self.tasks[0]

    def add_task(self, task: BaseTask) -> BaseTask:
        self._tasks.clear()

        task.preprocess(self)

        self._tasks.append(task)

        return task

    def add_tasks(self, *tasks: BaseTask | list[BaseTask]) -> list[BaseTask]:
        if len(tasks) > 1:
            raise ValueError("Agents can only have one task.")
        return super().add_tasks(*tasks)

    @observable
    def try_run(self, *args) -> Agent:
        self.task.run()

        return self

    def _init_task(self) -> None:
        if self.stream is None:
            with validators.disabled():
                self.stream = Defaults.drivers_config.prompt_driver.stream

        if self.prompt_driver is None:
            with validators.disabled():
                prompt_driver = evolve(Defaults.drivers_config.prompt_driver, stream=self.stream)
                self.prompt_driver = prompt_driver
        else:
            prompt_driver = self.prompt_driver

        task = PromptTask(
            self.input,
            prompt_driver=prompt_driver,
            tools=self.tools,
            output_schema=self.output_schema,
            max_meta_memory_entries=self.max_meta_memory_entries,
        )

        self.add_task(task)
  • _tasks = field(factory=list, kw_only=True, alias='tasks', metadata={'serializable': True}) class-attribute instance-attribute

  • fail_fast = field(default=False, kw_only=True) class-attribute instance-attribute

  • input = field(default=lambda task: task.full_context['args'][0] if task.full_context['args'] else TextArtifact(value='')) class-attribute instance-attribute

  • max_meta_memory_entries = field(default=20, kw_only=True) class-attribute instance-attribute

  • output_schema = field(default=None, kw_only=True) class-attribute instance-attribute

  • prompt_driver = field(default=None, kw_only=True) class-attribute instance-attribute

  • stream = field(default=None, kw_only=True) class-attribute instance-attribute

  • task property

  • tools = field(factory=list, kw_only=True) class-attribute instance-attribute

attrs_post_init()

Source Code in griptape/structures/agent.py
def __attrs_post_init__(self) -> None:
    super().__attrs_post_init__()

    if len(self.tasks) == 0:
        self._init_task()

_init_task()

Source Code in griptape/structures/agent.py
def _init_task(self) -> None:
    if self.stream is None:
        with validators.disabled():
            self.stream = Defaults.drivers_config.prompt_driver.stream

    if self.prompt_driver is None:
        with validators.disabled():
            prompt_driver = evolve(Defaults.drivers_config.prompt_driver, stream=self.stream)
            self.prompt_driver = prompt_driver
    else:
        prompt_driver = self.prompt_driver

    task = PromptTask(
        self.input,
        prompt_driver=prompt_driver,
        tools=self.tools,
        output_schema=self.output_schema,
        max_meta_memory_entries=self.max_meta_memory_entries,
    )

    self.add_task(task)

add_task(task)

Source Code in griptape/structures/agent.py
def add_task(self, task: BaseTask) -> BaseTask:
    self._tasks.clear()

    task.preprocess(self)

    self._tasks.append(task)

    return task

add_tasks(*tasks)

Source Code in griptape/structures/agent.py
def add_tasks(self, *tasks: BaseTask | list[BaseTask]) -> list[BaseTask]:
    if len(tasks) > 1:
        raise ValueError("Agents can only have one task.")
    return super().add_tasks(*tasks)

try_run(*args)

Source Code in griptape/structures/agent.py
@observable
def try_run(self, *args) -> Agent:
    self.task.run()

    return self

validatefail_fast(, fail_fast)

Source Code in griptape/structures/agent.py
@fail_fast.validator  # pyright: ignore[reportAttributeAccessIssue]
def validate_fail_fast(self, _: Attribute, fail_fast: bool) -> None:  # noqa: FBT001
    if fail_fast:
        raise ValueError("Agents cannot fail fast, as they can only have 1 task.")

validateprompt_driver(, prompt_driver)

Source Code in griptape/structures/agent.py
@prompt_driver.validator  # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess]
def validate_prompt_driver(self, _: Attribute, prompt_driver: Optional[BasePromptDriver]) -> None:
    if prompt_driver is not None and self.stream is not None:
        warnings.warn(
            "`Agent.prompt_driver` is set, but `Agent.stream` was provided. `Agent.stream` will be ignored. This will be an error in the future.",
            UserWarning,
            stacklevel=2,
        )

validatetasks(, tasks)

Source Code in griptape/structures/agent.py
@_tasks.validator  # pyright: ignore[reportAttributeAccessIssue]
def validate_tasks(self, _: Attribute, tasks: list) -> None:
    if tasks and self.prompt_driver is not None:
        warnings.warn(
            "`Agent.tasks` is set, but `Agent.prompt_driver` was provided. `Agent.prompt_driver` will be ignored. This will be an error in the future.",
            UserWarning,
            stacklevel=2,
        )

Could this page be better? Report a problem or suggest an addition!