Module llmsearch.tuner
Core Tuner Class that adapts a PyTorch model to function as a scikit-learn estimator.
Sub-modules
llmsearch.tuner.tuner-
Main Tuner Module containing
LLMEstimatorWrapperClass &TunerClass for scikit-learn
Classes
class Tuner (model, tokenizer, dataset, column_mapping, scorer, device, prompt_template=None, sample_preprocessor=None, tokenizer_encode_args=None, tokenizer_decode_args=None, batch_size=16, output_preproc=<function Tuner.<lambda>>, callbacks_after_inference=None, is_encoder_decoder=False, greater_is_better=True, seed=42, disable_batch_size_cache=False, disable_generation_param_checks=False, sample_ratio=0.3, tokenizer_max_length_quantile=0.9, custom_pred_function=None)-
Tuner Class which drives the search for the generation hyperparameters
Initializes the Tuner Class, populates the dataset with
_X&_ykeys for the model to use.Args
model:nn.Module- model that has a
.generatemethod tokenizer:AutoTokenizer- tokenizer for the input
dataset:Dataset- dataset to perform search on
column_mapping:Dict[str, list]- should contain
input_cols&eval_colskeys,input_colsshould contain the columns to be used in theprompt_template&eval_colsshould contain the columns to be used in thescorer,all eval_columns will be passed in as a dict as the second argument to thescorerfunction, eg -{'input_cols' : ["question"], 'eval_cols' : ['answer']} scorer:Callable- A function that has this signature -
(y_true: List, y_pred: List) -> float, takes in ground truth and predictions are returns a metric to optimize on,eval_colsincolumn_mappingare passed in as the second argument as aList[Dict] device:str- device to run inference on, eg -
cuda:0 prompt_template:str- template with placeholders for
input_colsfrom thecolumn_mappingargument, eg -"Question : How many days are there in a year?\nAnswer : 365\n\nQuestion : {question}\nAnswer : ", not used whensample_preprocessoris notNone sample_preprocessor:Callable- Preprocessor function for a single example from the
dataset, should have the signature -(tokenizer, **kwargs) -> str, where key word arguments are the columns frominput_cols&eval_colsincolumn_mapping, not used whenprompt_templateis notNone tokenizer_encode_args:Dict, optional- Encoding key value arguments for the
tokenizer. IfNoneit's initialized using theget_default_input_tokenizer_kwargsmethod. Defaults toNone. tokenizer_decode_args:Dict, optional- Decoding key value arguments for the
tokenizer. Defaults to{'skip_special_tokens' : True}. batch_size:int, optional- batch_size to run inference with, this gets dynamically halfed if the inference function encounters OOM errors. Defaults to
16. output_preproc:Callable, optional- Post processing function for the output, by default it strips the output. Defaults to
lambda x : x.strip(). callbacks_after_inference:List[Callable], optional- Callbacks to run after each inference. Useful for stopping criteria in generation. Defaults to
None. is_encoder_decoder:bool, optional- whether the model is an encoder-decoder model,
Falseif not. Defaults toFalse. greater_is_better:bool, optional- whether the metric to optimize on is greater the better. Defaults to
True. seed:int, optional- seed for reproducibility. Defaults to
42. disable_batch_size_cache:bool, optional- If
Truefor each cross validation run, the pre-definedbatch_sizeis used, this could lead to wasted computation time if OOM is raised by the inference function. Defaults toFalse. disable_generation_param_checks:bool, optional- Disables the custom generation parameter checks, this does a sanity check of the parameters & produces warnings before doing generation. Defaults to
False. sample_ratio:float, optional- Sampling Ratio of
datasetto find the ideal values for padding and truncation. Argument is invalid iftokenizer_encode_argsis notNone. Defaults to0.3. tokenizer_max_length_quantile:float, optional- percentile to find a value for
max_lengthbased on the dataset. Defaults to0.9. custom_pred_function:Union[Callable, None], optional- Override inference function if present. Defaults to
None. Should take in two parameters - model inputs (List[str]) and model generation parameters (Dict) and return aList[str]of outputs overridesmodel_utils.run_inference
Methods
def get_default_tokenizer_encode_args(self, sample_ratio, tokenizer_encode_args)-
Get default input tokenizer arguments using the dataset, Sets
padding&truncationtoTrueand calculatemax_lengthas tokenizer arguments.Args
sample_ratio:float- Sampling Ratio of
datasetto find the ideal values for padding and truncation. Argument is invalid iftokenizer_encode_argsis notNone. tokenizer_length_percentile:float- percentile to find a value for
max_lengthbased on the dataset. tokenizer_kwargs:Union[Dict, None]- Encoding key value arguments for the
tokenizer. Returns the same value if this is notNone. Defaults toNone.
def get_score(self, generation_args, dataset=None)-
Evaluate the score function on a dataset or the initialized dataset using some generation arguments for the model. If
datasetisNonethe initialized dataset is used, else thedatasetis preprocessed(_X&_yare populated) and used.Args
generation_args:Dict- generation kwargs to perform inference
dataset:Union[Dataset, Dict], optional- dataset to perform inference on. Defaults to None.
Returns
Tuple[float, List]- score, predictions
def get_value_at_quantile(self, input_list, quantile)-
Get value at a specific quantile
Args
input_list:List[str]- list of str on which to run the encoding of the
tokenizer quantile:float- quantile on which to find the value on.
Returns
int- rounded value at quantile
def preprocess_dataset(self, dataset)-
Dataset preprocessor, preprocesses using the
prompt_templateorsample_preprocessorfunction.self.prompt_template- Useful for already processed datasets(text can be directly fed into the model)self.sample_preprocessor- Useful for datasets that need to be preprocessed(converting into chat format) before feeding into the modelAdds
_X&_ykeys to the datasetNote : datasets.map is not used and traditional map has been used to map the dataset, as datasets.map has memory related issue. TODO :Issue to be raised in the datasets repo.
Args
dataset:Dataset- dataset to preprocess