Module llmsearch.utils.gen_utils
Generation Related Utilties
Note : Generation Type Detection Scripts are updated as of transformers v4.31.0
Functions
def check_if_gen_param_rules_satisfy(gen_params, rules)
-
Check if the provided generation parameters satisfy the specified rules.
Args
gen_params
:Dict
- Dictionary containing generation parameters.
rules
:Dict
- Dictionary containing rules for parameter validation.
Returns
bool
- True if all rules are satisfied, False otherwise.
def check_if_param_req_satisfy(gen_params, param, param_req)
-
Check if a specific generation parameter satisfies the specified requirement.
Args
gen_params
:Dict
- Dictionary containing generation parameters.
param
:str
- Name of the generation parameter to check.
param_req
:Dict
- Dictionary specifying the requirement for the parameter.
Returns
bool
- True if the requirement is satisfied, False otherwise.
def check_sample_parameter(gen_params, gen_type_params)
-
Check if the
do_sample
parameter is set to True when other parameters that are dependent on sampling are present.Args
gen_params
:Dict
- Dictionary containing generation parameters.
gen_type_params
:Dict
- Params to exclude if already being used by a generation type (
top_k
in Contrastive Search Decoding)
Raises
UserWarning
- Warns if dependent generation parameters are present without 'do_sample' set to True.
def get_sample_hyp_space(seed, max_new_tokens)
-
Get 2 sample hyp spaces
Args
seed
:int
- seed
Returns
Tuple[List, List]
- First Item is a larger hyp space which searches for individual generation types, Second Item of the Tuple are the top generation params as evaluated by oobabooga using Vicuna-13B with instruct prompts.
def identify_and_validate_gen_params(gen_params)
-
Identify and validate the generation type based on provided generation parameters.
Args
gen_params
:Dict
- Dictionary containing generation parameters.
Returns
str
- Name of the identified generation type.