In this tutorial, we will introduce how to use sklearn grid search method to find the best models in machine learning.
1. Import library
import numpy as np from sklearn import datasets from sklearn.linear_model import LogisticRegression from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import GridSearchCV from sklearn.pipeline import Pipeline # Set random seed np.random.seed(0)
2. Lad Iris dataset
# Load data iris = datasets.load_iris() X = iris.data y = iris.target
3. Create pipeline
pipe = Pipeline([('classifier', RandomForestClassifier())])
4. Create model selection search space
We will find best model by this search space.
# Create space of candidate learning algorithms and their hyperparameters search_space = [{'classifier': [LogisticRegression()], 'classifier__penalty': ['l1', 'l2'], 'classifier__C': np.logspace(0, 4, 10)}, {'classifier': [RandomForestClassifier()], 'classifier__n_estimators': [10, 100, 1000], 'classifier__max_features': [1, 2, 3]}]
5. Create grid search
# Create grid search clf = GridSearchCV(pipe, search_space, cv=5, verbose=0)
6. Use grid search to fit model
# Fit grid search best_model = clf.fit(X, y)
7. View the best model
# View best model best_model.best_estimator_.get_params()['classifier']
You may get the best model is:
LogisticRegression(C=7.7426368268112693, class_weight=None, dual=False, fit_intercept=True, intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1, penalty='l1', random_state=None, solver='liblinear', tol=0.0001, verbose=0, warm_start=False)
8. Use best model to predict
# Predict target vector best_model.predict(X)
You may get the result:
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])