Skip to content

model_base

This module contains the base Model class

Model

Parent model class.

Source code in happy_vllm/model/model_base.py
class Model:
    """Parent model class.
    """

    def __init__(self, **kwargs):
        '''Init. model class'''
        self._model = None
        self._tokenizer = None
        self._model_conf = None
        self._model_explainer = None
        self.openai_serving_chat = None
        self.openai_serving_completion = None
        self.openai_serving_tokenization = None
        self._loaded = False
        self.app_name = kwargs.get('app_name', "happy_vllm")

    def is_model_loaded(self):
        """return the state of the model"""
        return self._loaded

    async def loading(self, async_engine_client: AsyncEngineRPCClient, args: Namespace, **kwargs):
        """load the model"""
        await self._load_model(async_engine_client, args, **kwargs)
        self._loaded = True
        if args.with_launch_arguments:
            self.launch_arguments = vars(args)
        else:
            self.launch_arguments = {}

    async def _load_model(self, async_engine_client: AsyncEngineRPCClient, args: Namespace, **kwargs) -> None:
        """Load a model from a file

        Returns:
            Tuple[Any, dict]: A tuple containing the model and a dict of metadata about it.
        """

        self._model_conf = {'model_name': args.model_name}

        logger.info(f"Loading the model from {args.model}")
        if args.model_name != "TEST MODEL":
            self._model = async_engine_client
            model_config = await self._model.get_model_config()
            # Define the tokenizer differently if we have an AsyncLLMEngine
            if isinstance(self._model, AsyncLLMEngine):
                tokenizer_tmp = self._model.engine.tokenizer
            else:
                tokenizer_tmp = self._model.tokenizer
            if isinstance(tokenizer_tmp, TokenizerGroup): # type: ignore
                self._tokenizer = tokenizer_tmp.tokenizer # type: ignore
            else:
                self._tokenizer = tokenizer_tmp # type: ignore
            self._tokenizer_lmformatenforcer = build_token_enforcer_tokenizer_data(self._tokenizer)
            self.max_model_len = model_config.max_model_len # type: ignore
            # To take into account Mistral tokenizers
            try:
                self.original_truncation_side = self._tokenizer.truncation_side # type: ignore
            except:
                self.original_truncation_side = "left"
            if args.disable_log_requests:
                request_logger = None
            else:
                request_logger = RequestLogger(max_log_len=args.max_log_len)
            self.openai_serving_chat = OpenAIServingChat(cast(AsyncEngineClient, self._model), model_config, [args.model_name],
                                                        args.response_role,
                                                        lora_modules=args.lora_modules,
                                                        prompt_adapters=args.prompt_adapters,
                                                        request_logger=request_logger,
                                                        chat_template=args.chat_template,
                                                        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
                                                        enable_auto_tools=args.enable_auto_tool_choice,
                                                        tool_parser=args.tool_call_parser)
            self.openai_serving_completion = OpenAIServingCompletion(cast(AsyncEngineClient, self._model), model_config, [args.model_name], 
                                                                    lora_modules=args.lora_modules,
                                                                    prompt_adapters=args.prompt_adapters,
                                                                    request_logger=request_logger,
                                                                    return_tokens_as_token_ids=args.return_tokens_as_token_ids)
            self.openai_serving_tokenization  = OpenAIServingTokenization(cast(AsyncEngineClient, self._model), model_config, [args.model_name],
                                                                        lora_modules=args.lora_modules,
                                                                        request_logger=request_logger,
                                                                        chat_template=args.chat_template)

        # For test purpose
        else:
            self.max_model_len = 2048
            self.original_truncation_side = 'right'
            self._tokenizer = AutoTokenizer.from_pretrained(utils.TEST_TOKENIZER_NAME,
                                                     cache_dir=os.environ["TEST_MODELS_DIR"], truncation_side=self.original_truncation_side)
            self._tokenizer_lmformatenforcer = build_token_enforcer_tokenizer_data(self._tokenizer)
            self._model = MockModel(self._tokenizer, self.app_name)
            self.openai_serving_tokenization = MockOpenAIServingTokenization(self._tokenizer)
        logger.info(f"Model loaded")

    def tokenize(self, text: str) -> List[int]:
        """Tokenizes a text

        Args:
            text (str) : The text to tokenize

        Returns:
            list : The list of token ids
        """
        return list(utils.proper_tokenization(self._tokenizer, text))

    def split_text(self, text: str, num_tokens_in_chunk: int = 200, separators: Union[list, None] = None) -> List[str]:
        '''Splits a text in small texts containing at least num_tokens_in_chunk tokens and ending by a separator. note that the `separators` 
        used are the tokenization of the strings and not the strings themselves (which explains why we must for example 
        specify ' .' and '.' as two separate separators) 

        Args:
            text (str) : The text to split

        Kwargs:
            num_tokens_in_chunk (int) : The minimal number of tokens in the chunk
            separators (list) : The separators marking the end of a sentence

        Returns:
            A list of strings each string containing at least num_tokens_in_chunk tokens and ending by a separator
        '''
        if separators is None:
            separators = [".", "!", "?", "|", " .", " !", " ?", " |"]
        separators_tokens_ids = set()
        for separator in separators:
            separators_tokens_ids.add(utils.proper_tokenization(self._tokenizer, separator))
        tokens = list(utils.proper_tokenization(self._tokenizer, text))
        indices_separators = []
        for separator_tokens_ids in separators_tokens_ids:
            indices_separators += find_indices_sub_list_in_list(tokens, list(separator_tokens_ids))
        indices_separators.sort()

        chunks = []
        index_beginning_chunk = 0
        current_used_separator = 0
        while current_used_separator < len(indices_separators):
            index_current_used_separator = indices_separators[current_used_separator]
            if index_current_used_separator +1 - index_beginning_chunk >= num_tokens_in_chunk:
                chunks.append(tokens[index_beginning_chunk:index_current_used_separator + 1])
                index_beginning_chunk = index_current_used_separator + 1
            current_used_separator += 1
        chunks.append(tokens[index_beginning_chunk:])
        chunks = [utils.proper_decode(self._tokenizer, chunk) for chunk in chunks]
        chunks = [element for element in chunks if element!= ""]
        return chunks

    def extract_text_outside_truncation(self, text: str, truncation_side: Union[str, None] = None, max_length: Union[int, None] = None) -> str:
        """Extracts the part of the prompt not kept after truncation, which will not be infered by the model.
        First, we tokenize the prompt while applying truncation.
        We obtain a list of sequences of token ids padded, which are outside the truncation.
        Then we decode this list of tensors of token IDs containing special tokens to a string.

        Args:
            text (str) : The text we want to parse
            truncation_side (str) : The side of the truncation
            max_length (int) : The length above which the text will be truncated

        Returns:
            The part of the text which will be dropped by the truncation (str)
        """
        if max_length is None:
            max_length = self.max_model_len
        if truncation_side is None:
            truncation_side = self.original_truncation_side
        self._tokenizer.truncation_side = truncation_side
        list_tokens = self._tokenizer(text, truncation=True, add_special_tokens=False, max_length=max_length, return_overflowing_tokens=True)['input_ids']
        if len(list_tokens) <= 1:
            return ''
        not_truncated = list_tokens[0]
        truncated_tmp = list_tokens[1:]
        if self._tokenizer.truncation_side == 'left':
            truncated_tmp.reverse()
        truncated = []
        for truncated_tokens in truncated_tmp:
            truncated += truncated_tokens
        truncated_str = self._tokenizer.decode(truncated)
        self._tokenizer.truncation_side = self.original_truncation_side
        return truncated_str

