Classes
MLClassifier
Unified sklearn-compatible classifier for NMR metabolomics data.
Wraps four model families (Random Forest, SVM, XGBoost, Elastic Net) behind a single interface with built-in cross-validation, feature importance, and Plotly visualisations.
Attributes: pipeline_ (Pipeline): Fitted sklearn Pipeline (scaler + model). cv_results_ (dict): Cross-validated metrics stored after fit(). classes_ (np.ndarray): Unique class labels observed during fit().
Examples: >>> import pandas as pd >>> from metbit.ml import MLClassifier >>> X = pd.DataFrame({"f1": [1, 2, 3, 4], "f2": [4, 3, 2, 1]}) >>> y = ["A", "A", "B", "B"] >>> clf = MLClassifier(X, y, model="rf") >>> clf.fit(cv=2) >>> preds = clf.predict(X)
Methods
__init__(self, X: 'pd.DataFrame | np.ndarray', y: 'pd.Series | np.ndarray | list', model: str='rf', features_name: 'list | None'=None, scaling_method: str='pareto', random_state: int=42, **model_kwargs)
Initialise MLClassifier.
Args:
XFeature matrix with shape (n_samples, n_features).yTarget labels with length n_samples.modelModel family to use. One of ``"rf"``, ``"svm"``,``"xgb"``, or ``"elasticnet"``. Defaults to ``"rf"``.
features_nameOptional list of feature names. When *X* is aDataFrame the column names are used automatically.
scaling_methodScaling strategy. Currently StandardScaler isapplied regardless of the value (kept for API compatibility with the rest of metbit). Defaults to ``"pareto"``.
random_stateRandom seed passed to the underlying estimator.Defaults to ``42``. **model_kwargs: Additional keyword arguments forwarded verbatim to the underlying sklearn/XGBoost estimator.
Raises:
ValueErrorIf *model* is not one of the supported families.Examples: >>> clf = MLClassifier(X, y, model="svm", C=1.0)
fit(self, cv: int=5)
Fit the model on the full dataset and compute cross-validated metrics.
Cross-validation uses ``StratifiedKFold`` with *cv* folds. The following metrics are recorded in ``cv_results_``:
- ``accuracy_mean`` / ``accuracy_std`` - ``balanced_accuracy_mean`` / ``balanced_accuracy_std`` - ``roc_auc_mean`` / ``roc_auc_std``
Args:
cvNumber of cross-validation folds. Defaults to ``5``.Returns: self - to allow method chaining.
Examples: >>> clf = MLClassifier(X, y, model="rf").fit(cv=5) >>> print(clf.cv_results_)
predict(self, X_new: 'pd.DataFrame | np.ndarray')
Predict class labels for *X_new*.
Args:
X_newFeature matrix with shape (n_samples, n_features).Returns: Predicted class labels as a 1-D array.
Examples: >>> labels = clf.predict(X_test)
predict_proba(self, X_new: 'pd.DataFrame | np.ndarray')
Predict class-membership probabilities for *X_new*.
Args:
X_newFeature matrix with shape (n_samples, n_features).Returns: Probability matrix with shape (n_samples, n_classes).
Examples: >>> proba = clf.predict_proba(X_test)
get_feature_importance(self, top_n: int=30)
Return a DataFrame of the top-N most important features.
The importance source depends on the model family:
- **RF / XGB**: ``feature_importances_`` from the fitted estimator. - **ElasticNet**: ``coef_`` (mean absolute value across classes for multi-class problems). - **SVM (linear kernel)**: ``coef_``. - **SVM (rbf / other kernel)**: permutation importance on training data (slower but kernel-agnostic).
Args:
top_nMaximum number of features to return. Defaults to ``30``.Returns: DataFrame with columns ``feature`` and ``importance``, sorted by ``importance`` descending.
Examples: >>> df = clf.get_feature_importance(top_n=20) >>> print(df.head())
plot_feature_importance(self, top_n: int=30, fig_height: int=700, fig_width: int=900, font_size: int=14)
Plot a horizontal bar chart of the top feature importances.
Args:
top_nNumber of top features to display. Defaults to ``30``.fig_heightFigure height in pixels. Defaults to ``700``.fig_widthFigure width in pixels. Defaults to ``900``.font_sizeBase font size for axis labels and tick text.Defaults to ``14``.
Returns: Plotly Figure object.
Examples: >>> fig = clf.plot_feature_importance(top_n=20) >>> fig.show()
plot_confusion_matrix(self, normalize: bool=True, fig_height: int=600, fig_width: int=700, font_size: int=14)
Plot a heatmap confusion matrix using training-set predictions.
Args:
normalizeWhether to normalize each row to sum to 1.0.Defaults to ``True``.
fig_heightFigure height in pixels. Defaults to ``600``.fig_widthFigure width in pixels. Defaults to ``700``.font_sizeBase font size. Defaults to ``14``.Returns: Plotly Figure object.
Examples: >>> fig = clf.plot_confusion_matrix(normalize=True) >>> fig.show()
plot_roc(self, fig_height: int=600, fig_width: int=800, font_size: int=14)
Plot one-vs-rest ROC curves using cross-validated probability estimates.
Each class gets its own curve with AUC displayed in the legend. A micro-average ROC curve is also included.
Args:
fig_heightFigure height in pixels. Defaults to ``600``.fig_widthFigure width in pixels. Defaults to ``800``.font_sizeBase font size. Defaults to ``14``.Returns: Plotly Figure object.
Examples: >>> fig = clf.plot_roc() >>> fig.show()
get_cv_results(self)
Return the stored cross-validation metrics.
Returns: Dictionary with keys ``accuracy_mean``, ``accuracy_std``, ``balanced_accuracy_mean``, ``balanced_accuracy_std``, ``roc_auc_mean``, and ``roc_auc_std``.
Raises:
RuntimeErrorIf ``fit()`` has not been called yet.Examples: >>> clf.fit() >>> print(clf.get_cv_results())