Module llmsearch.utils.gen_utils

Generation Related Utilties

Note : Generation Type Detection Scripts are updated as of transformers v4.31.0

Functions

def check_if_gen_param_rules_satisfy(gen_params, rules)

Check if the provided generation parameters satisfy the specified rules.

Args

gen_params : Dict
Dictionary containing generation parameters.
rules : Dict
Dictionary containing rules for parameter validation.

Returns

bool
True if all rules are satisfied, False otherwise.
def check_if_param_req_satisfy(gen_params, param, param_req)

Check if a specific generation parameter satisfies the specified requirement.

Args

gen_params : Dict
Dictionary containing generation parameters.
param : str
Name of the generation parameter to check.
param_req : Dict
Dictionary specifying the requirement for the parameter.

Returns

bool
True if the requirement is satisfied, False otherwise.
def check_sample_parameter(gen_params, gen_type_params)

Check if the do_sample parameter is set to True when other parameters that are dependent on sampling are present.

Args

gen_params : Dict
Dictionary containing generation parameters.
gen_type_params : Dict
Params to exclude if already being used by a generation type (top_k in Contrastive Search Decoding)

Raises

UserWarning
Warns if dependent generation parameters are present without 'do_sample' set to True.
def get_sample_hyp_space(seed, max_new_tokens)

Get 2 sample hyp spaces

Args

seed : int
seed

Returns

Tuple[List, List]
First Item is a larger hyp space which searches for individual generation types, Second Item of the Tuple are the top generation params as evaluated by oobabooga using Vicuna-13B with instruct prompts.
def identify_and_validate_gen_params(gen_params)

Identify and validate the generation type based on provided generation parameters.

Args

gen_params : Dict
Dictionary containing generation parameters.

Returns

str
Name of the identified generation type.