__init__(**kwargs)

Init. model class

Source code in happy_vllm/model/model_base.py
def __init__(self, **kwargs):
    '''Init. model class'''
    self._model = None
    self._tokenizer = None
    self._model_conf = None
    self._model_explainer = None
    self.openai_serving_chat = None
    self.openai_serving_completion = None
    self.openai_serving_tokenization = None
    self._loaded = False
    self.app_name = kwargs.get('app_name', "happy_vllm")

extract_text_outside_truncation(text, truncation_side=None, max_length=None)

Extracts the part of the prompt not kept after truncation, which will not be infered by the model. First, we tokenize the prompt while applying truncation. We obtain a list of sequences of token ids padded, which are outside the truncation. Then we decode this list of tensors of token IDs containing special tokens to a string.

Parameters:

Name Type Description Default
text str)

The text we want to parse

required
truncation_side str)

The side of the truncation

None
max_length int)

The length above which the text will be truncated

None

Returns:

Type Description
str

The part of the text which will be dropped by the truncation (str)

Source code in happy_vllm/model/model_base.py
def extract_text_outside_truncation(self, text: str, truncation_side: Union[str, None] = None, max_length: Union[int, None] = None) -> str:
    """Extracts the part of the prompt not kept after truncation, which will not be infered by the model.
    First, we tokenize the prompt while applying truncation.
    We obtain a list of sequences of token ids padded, which are outside the truncation.
    Then we decode this list of tensors of token IDs containing special tokens to a string.

    Args:
        text (str) : The text we want to parse
        truncation_side (str) : The side of the truncation
        max_length (int) : The length above which the text will be truncated

    Returns:
        The part of the text which will be dropped by the truncation (str)
    """
    if max_length is None:
        max_length = self.max_model_len
    if truncation_side is None:
        truncation_side = self.original_truncation_side
    self._tokenizer.truncation_side = truncation_side
    list_tokens = self._tokenizer(text, truncation=True, add_special_tokens=False, max_length=max_length, return_overflowing_tokens=True)['input_ids']
    if len(list_tokens) <= 1:
        return ''
    not_truncated = list_tokens[0]
    truncated_tmp = list_tokens[1:]
    if self._tokenizer.truncation_side == 'left':
        truncated_tmp.reverse()
    truncated = []
    for truncated_tokens in truncated_tmp:
        truncated += truncated_tokens
    truncated_str = self._tokenizer.decode(truncated)
    self._tokenizer.truncation_side = self.original_truncation_side
    return truncated_str

