Abstractions for model training with Podium¶

Models¶

Module contains ML models.

Pipeline¶

Package contains components for automatic training and evaluation of ML models.

Model selection¶

Method used to find the best combination of training and model hyperparameters out of the given hyperparameters. This method uses simple grid search to evaluate all possible combinations of the given hyperparameters. Based on sklearn’s sklearn.model_selection.GridSearchCV. Each hyperparameter combination is scored by first training the model and then evaluating that model using k-fold cross validation. The final score for a set of hyperparameters is the mean of scores across all folds.

Parameters
  • experiment (Experiment) – Experiment defining the training and prediction procedure to be optimised.

  • dataset (Dataset) – Dataset to be used in the hyperparameter search.

  • score_fun (callable) – Function used to score a hyperparameter set.

  • model_param_grid (Dict or Iterable of Dicts) – The model parameter grid. Combinations taken from this grid are passed to the model’s __init__ function. Dictionary with parameters names (string) as keys and lists of parameter settings to try as values, or a list of such dictionaries, in which case the grids spanned by each dictionary in the list are explored. This enables searching over any sequence of parameter settings.

  • trainer_param_grid (Dict or Iterable of Dicts) – The trainer parameter grid. Combinations taken from this grid are passed to the trainers’s train function. Dictionary with parameters names (string) as keys and lists of parameter settings to try as values, or a list of such dictionaries, in which case the grids spanned by each dictionary in the list are explored. This enables searching over any sequence of parameter settings.

  • n_splits (int) – Number of folds to be used in cross-validation.

  • greater_is_better (bool) – Whether score_func is a score function (default), meaning high is good, or a loss function, meaning low is good.

  • print_progress (bool) – Whether to print progress. Progress is printed to sys.stderr.

Model validation¶

Package contains modules used in model validation.