上一篇说明了Auto ML的基本概念, 本篇我们就来使用auto sklearn实作看看 Auto ML怎麽操作.
请到github下载 notebook
我们接续上一个心血管疾病范例改写, 原本是使用XGBoost演算法, 这次我们改用auto skleran来进行训练.
首先我们需要安装 scikit learn与auto sklearn
pip install -U scikit-learn auto-sklearn
然後我们import所需要的module
import autosklearn
import sklearn.metrics
import autosklearn.classification
把autosklean的版本印出来看看
print('autosklearn: %s' % autosklearn.__version__)
接下来我们要呼叫auto sklearn所提供的演算法物件, 选择classification的AutoSklearnClassifier物件, 参数说明如下
time_left_for_this_task
参数的十分之一(360秒). 在我们的demo中也设定为300秒的十分之一, 也就是30秒# define search
model = autosklearn.classification.AutoSklearnClassifier(time_left_for_this_task=300, per_run_time_limit=30)
执行训练. 这里的训练资料直接使用心血管疾病资料集所切分好的X(features)与y(target), 然後呼叫fit()就会开始执行.
# perform the search
model.fit(X_train, y_train)
在等待一段时间完成训练之後, 印出统计资料看一下.
# summarize
print(model.sprint_statistics())
这次的统计结果如下.
auto-sklearn results:
Dataset name: 77c66f3c-0fdc-11ec-8906-8b133cd412c7
Metric: accuracy
Best validation score: 0.732738
Number of target algorithm runs: 11
Number of successful target algorithm runs: 3
Number of crashed target algorithm runs: 0
Number of target algorithms that exceeded the time limit: 8
Number of target algorithms that exceeded the memory limit: 0
前面的统计计表说明有3种演算法完成训练, 那我们来看一下是哪三种演算法
print(model.leaderboard())
结果显示最好的演算法是gradient_boosting
.
rank ensemble_weight type cost duration
model_id
4 1 0.84 gradient_boosting 0.267262 21.289428
8 2 0.08 mlp 0.354654 27.291203
10 3 0.08 sgd 0.356439 7.165989
最後可以印出每固模型的详细资料
print(model.show_models())
这次的详细资料如下:
[(0.840000, SimpleClassificationPipeline({'balancing:strategy': 'weighting', 'classifier:__choice__': 'gradient_boosting', 'data_preprocessing:categorical_transformer:categorical_encoding:__choice__': 'one_hot_encoding', 'data_preprocessing:categorical_transformer:category_coalescence:__choice__': 'minority_coalescer', 'data_preprocessing:numerical_transformer:imputation:strategy': 'mean', 'data_preprocessing:numerical_transformer:rescaling:__choice__': 'standardize', 'feature_preprocessor:__choice__': 'no_preprocessing', 'classifier:gradient_boosting:early_stop': 'off', 'classifier:gradient_boosting:l2_regularization': 1.0945814167023392e-10, 'classifier:gradient_boosting:learning_rate': 0.11042628136263043, 'classifier:gradient_boosting:loss': 'auto', 'classifier:gradient_boosting:max_bins': 255, 'classifier:gradient_boosting:max_depth': 'None', 'classifier:gradient_boosting:max_leaf_nodes': 30, 'classifier:gradient_boosting:min_samples_leaf': 22, 'classifier:gradient_boosting:scoring': 'loss', 'classifier:gradient_boosting:tol': 1e-07, 'data_preprocessing:categorical_transformer:category_coalescence:minority_coalescer:minimum_fraction': 0.05141281638752715},
dataset_properties={
'task': 1,
'sparse': False,
'multilabel': False,
'multiclass': False,
'target_type': 'classification',
'signed': False})),
(0.080000, SimpleClassificationPipeline({'balancing:strategy': 'none', 'classifier:__choice__': 'mlp', 'data_preprocessing:categorical_transformer:categorical_encoding:__choice__': 'no_encoding', 'data_preprocessing:categorical_transformer:category_coalescence:__choice__': 'minority_coalescer', 'data_preprocessing:numerical_transformer:imputation:strategy': 'mean', 'data_preprocessing:numerical_transformer:rescaling:__choice__': 'standardize', 'feature_preprocessor:__choice__': 'feature_agglomeration', 'classifier:mlp:activation': 'tanh', 'classifier:mlp:alpha': 0.05476322473700896, 'classifier:mlp:batch_size': 'auto', 'classifier:mlp:beta_1': 0.9, 'classifier:mlp:beta_2': 0.999, 'classifier:mlp:early_stopping': 'valid', 'classifier:mlp:epsilon': 1e-08, 'classifier:mlp:hidden_layer_depth': 1, 'classifier:mlp:learning_rate_init': 0.012698439797907473, 'classifier:mlp:n_iter_no_change': 32, 'classifier:mlp:num_nodes_per_layer': 136, 'classifier:mlp:shuffle': 'True', 'classifier:mlp:solver': 'adam', 'classifier:mlp:tol': 0.0001, 'data_preprocessing:categorical_transformer:category_coalescence:minority_coalescer:minimum_fraction': 0.07441872802099897, 'feature_preprocessor:feature_agglomeration:affinity': 'manhattan', 'feature_preprocessor:feature_agglomeration:linkage': 'average', 'feature_preprocessor:feature_agglomeration:n_clusters': 264, 'feature_preprocessor:feature_agglomeration:pooling_func': 'max', 'classifier:mlp:validation_fraction': 0.1},
dataset_properties={
'task': 1,
'sparse': False,
'multilabel': False,
'multiclass': False,
'target_type': 'classification',
'signed': False})),
(0.080000, SimpleClassificationPipeline({'balancing:strategy': 'none', 'classifier:__choice__': 'sgd', 'data_preprocessing:categorical_transformer:categorical_encoding:__choice__': 'one_hot_encoding', 'data_preprocessing:categorical_transformer:category_coalescence:__choice__': 'minority_coalescer', 'data_preprocessing:numerical_transformer:imputation:strategy': 'most_frequent', 'data_preprocessing:numerical_transformer:rescaling:__choice__': 'standardize', 'feature_preprocessor:__choice__': 'select_percentile_classification', 'classifier:sgd:alpha': 1.6992296128865824e-07, 'classifier:sgd:average': 'True', 'classifier:sgd:fit_intercept': 'True', 'classifier:sgd:learning_rate': 'optimal', 'classifier:sgd:loss': 'log', 'classifier:sgd:penalty': 'l1', 'classifier:sgd:tol': 1.535384699341134e-05, 'data_preprocessing:categorical_transformer:category_coalescence:minority_coalescer:minimum_fraction': 0.24471105740962484, 'feature_preprocessor:select_percentile_classification:percentile': 39.91903776071659, 'feature_preprocessor:select_percentile_classification:score_func': 'f_classif'},
dataset_properties={
'task': 1,
'sparse': False,
'multilabel': False,
'multiclass': False,
'target_type': 'classification',
'signed': False})),
]
这样就可以使用由auto sklearn所训练出来的模型进行部署, 然後提供推论的结果. 但以目前来说还不建议直接将auto ML训练出来的模型放在正式环境上做部署. 比较建议的方式是先使用auto ML工具产出具有基本准确度的模型与参数, 然後由人工进行参数的调整训练模型以再次提高模型的准确度, 这样会是比较好的方式.
>>: Angular 深入浅出三十天:表单与测试 Day19 - 与 Cypress 的初次见面(下)
我实在没想到我能坚持连续30天不间断发文,对我来说真的是一大挑战,因为我不是一个经得起坚持的人,这次...
读到Dispatcher有种越来越难的感觉QQ 这些太高深的东西对於小萌新来说真的好杀热情阿 估计今...
相信 Cookie 与 Session 很多人常常搞不清楚,今天就用超级简单的方式来让大家了解。 C...
前几天在做建立专案的时候,好像看到一个不是很懂的东西-Django。 今天不谈 Chatbot,我们...
我们今天要接续昨天的划分4个区域开始,我们今天先从北部开始吧! 一样先来建立资料夹 lib/scar...