is_model_loaded()

return the state of the model

Source code in happy_vllm/model/model_base.py
def is_model_loaded(self):
    """return the state of the model"""
    return self._loaded

loading(async_engine_client, args, **kwargs) async

load the model

Source code in happy_vllm/model/model_base.py
async def loading(self, async_engine_client: AsyncEngineRPCClient, args: Namespace, **kwargs):
    """load the model"""
    await self._load_model(async_engine_client, args, **kwargs)
    self._loaded = True
    if args.with_launch_arguments:
        self.launch_arguments = vars(args)
    else:
        self.launch_arguments = {}

split_text(text, num_tokens_in_chunk=200, separators=None)

Splits a text in small texts containing at least num_tokens_in_chunk tokens and ending by a separator. note that the separators used are the tokenization of the strings and not the strings themselves (which explains why we must for example specify ' .' and '.' as two separate separators)

Parameters:

Name Type Description Default
text str)

The text to split

required
Kwargs

num_tokens_in_chunk (int) : The minimal number of tokens in the chunk separators (list) : The separators marking the end of a sentence

Returns:

Type Description
List[str]

A list of strings each string containing at least num_tokens_in_chunk tokens and ending by a separator

Source code in happy_vllm/model/model_base.py
def split_text(self, text: str, num_tokens_in_chunk: int = 200, separators: Union[list, None] = None) -> List[str]:
    '''Splits a text in small texts containing at least num_tokens_in_chunk tokens and ending by a separator. note that the `separators` 
    used are the tokenization of the strings and not the strings themselves (which explains why we must for example 
    specify ' .' and '.' as two separate separators) 

    Args:
        text (str) : The text to split

    Kwargs:
        num_tokens_in_chunk (int) : The minimal number of tokens in the chunk
        separators (list) : The separators marking the end of a sentence

    Returns:
        A list of strings each string containing at least num_tokens_in_chunk tokens and ending by a separator
    '''
    if separators is None:
        separators = [".", "!", "?", "|", " .", " !", " ?", " |"]
    separators_tokens_ids = set()
    for separator in separators:
        separators_tokens_ids.add(utils.proper_tokenization(self._tokenizer, separator))
    tokens = list(utils.proper_tokenization(self._tokenizer, text))
    indices_separators = []
    for separator_tokens_ids in separators_tokens_ids:
        indices_separators += find_indices_sub_list_in_list(tokens, list(separator_tokens_ids))
    indices_separators.sort()

    chunks = []
    index_beginning_chunk = 0
    current_used_separator = 0
    while current_used_separator < len(indices_separators):
        index_current_used_separator = indices_separators[current_used_separator]
        if index_current_used_separator +1 - index_beginning_chunk >= num_tokens_in_chunk:
            chunks.append(tokens[index_beginning_chunk:index_current_used_separator + 1])
            index_beginning_chunk = index_current_used_separator + 1
        current_used_separator += 1
    chunks.append(tokens[index_beginning_chunk:])
    chunks = [utils.proper_decode(self._tokenizer, chunk) for chunk in chunks]
    chunks = [element for element in chunks if element!= ""]
    return chunks

tokenize(text)

Tokenizes a text

Parameters:

Name Type Description Default
text str)

The text to tokenize

required

Returns:

Name Type Description
list List[int]

The list of token ids

Source code in happy_vllm/model/model_base.py
def tokenize(self, text: str) -> List[int]:
    """Tokenizes a text

    Args:
        text (str) : The text to tokenize

    Returns:
        list : The list of token ids
    """
    return list(utils.proper_tokenization(self._tokenizer, text))

find_indices_sub_list_in_list(big_list, sub_list)

Find the indices of the presence of a sub list in a bigger list. For example if big_list = [3, 4, 1, 2, 3, 4, 5, 6, 3, 4] and sub_list = [3, 4], the result will be [1, 5, 9]

Parameters:

Name Type Description Default
big_list list)

The list in which we want to find the sub_list

required
sub_list list

The list we want the indices of in the big_list

required

Returns:

Name Type Description
list list

The list of indices of where the sub_list is in the big_list

Source code in happy_vllm/model/model_base.py
def find_indices_sub_list_in_list(big_list: list, sub_list: list) -> list:
    """Find the indices of the presence of a sub list in a bigger list. For example
    if big_list = [3, 4, 1, 2, 3, 4, 5, 6, 3, 4] and sub_list = [3, 4],
    the result will be [1, 5, 9]

    Args:
        big_list (list) : The list in which we want to find the sub_list
        sub_list (list): The list we want the indices of in the big_list

    Returns:
        list : The list of indices of where the sub_list is in the big_list 
    """
    len_sub_list = len(sub_list)
    indices = []
    for index in range(len(big_list)):
        if big_list[index - len_sub_list + 1: index + 1] == sub_list:
            indices.append(index)
    return indices