Source code for pysptools.skl.cv

#
#------------------------------------------------------------------------------
# Copyright (c) 2013-2017, Christian Therien
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#------------------------------------------------------------------------------
#
# cv.py - This file is part of the PySptools package.
#

from __future__ import print_function

import numpy as np
import sklearn.model_selection as ms


[docs]class HyperEstimatorCrossVal(object): """ Do a cross validation on a hypercube or a concatenation of hypercubes. Use scikit-learn KFold and GridSearchCV. """ def __init__(self, estimator, param_grid): """ Create a new HyperEstimatorCrossVal. Parameters: estimator: `class name` One of HyperSVC, HyperRandomForestClassifier, HyperKNeighborsClassifier HyperLogisticRegression, HyperGradientBoostingClassifier. param_grid: `dic` A dic of parameters to be cross validated. Ex. for HyperSVC: {'C': [10,20,30,50], 'gamma': [0.1,0.5,1.0,10.0]}. """ self.estimator = estimator self.param_grid = param_grid self.n_splits = 2 # Usefull ??
[docs] def fit_cube(self, M, mask): """ Do a cross validation on a hypercube Parameters: M: `numpy array` A HSI cube (m x n x p). mask: `numpy array` A class map mask. """ X = self._convert2D(M) Y = np.reshape(mask, mask.shape[0]*mask.shape[1]) self._cross_val(X, Y, self.param_grid)
[docs] def fit(self, X, y): """ Run the cross validation. Parameters: X: `numpy array` A vector (n_samples, n_features) where each element *n_features* is a spectrum. y: `numpy array` Target values (n_samples,). A zero value is the background. A value of one or more is a class value. """ self._cross_val(X, y, self.param_grid)
def _cross_val(self, X, Y, grid, n_splits=2): self.n_splits = n_splits kf = ms.KFold(n_splits=n_splits, shuffle=True) self.gcv = ms.GridSearchCV(self.estimator(), grid, cv=kf, refit=False) self.gcv.fit(X, Y) def _convert2D(self, M): h, w, numBands = M.shape return np.reshape(M, (w*h, numBands))
[docs] def get_best_params(self): """ Returns: `dic` Dic of best match. """ return self.gcv.best_params_
[docs] def print(self, label='No title'): """ Print a summary for the cross validation results. Parameters: label: `string` The test title. """ params = self.gcv.cv_results_['params'] scores = self.gcv.cv_results_['mean_test_score'] stds = self.gcv.cv_results_['std_test_score'] print('================================================================') print('Cross validation results for: {}'.format(label)) print('Param grid:', self.param_grid) print('n splits:', self.n_splits) print('Shuffle: True') print('================================================================') print('Best score:', self.gcv.best_score_) print('Best params:', self.gcv.best_params_) print('================================================================') print('All scores') for p,sc,st in zip(params,scores,stds): print(p, ', score: '+str(sc), ', std: '+str(st)) print('================================================================') print()