huggingface_hub_prompt_driver
logger = logging.getLogger(Defaults.logging_config.logger_name)
module-attribute
Bases:
BasePromptDriver
Attributes
Name | Type | Description |
---|---|---|
api_token | str | Hugging Face Hub API token. |
use_gpu | str | Use GPU during model run. |
model | str | Hugging Face Hub model name. |
client | InferenceClient | Custom InferenceApi . |
tokenizer | HuggingFaceTokenizer | Custom HuggingFaceTokenizer . |
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@define class HuggingFaceHubPromptDriver(BasePromptDriver): """Hugging Face Hub Prompt Driver. Attributes: api_token: Hugging Face Hub API token. use_gpu: Use GPU during model run. model: Hugging Face Hub model name. client: Custom `InferenceApi`. tokenizer: Custom `HuggingFaceTokenizer`. """ api_token: str = field(kw_only=True, metadata={"serializable": True}) max_tokens: int = field(default=250, kw_only=True, metadata={"serializable": True}) model: str = field(kw_only=True, metadata={"serializable": True}) structured_output_strategy: StructuredOutputStrategy = field( default="native", kw_only=True, metadata={"serializable": True} ) tokenizer: HuggingFaceTokenizer = field( default=Factory( lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True, ), kw_only=True, ) _client: Optional[InferenceClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() def client(self) -> InferenceClient: return import_optional_dependency("huggingface_hub").InferenceClient( model=self.model, token=self.api_token, ) @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "tool": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value @observable def try_run(self, prompt_stack: PromptStack) -> Message: prompt = self.prompt_stack_to_string(prompt_stack) full_params = self._base_params(prompt_stack) logger.debug( { "prompt": prompt, **full_params, } ) response = self.client.text_generation( prompt, **full_params, ) logger.debug(response) input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(response)) # pyright: ignore[reportArgumentType] return Message( content=response, role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: prompt = self.prompt_stack_to_string(prompt_stack) full_params = {**self._base_params(prompt_stack), "stream": True} logger.debug( { "prompt": prompt, **full_params, } ) response = self.client.text_generation(prompt, **full_params) input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) full_text = "" for token in response: logger.debug(token) full_text += token yield DeltaMessage(content=TextDeltaMessageContent(token, index=0)) output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) # pyright: ignore[reportArgumentType] yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens)) def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) # pyright: ignore[reportArgumentType] def _base_params(self, prompt_stack: PromptStack) -> dict: params = { "return_full_text": False, "max_new_tokens": self.max_tokens, **self.extra_params, } if prompt_stack.output_schema and self.structured_output_strategy == "native": # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.to_output_json_schema() # Grammar does not support $schema and $id del output_schema["$schema"] del output_schema["$id"] params["grammar"] = {"type": "json", "value": output_schema} return params def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: if len(message.content) == 1: messages.append({"role": message.role, "content": message.to_text()}) else: raise ValueError("Invalid input content length.") return messages def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): return tokens # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int]. raise ValueError("Invalid output type.")
_client = field(default=None, kw_only=True, alias='client', metadata={'serializable': False})
class-attribute instance-attributeapi_token = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemax_tokens = field(default=250, kw_only=True, metadata={'serializable': True})
class-attribute instance-attributemodel = field(kw_only=True, metadata={'serializable': True})
class-attribute instance-attributestructured_output_strategy = field(default='native', kw_only=True, metadata={'serializable': True})
class-attribute instance-attributetokenizer = field(default=Factory(lambda self: HuggingFaceTokenizer(model=self.model, max_output_tokens=self.max_tokens), takes_self=True), kw_only=True)
class-attribute instance-attribute
__prompt_stack_to_tokens(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def __prompt_stack_to_tokens(self, prompt_stack: PromptStack) -> list[int]: messages = self._prompt_stack_to_messages(prompt_stack) tokens = self.tokenizer.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=True) if isinstance(tokens, list): return tokens # pyright: ignore[reportReturnType] According to the [docs](https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template), the return type is List[int]. raise ValueError("Invalid output type.")
_base_params(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def _base_params(self, prompt_stack: PromptStack) -> dict: params = { "return_full_text": False, "max_new_tokens": self.max_tokens, **self.extra_params, } if prompt_stack.output_schema and self.structured_output_strategy == "native": # https://huggingface.co/learn/cookbook/en/structured_generation#-constrained-decoding output_schema = prompt_stack.to_output_json_schema() # Grammar does not support $schema and $id del output_schema["$schema"] del output_schema["$id"] params["grammar"] = {"type": "json", "value": output_schema} return params
_prompt_stack_to_messages(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def _prompt_stack_to_messages(self, prompt_stack: PromptStack) -> list[dict]: messages = [] for message in prompt_stack.messages: if len(message.content) == 1: messages.append({"role": message.role, "content": message.to_text()}) else: raise ValueError("Invalid input content length.") return messages
client()
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@lazy_property() def client(self) -> InferenceClient: return import_optional_dependency("huggingface_hub").InferenceClient( model=self.model, token=self.api_token, )
prompt_stack_to_string(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
def prompt_stack_to_string(self, prompt_stack: PromptStack) -> str: return self.tokenizer.tokenizer.decode(self.__prompt_stack_to_tokens(prompt_stack)) # pyright: ignore[reportArgumentType]
try_run(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@observable def try_run(self, prompt_stack: PromptStack) -> Message: prompt = self.prompt_stack_to_string(prompt_stack) full_params = self._base_params(prompt_stack) logger.debug( { "prompt": prompt, **full_params, } ) response = self.client.text_generation( prompt, **full_params, ) logger.debug(response) input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) output_tokens = len(self.tokenizer.tokenizer.encode(response)) # pyright: ignore[reportArgumentType] return Message( content=response, role=Message.ASSISTANT_ROLE, usage=Message.Usage(input_tokens=input_tokens, output_tokens=output_tokens), )
try_stream(prompt_stack)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: prompt = self.prompt_stack_to_string(prompt_stack) full_params = {**self._base_params(prompt_stack), "stream": True} logger.debug( { "prompt": prompt, **full_params, } ) response = self.client.text_generation(prompt, **full_params) input_tokens = len(self.__prompt_stack_to_tokens(prompt_stack)) full_text = "" for token in response: logger.debug(token) full_text += token yield DeltaMessage(content=TextDeltaMessageContent(token, index=0)) output_tokens = len(self.tokenizer.tokenizer.encode(full_text)) # pyright: ignore[reportArgumentType] yield DeltaMessage(usage=DeltaMessage.Usage(input_tokens=input_tokens, output_tokens=output_tokens))
validatestructured_output_strategy(, value)
Source Code in griptape/drivers/prompt/huggingface_hub_prompt_driver.py
@structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: if value == "tool": raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") return value
Could this page be better? Report a problem or suggest an addition!