Module llmsearch.tuner

Core Tuner Class that adapts a PyTorch model to function as a scikit-learn estimator.

Sub-modules

llmsearch.tuner.tuner

Main Tuner Module containing LLMEstimatorWrapper Class & Tuner Class for scikit-learn

Classes

class Tuner (model, tokenizer, dataset, column_mapping, scorer, device, prompt_template=None, sample_preprocessor=None, tokenizer_encode_args=None, tokenizer_decode_args=None, batch_size=16, output_preproc=<function Tuner.<lambda>>, callbacks_after_inference=None, is_encoder_decoder=False, greater_is_better=True, seed=42, disable_batch_size_cache=False, disable_generation_param_checks=False, sample_ratio=0.3, tokenizer_max_length_quantile=0.9, custom_pred_function=None)

Tuner Class which drives the search for the generation hyperparameters

Initializes the Tuner Class, populates the dataset with _X & _y keys for the model to use.

Args

model : nn.Module
model that has a .generate method
tokenizer : AutoTokenizer
tokenizer for the input
dataset : Dataset
dataset to perform search on
column_mapping : Dict[str, list]
should contain input_cols & eval_cols keys, input_cols should contain the columns to be used in the prompt_template & eval_cols should contain the columns to be used in the scorer,all eval_columns will be passed in as a dict as the second argument to the scorer function, eg - {'input_cols' : ["question"], 'eval_cols' : ['answer']}
scorer : Callable
A function that has this signature - (y_true: List, y_pred: List) -> float , takes in ground truth and predictions are returns a metric to optimize on, eval_cols in column_mapping are passed in as the second argument as a List[Dict]
device : str
device to run inference on, eg - cuda:0
prompt_template : str
template with placeholders for input_cols from the column_mapping argument, eg - "Question : How many days are there in a year?\nAnswer : 365\n\nQuestion : {question}\nAnswer : ", not used when sample_preprocessor is not None
sample_preprocessor : Callable
Preprocessor function for a single example from the dataset, should have the signature - (tokenizer, **kwargs) -> str, where key word arguments are the columns from input_cols & eval_cols in column_mapping, not used when prompt_template is not None
tokenizer_encode_args : Dict, optional
Encoding key value arguments for the tokenizer. If None it's initialized using the get_default_input_tokenizer_kwargs method. Defaults to None.
tokenizer_decode_args : Dict, optional
Decoding key value arguments for the tokenizer. Defaults to {'skip_special_tokens' : True}.
batch_size : int, optional
batch_size to run inference with, this gets dynamically halfed if the inference function encounters OOM errors. Defaults to 16.
output_preproc : Callable, optional
Post processing function for the output, by default it strips the output. Defaults to lambda x : x.strip().
callbacks_after_inference : List[Callable], optional
Callbacks to run after each inference. Useful for stopping criteria in generation. Defaults to None.
is_encoder_decoder : bool, optional
whether the model is an encoder-decoder model, False if not. Defaults to False.
greater_is_better : bool, optional
whether the metric to optimize on is greater the better. Defaults to True.
seed : int, optional
seed for reproducibility. Defaults to 42.
disable_batch_size_cache : bool, optional
If True for each cross validation run, the pre-defined batch_size is used, this could lead to wasted computation time if OOM is raised by the inference function. Defaults to False.
disable_generation_param_checks : bool, optional
Disables the custom generation parameter checks, this does a sanity check of the parameters & produces warnings before doing generation. Defaults to False.
sample_ratio : float, optional
Sampling Ratio of dataset to find the ideal values for padding and truncation. Argument is invalid if tokenizer_encode_args is not None. Defaults to 0.3.
tokenizer_max_length_quantile : float, optional
percentile to find a value for max_length based on the dataset. Defaults to 0.9.
custom_pred_function : Union[Callable, None], optional
Override inference function if present. Defaults to None. Should take in two parameters - model inputs (List[str]) and model generation parameters (Dict) and return a List[str] of outputs overrides model_utils.run_inference

Methods

def get_default_tokenizer_encode_args(self, sample_ratio, tokenizer_encode_args)

Get default input tokenizer arguments using the dataset, Sets padding & truncation to True and calculate max_length as tokenizer arguments.

Args

sample_ratio : float
Sampling Ratio of dataset to find the ideal values for padding and truncation. Argument is invalid if tokenizer_encode_args is not None.
tokenizer_length_percentile : float
percentile to find a value for max_length based on the dataset.
tokenizer_kwargs : Union[Dict, None]
Encoding key value arguments for the tokenizer. Returns the same value if this is not None. Defaults to None.
def get_score(self, generation_args, dataset=None)

Evaluate the score function on a dataset or the initialized dataset using some generation arguments for the model. If dataset is None the initialized dataset is used, else the dataset is preprocessed(_X & _y are populated) and used.

Args

generation_args : Dict
generation kwargs to perform inference
dataset : Union[Dataset, Dict], optional
dataset to perform inference on. Defaults to None.

Returns

Tuple[float, List]
score, predictions
def get_value_at_quantile(self, input_list, quantile)

Get value at a specific quantile

Args

input_list : List[str]
list of str on which to run the encoding of the tokenizer
quantile : float
quantile on which to find the value on.

Returns

int
rounded value at quantile
def preprocess_dataset(self, dataset)

Dataset preprocessor, preprocesses using the prompt_template or sample_preprocessor function.

self.prompt_template - Useful for already processed datasets(text can be directly fed into the model)

self.sample_preprocessor - Useful for datasets that need to be preprocessed(converting into chat format) before feeding into the model

Adds _X & _y keys to the dataset

Note : datasets.map is not used and traditional map has been used to map the dataset, as datasets.map has memory related issue. TODO :Issue to be raised in the datasets repo.

Args

dataset : Dataset
dataset to preprocess