Module llmsearch.tuner
Core Tuner Class that adapts a PyTorch model to function as a scikit-learn estimator.
Sub-modules
llmsearch.tuner.tuner
-
Main Tuner Module containing
LLMEstimatorWrapper
Class &Tuner
Class for scikit-learn
Classes
class Tuner (model, tokenizer, dataset, column_mapping, scorer, device, prompt_template=None, sample_preprocessor=None, tokenizer_encode_args=None, tokenizer_decode_args=None, batch_size=16, output_preproc=<function Tuner.<lambda>>, callbacks_after_inference=None, is_encoder_decoder=False, greater_is_better=True, seed=42, disable_batch_size_cache=False, disable_generation_param_checks=False, sample_ratio=0.3, tokenizer_max_length_quantile=0.9, custom_pred_function=None)
-
Tuner Class which drives the search for the generation hyperparameters
Initializes the Tuner Class, populates the dataset with
_X
&_y
keys for the model to use.Args
model
:nn.Module
- model that has a
.generate
method tokenizer
:AutoTokenizer
- tokenizer for the input
dataset
:Dataset
- dataset to perform search on
column_mapping
:Dict[str, list]
- should contain
input_cols
&eval_cols
keys,input_cols
should contain the columns to be used in theprompt_template
&eval_cols
should contain the columns to be used in thescorer
,all eval_columns will be passed in as a dict as the second argument to thescorer
function, eg -{'input_cols' : ["question"], 'eval_cols' : ['answer']}
scorer
:Callable
- A function that has this signature -
(y_true: List, y_pred: List) -> float
, takes in ground truth and predictions are returns a metric to optimize on,eval_cols
incolumn_mapping
are passed in as the second argument as aList[Dict]
device
:str
- device to run inference on, eg -
cuda:0
prompt_template
:str
- template with placeholders for
input_cols
from thecolumn_mapping
argument, eg -"Question : How many days are there in a year?\nAnswer : 365\n\nQuestion : {question}\nAnswer : "
, not used whensample_preprocessor
is notNone
sample_preprocessor
:Callable
- Preprocessor function for a single example from the
dataset
, should have the signature -(tokenizer, **kwargs) -> str
, where key word arguments are the columns frominput_cols
&eval_cols
incolumn_mapping
, not used whenprompt_template
is notNone
tokenizer_encode_args
:Dict
, optional- Encoding key value arguments for the
tokenizer
. IfNone
it's initialized using theget_default_input_tokenizer_kwargs
method. Defaults toNone
. tokenizer_decode_args
:Dict
, optional- Decoding key value arguments for the
tokenizer
. Defaults to{'skip_special_tokens' : True}
. batch_size
:int
, optional- batch_size to run inference with, this gets dynamically halfed if the inference function encounters OOM errors. Defaults to
16
. output_preproc
:Callable
, optional- Post processing function for the output, by default it strips the output. Defaults to
lambda x : x.strip()
. callbacks_after_inference
:List[Callable]
, optional- Callbacks to run after each inference. Useful for stopping criteria in generation. Defaults to
None
. is_encoder_decoder
:bool
, optional- whether the model is an encoder-decoder model,
False
if not. Defaults toFalse
. greater_is_better
:bool
, optional- whether the metric to optimize on is greater the better. Defaults to
True
. seed
:int
, optional- seed for reproducibility. Defaults to
42
. disable_batch_size_cache
:bool
, optional- If
True
for each cross validation run, the pre-definedbatch_size
is used, this could lead to wasted computation time if OOM is raised by the inference function. Defaults toFalse
. disable_generation_param_checks
:bool
, optional- Disables the custom generation parameter checks, this does a sanity check of the parameters & produces warnings before doing generation. Defaults to
False
. sample_ratio
:float
, optional- Sampling Ratio of
dataset
to find the ideal values for padding and truncation. Argument is invalid iftokenizer_encode_args
is notNone
. Defaults to0.3
. tokenizer_max_length_quantile
:float
, optional- percentile to find a value for
max_length
based on the dataset. Defaults to0.9
. custom_pred_function
:Union[Callable, None]
, optional- Override inference function if present. Defaults to
None
. Should take in two parameters - model inputs (List[str]
) and model generation parameters (Dict
) and return aList[str]
of outputs overridesmodel_utils.run_inference
Methods
def get_default_tokenizer_encode_args(self, sample_ratio, tokenizer_encode_args)
-
Get default input tokenizer arguments using the dataset, Sets
padding
&truncation
toTrue
and calculatemax_length
as tokenizer arguments.Args
sample_ratio
:float
- Sampling Ratio of
dataset
to find the ideal values for padding and truncation. Argument is invalid iftokenizer_encode_args
is notNone
. tokenizer_length_percentile
:float
- percentile to find a value for
max_length
based on the dataset. tokenizer_kwargs
:Union[Dict, None]
- Encoding key value arguments for the
tokenizer
. Returns the same value if this is notNone
. Defaults toNone
.
def get_score(self, generation_args, dataset=None)
-
Evaluate the score function on a dataset or the initialized dataset using some generation arguments for the model. If
dataset
isNone
the initialized dataset is used, else thedataset
is preprocessed(_X
&_y
are populated) and used.Args
generation_args
:Dict
- generation kwargs to perform inference
dataset
:Union[Dataset, Dict]
, optional- dataset to perform inference on. Defaults to None.
Returns
Tuple[float, List]
- score, predictions
def get_value_at_quantile(self, input_list, quantile)
-
Get value at a specific quantile
Args
input_list
:List[str]
- list of str on which to run the encoding of the
tokenizer
quantile
:float
- quantile on which to find the value on.
Returns
int
- rounded value at quantile
def preprocess_dataset(self, dataset)
-
Dataset preprocessor, preprocesses using the
prompt_template
orsample_preprocessor
function.self.prompt_template
- Useful for already processed datasets(text can be directly fed into the model)self.sample_preprocessor
- Useful for datasets that need to be preprocessed(converting into chat format) before feeding into the modelAdds
_X
&_y
keys to the datasetNote : datasets.map is not used and traditional map has been used to map the dataset, as datasets.map has memory related issue. TODO :Issue to be raised in the datasets repo.
Args
dataset
:Dataset
- dataset to preprocess