Module llmsearch.utils.gen_utils
Generation Related Utilties
Note : Generation Type Detection Scripts are updated as of transformers v4.31.0
Functions
def check_if_gen_param_rules_satisfy(gen_params, rules)-
Check if the provided generation parameters satisfy the specified rules.
Args
gen_params:Dict- Dictionary containing generation parameters.
rules:Dict- Dictionary containing rules for parameter validation.
Returns
bool- True if all rules are satisfied, False otherwise.
def check_if_param_req_satisfy(gen_params, param, param_req)-
Check if a specific generation parameter satisfies the specified requirement.
Args
gen_params:Dict- Dictionary containing generation parameters.
param:str- Name of the generation parameter to check.
param_req:Dict- Dictionary specifying the requirement for the parameter.
Returns
bool- True if the requirement is satisfied, False otherwise.
def check_sample_parameter(gen_params, gen_type_params)-
Check if the
do_sampleparameter is set to True when other parameters that are dependent on sampling are present.Args
gen_params:Dict- Dictionary containing generation parameters.
gen_type_params:Dict- Params to exclude if already being used by a generation type (
top_kin Contrastive Search Decoding)
Raises
UserWarning- Warns if dependent generation parameters are present without 'do_sample' set to True.
def get_sample_hyp_space(seed, max_new_tokens)-
Get 2 sample hyp spaces
Args
seed:int- seed
Returns
Tuple[List, List]- First Item is a larger hyp space which searches for individual generation types, Second Item of the Tuple are the top generation params as evaluated by oobabooga using Vicuna-13B with instruct prompts.
def identify_and_validate_gen_params(gen_params)-
Identify and validate the generation type based on provided generation parameters.
Args
gen_params:Dict- Dictionary containing generation parameters.
Returns
str- Name of the identified generation type.