使用auto-sklearn demo AutoML

上一篇说明了Auto ML的基本概念, 本篇我们就来使用auto sklearn实作看看 Auto ML怎麽操作.

请到github下载 notebook

  • notebook: cardiovascular_disease_prediction_notebook_automl.ipynb
  • dateset: cardio_train.csv

我们接续上一个心血管疾病范例改写, 原本是使用XGBoost演算法, 这次我们改用auto skleran来进行训练.

首先我们需要安装 scikit learn与auto sklearn

pip install -U scikit-learn auto-sklearn

然後我们import所需要的module

import autosklearn
import sklearn.metrics
import autosklearn.classification

把autosklean的版本印出来看看

print('autosklearn: %s' % autosklearn.__version__)

接下来我们要呼叫auto sklearn所提供的演算法物件, 选择classification的AutoSklearnClassifier物件, 参数说明如下

  • time_left_for_this_task: 对每一种模型最多的训练时间, 预设值是3600秒, 数值愈高模型的准确度会愈高, 但所需要的时间就比较长. 在我们的demo中时间设定短一点(300秒), 这样不会等待太久, 同时也能看到执行的效果.
  • per_run_time_limit: 每次执行(run)的最高秒数, 预设值是time_left_for_this_task参数的十分之一(360秒). 在我们的demo中也设定为300秒的十分之一, 也就是30秒
# define search
model = autosklearn.classification.AutoSklearnClassifier(time_left_for_this_task=300, per_run_time_limit=30)

执行训练. 这里的训练资料直接使用心血管疾病资料集所切分好的X(features)与y(target), 然後呼叫fit()就会开始执行.

# perform the search
model.fit(X_train, y_train)

在等待一段时间完成训练之後, 印出统计资料看一下.

# summarize
print(model.sprint_statistics())

这次的统计结果如下.

  • 有11种演算法被纳入评估
  • 有3种演算法成功完成训练, 有8种演算法在限制的时间无法完成
  • 最好的演算法取得的验证分数为 0.73, metric是 accuracy
auto-sklearn results:
  Dataset name: 77c66f3c-0fdc-11ec-8906-8b133cd412c7
  Metric: accuracy
  Best validation score: 0.732738
  Number of target algorithm runs: 11
  Number of successful target algorithm runs: 3
  Number of crashed target algorithm runs: 0
  Number of target algorithms that exceeded the time limit: 8
  Number of target algorithms that exceeded the memory limit: 0

前面的统计计表说明有3种演算法完成训练, 那我们来看一下是哪三种演算法

print(model.leaderboard())

结果显示最好的演算法是gradient_boosting.

          rank  ensemble_weight               type      cost   duration
model_id                                                               
4            1             0.84  gradient_boosting  0.267262  21.289428
8            2             0.08                mlp  0.354654  27.291203
10           3             0.08                sgd  0.356439   7.165989

最後可以印出每固模型的详细资料

print(model.show_models())

这次的详细资料如下:

