Source code for deeptables.utils.dart_early_stopping

# -*- coding:utf-8 -*-

from operator import gt, lt
from sklearn.base import clone
from lightgbm.compat import range_
from lightgbm.callback import EarlyStopException


def _format_eval_result(value, show_stdv=True):
    """Format metric string."""
    if len(value) == 4:
        return '%s\'s %s: %g' % (value[0], value[1], value[2])
    elif len(value) == 5:
        if show_stdv:
            return '%s\'s %s: %g + %g' % (value[0], value[1], value[2], value[4])
        else:
            return '%s\'s %s: %g' % (value[0], value[1], value[2])
    else:
        raise ValueError("Wrong metric value")


[docs]def dart_early_stopping(stopping_rounds, first_metric_only=False, verbose=True): """Create a callback that activates early stopping. Activates early stopping. The model will train until the validation score stops improving. Validation score needs to improve at least every ``early_stopping_rounds`` round(s) to continue training. Requires at least one validation datasets and one metric. If there's more than one, will check all of them. But the training datasets is ignored anyway. To check only the first metric set ``first_metric_only`` to True. Parameters ---------- stopping_rounds : int The possible number of rounds without the trend occurrence. first_metric_only : bool, optional (default=False) Whether to use only the first metric for early stopping. verbose : bool, optional (default=True) Whether to print message with early stopping information. Returns ------- callback : function The callback that activates early stopping. """ best_score = [] best_iter = [] best_score_list = [] best_model = {} cmp_op = [] enabled = [True] first_metric = [''] def _init(env): if not env.evaluation_result_list: raise ValueError('For early stopping, ' 'at least one dataset and eval metric is required for evaluation') if verbose: msg = "Training until validation scores don't improve for {} rounds" print(msg.format(stopping_rounds)) # split is needed for "<dataset type> <metric>" case (e.g. "train l1") first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1] for eval_ret in env.evaluation_result_list: best_iter.append(0) best_score_list.append(None) if eval_ret[3]: best_score.append(float('-inf')) cmp_op.append(gt) else: best_score.append(float('inf')) cmp_op.append(lt) def _final_iteration_check(env, eval_name_splitted, i): if env.iteration == env.end_iteration - 1: if verbose: print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % ( best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]]))) if first_metric_only: print("Evaluated only: {}".format(eval_name_splitted[-1])) raise EarlyStopException(best_iter[i], best_score_list[i]) def _callback(env): if not cmp_op: _init(env) if not enabled[0]: return # print(f'iteration:{env.iteration}') for i in range_(len(env.evaluation_result_list)): score = env.evaluation_result_list[i][2] # print(f'score:{score}') if best_score_list[i] is None or cmp_op[i](score, best_score[i]): best_score[i] = score best_iter[i] = env.iteration best_score_list[i] = env.evaluation_result_list best_model[i] = env.model.model_to_string() # split is needed for "<dataset type> <metric>" case (e.g. "train l1") eval_name_splitted = env.evaluation_result_list[i][1].split(" ") if first_metric_only and first_metric[0] != eval_name_splitted[-1]: continue # use only the first metric for early stopping if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train" or env.evaluation_result_list[i][0] == env.model._train_data_name)): _final_iteration_check(env, eval_name_splitted, i) continue # train datasets for lgb.cv or sklearn wrapper (underlying lgb.train) elif env.iteration - best_iter[i] >= stopping_rounds: if verbose: print('Early stopping, best iteration is:\n[%d]\t%s' % ( best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]]))) if first_metric_only: print("Evaluated only: {}".format(eval_name_splitted[-1])) # recover best model from string env.model.model_from_string(best_model[i]) raise EarlyStopException(best_iter[i], best_score_list[i]) _final_iteration_check(env, eval_name_splitted, i) _callback.order = 30 return _callback