class sklearn.cluster.MeanShift(bandwidth=None, seeds=None, bin_seeding=False, min_bin_freq=1, cluster_all=True, n_jobs=None) [source]
Mean shift clustering using a flat kernel.
Mean shift clustering aims to discover “blobs” in a smooth density of samples. It is a centroid-based algorithm, which works by updating candidates for centroids to be the mean of the points within a given region. These candidates are then filtered in a post-processing stage to eliminate near-duplicates to form the final set of centroids.
Seeding is performed using a binning technique for scalability.
Read more in the User Guide.
| Parameters: |
|
|---|---|
| Attributes: |
|
Scalability:
Because this implementation uses a flat kernel and a Ball Tree to look up members of each kernel, the complexity will tend towards O(T*n*log(n)) in lower dimensions, with n the number of samples and T the number of points. In higher dimensions the complexity will tend towards O(T*n^2).
Scalability can be boosted by using fewer seeds, for example by using a higher value of min_bin_freq in the get_bin_seeds function.
Note that the estimate_bandwidth function is much less scalable than the mean shift algorithm and will be the bottleneck if it is used.
Dorin Comaniciu and Peter Meer, “Mean Shift: A robust approach toward feature space analysis”. IEEE Transactions on Pattern Analysis and Machine Intelligence. 2002. pp. 603-619.
>>> from sklearn.cluster import MeanShift
>>> import numpy as np
>>> X = np.array([[1, 1], [2, 1], [1, 0],
... [4, 7], [3, 5], [3, 6]])
>>> clustering = MeanShift(bandwidth=2).fit(X)
>>> clustering.labels_
array([1, 1, 1, 0, 0, 0])
>>> clustering.predict([[0, 0], [5, 5]])
array([1, 0])
>>> clustering
MeanShift(bandwidth=2, bin_seeding=False, cluster_all=True, min_bin_freq=1,
n_jobs=None, seeds=None)
fit(X[, y]) | Perform clustering. |
fit_predict(X[, y]) | Performs clustering on X and returns cluster labels. |
get_params([deep]) | Get parameters for this estimator. |
predict(X) | Predict the closest cluster each sample in X belongs to. |
set_params(**params) | Set the parameters of this estimator. |
__init__(bandwidth=None, seeds=None, bin_seeding=False, min_bin_freq=1, cluster_all=True, n_jobs=None) [source]
fit(X, y=None) [source]
Perform clustering.
| Parameters: |
|
|---|
fit_predict(X, y=None) [source]
Performs clustering on X and returns cluster labels.
| Parameters: |
|
|---|---|
| Returns: |
|
get_params(deep=True) [source]
Get parameters for this estimator.
| Parameters: |
|
|---|---|
| Returns: |
|
predict(X) [source]
Predict the closest cluster each sample in X belongs to.
| Parameters: |
|
|---|---|
| Returns: |
|
set_params(**params) [source]
Set the parameters of this estimator.
The method works on simple estimators as well as on nested objects (such as pipelines). The latter have parameters of the form <component>__<parameter> so that it’s possible to update each component of a nested object.
| Returns: |
|
|---|
sklearn.cluster.MeanShift
© 2007–2018 The scikit-learn developers
Licensed under the 3-clause BSD License.
http://scikit-learn.org/stable/modules/generated/sklearn.cluster.MeanShift.html