[(0.840000, SimpleClassificationPipeline({'balancing:strategy': 'weighting', 'classifier:__choice__': 'gradient_boosting', 'data_preprocessing:categorical_transformer:categorical_encoding:__choice__': 'one_hot_encoding', 'data_preprocessing:categorical_transformer:category_coalescence:__choice__': 'minority_coalescer', 'data_preprocessing:numerical_transformer:imputation:strategy': 'mean', 'data_preprocessing:numerical_transformer:rescaling:__choice__': 'standardize', 'feature_preprocessor:__choice__': 'no_preprocessing', 'classifier:gradient_boosting:early_stop': 'off', 'classifier:gradient_boosting:l2_regularization': 1.0945814167023392e-10, 'classifier:gradient_boosting:learning_rate': 0.11042628136263043, 'classifier:gradient_boosting:loss': 'auto', 'classifier:gradient_boosting:max_bins': 255, 'classifier:gradient_boosting:max_depth': 'None', 'classifier:gradient_boosting:max_leaf_nodes': 30, 'classifier:gradient_boosting:min_samples_leaf': 22, 'classifier:gradient_boosting:scoring': 'loss', 'classifier:gradient_boosting:tol': 1e-07, 'data_preprocessing:categorical_transformer:category_coalescence:minority_coalescer:minimum_fraction': 0.05141281638752715},
dataset_properties={
  'task': 1,
  'sparse': False,
  'multilabel': False,
  'multiclass': False,
  'target_type': 'classification',
  'signed': False})),
(0.080000, SimpleClassificationPipeline({'balancing:strategy': 'none', 'classifier:__choice__': 'mlp', 'data_preprocessing:categorical_transformer:categorical_encoding:__choice__': 'no_encoding', 'data_preprocessing:categorical_transformer:category_coalescence:__choice__': 'minority_coalescer', 'data_preprocessing:numerical_transformer:imputation:strategy': 'mean', 'data_preprocessing:numerical_transformer:rescaling:__choice__': 'standardize', 'feature_preprocessor:__choice__': 'feature_agglomeration', 'classifier:mlp:activation': 'tanh', 'classifier:mlp:alpha': 0.05476322473700896, 'classifier:mlp:batch_size': 'auto', 'classifier:mlp:beta_1': 0.9, 'classifier:mlp:beta_2': 0.999, 'classifier:mlp:early_stopping': 'valid', 'classifier:mlp:epsilon': 1e-08, 'classifier:mlp:hidden_layer_depth': 1, 'classifier:mlp:learning_rate_init': 0.012698439797907473, 'classifier:mlp:n_iter_no_change': 32, 'classifier:mlp:num_nodes_per_layer': 136, 'classifier:mlp:shuffle': 'True', 'classifier:mlp:solver': 'adam', 'classifier:mlp:tol': 0.0001, 'data_preprocessing:categorical_transformer:category_coalescence:minority_coalescer:minimum_fraction': 0.07441872802099897, 'feature_preprocessor:feature_agglomeration:affinity': 'manhattan', 'feature_preprocessor:feature_agglomeration:linkage': 'average', 'feature_preprocessor:feature_agglomeration:n_clusters': 264, 'feature_preprocessor:feature_agglomeration:pooling_func': 'max', 'classifier:mlp:validation_fraction': 0.1},
dataset_properties={
  'task': 1,
  'sparse': False,
  'multilabel': False,
  'multiclass': False,
  'target_type': 'classification',
  'signed': False})),
(0.080000, SimpleClassificationPipeline({'balancing:strategy': 'none', 'classifier:__choice__': 'sgd', 'data_preprocessing:categorical_transformer:categorical_encoding:__choice__': 'one_hot_encoding', 'data_preprocessing:categorical_transformer:category_coalescence:__choice__': 'minority_coalescer', 'data_preprocessing:numerical_transformer:imputation:strategy': 'most_frequent', 'data_preprocessing:numerical_transformer:rescaling:__choice__': 'standardize', 'feature_preprocessor:__choice__': 'select_percentile_classification', 'classifier:sgd:alpha': 1.6992296128865824e-07, 'classifier:sgd:average': 'True', 'classifier:sgd:fit_intercept': 'True', 'classifier:sgd:learning_rate': 'optimal', 'classifier:sgd:loss': 'log', 'classifier:sgd:penalty': 'l1', 'classifier:sgd:tol': 1.535384699341134e-05, 'data_preprocessing:categorical_transformer:category_coalescence:minority_coalescer:minimum_fraction': 0.24471105740962484, 'feature_preprocessor:select_percentile_classification:percentile': 39.91903776071659, 'feature_preprocessor:select_percentile_classification:score_func': 'f_classif'},
dataset_properties={
  'task': 1,
  'sparse': False,
  'multilabel': False,
  'multiclass': False,
  'target_type': 'classification',
  'signed': False})),
]

这样就可以使用由auto sklearn所训练出来的模型进行部署, 然後提供推论的结果. 但以目前来说还不建议直接将auto ML训练出来的模型放在正式环境上做部署. 比较建议的方式是先使用auto ML工具产出具有基本准确度的模型与参数, 然後由人工进行参数的调整训练模型以再次提高模型的准确度, 这样会是比较好的方式.


<<:  使用Vertex汇出的模型 | ML#Day26

>>:  Angular 深入浅出三十天:表单与测试 Day19 - 与 Cypress 的初次见面(下)

IOS 菜菜菜鸟30天挑战 Day-30 结语+转图小技巧

我实在没想到我能坚持连续30天不间断发文,对我来说真的是一大挑战,因为我不是一个经得起坚持的人,这次...

Day 17 | 今天是Coroutiones的 Dispatcher

读到Dispatcher有种越来越难的感觉QQ 这些太高深的东西对於小萌新来说真的好杀热情阿 估计今...

[Python 爬虫这样学,一定是大拇指拉!] DAY15 - Cookie / Session

相信 Cookie 与 Session 很多人常常搞不清楚,今天就用超级简单的方式来让大家了解。 C...

【Day 06】从零开始的 Line Chatbot-浅谈 Django

前几天在做建立专案的时候,好像看到一个不是很懂的东西-Django。 今天不谈 Chatbot,我们...

Flutter基础介绍与实作-Day22 旅游笔记的实作(3)

我们今天要接续昨天的划分4个区域开始,我们今天先从北部开始吧! 一样先来建立资料夹 lib/scar...