Skip to content

BaseClassWrapper

sklearn_wrap.base.BaseClassWrapper

Bases: BaseEstimator

Base class for wrapping classes into scikit-learn estimators.

Inheriting from this class provides default implementations of:

  • setting and getting parameters used by GridSearchCV and friends;
  • textual and HTML representation displayed in terminals and IDEs;
  • estimator serialization;
  • parameters validation;
  • data validation;
  • metadata routing.

Parameters

Name Type Description Default
**params

The keyword argument matching _estimator_name provides the class to wrap (optional when _estimator_default_class is set). Remaining keyword arguments are passed as constructor parameters to the wrapped class.

{}

Examples

>>> import numpy as np
>>> from sklearn_wrap.base import BaseClassWrapper, _fit_context
>>>
>>> # Define a simple estimator class to wrap
>>> class SimpleRegressor:
...     def __init__(self, multiplier=1.0, offset=0.0):
...         self.multiplier = multiplier
...         self.offset = offset
...
...     def fit(self, X, y):
...         return self
...
...     def predict(self, X):
...         return np.full(X.shape[0], self.multiplier) + self.offset
>>>
>>> # Wrap it with BaseClassWrapper and use _fit_context
>>> class MyEstimator(BaseClassWrapper):
...     _estimator_name = "regressor"
...     _estimator_base_class = object
...
...     @_fit_context(prefer_skip_nested_validation=True)
...     def fit(self, X, y=None):
...         # instantiate() is called automatically by decorator
...         self.instance_.fit(X, y)
...         return self
...
...     def predict(self, X):
...         return self.instance_.predict(X)
>>>
>>> # Use it like any sklearn estimator with parameter management
>>> estimator = MyEstimator(regressor=SimpleRegressor, multiplier=2.0, offset=1.0)
>>> params = estimator.get_params()
>>> params["multiplier"]
2.0
>>> params["offset"]
1.0
>>> params["regressor"]
<class '...SimpleRegressor'>
>>> X = np.array([[1, 2], [2, 3], [3, 4]])
>>> y = np.array([1, 0, 1])
>>> estimator.fit(X, y).predict(X)
array([3., 3., 3.])
>>> # Parameters can be updated via set_params
>>> estimator.set_params(multiplier=3.0, offset=0.5).fit(X, y).predict(X)
array([3.5, 3.5, 3.5])

See Also

_fit_context : Decorator for automatic instantiation during fit.

References

  1. Scikit-learn developer guide: conventions for estimator compatibility.
  2. Scikit-learn BaseEstimator: the parent class providing get_params / set_params.

Source Code

Show/Hide source
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
class BaseClassWrapper(BaseEstimator, metaclass=abc.ABCMeta):
    """Base class for wrapping classes into scikit-learn estimators.

    Inheriting from this class provides default implementations of:

    - setting and getting parameters used by `GridSearchCV` and friends;
    - textual and HTML representation displayed in terminals and IDEs;
    - estimator serialization;
    - parameters validation;
    - data validation;
    - metadata routing.


    Parameters
    ----------
    **params
        The keyword argument matching ``_estimator_name`` provides the class
        to wrap (optional when ``_estimator_default_class`` is set).
        Remaining keyword arguments are passed as constructor parameters to
        the wrapped class.

    Examples
    --------
    >>> import numpy as np
    >>> from sklearn_wrap.base import BaseClassWrapper, _fit_context
    >>>
    >>> # Define a simple estimator class to wrap
    >>> class SimpleRegressor:
    ...     def __init__(self, multiplier=1.0, offset=0.0):
    ...         self.multiplier = multiplier
    ...         self.offset = offset
    ...
    ...     def fit(self, X, y):
    ...         return self
    ...
    ...     def predict(self, X):
    ...         return np.full(X.shape[0], self.multiplier) + self.offset
    >>>
    >>> # Wrap it with BaseClassWrapper and use _fit_context
    >>> class MyEstimator(BaseClassWrapper):
    ...     _estimator_name = "regressor"
    ...     _estimator_base_class = object
    ...
    ...     @_fit_context(prefer_skip_nested_validation=True)
    ...     def fit(self, X, y=None):
    ...         # instantiate() is called automatically by decorator
    ...         self.instance_.fit(X, y)
    ...         return self
    ...
    ...     def predict(self, X):
    ...         return self.instance_.predict(X)
    >>>
    >>> # Use it like any sklearn estimator with parameter management
    >>> estimator = MyEstimator(regressor=SimpleRegressor, multiplier=2.0, offset=1.0)
    >>> params = estimator.get_params()
    >>> params["multiplier"]
    2.0
    >>> params["offset"]
    1.0
    >>> params["regressor"]  # doctest: +ELLIPSIS
    <class '...SimpleRegressor'>
    >>> X = np.array([[1, 2], [2, 3], [3, 4]])
    >>> y = np.array([1, 0, 1])
    >>> estimator.fit(X, y).predict(X)
    array([3., 3., 3.])
    >>> # Parameters can be updated via set_params
    >>> estimator.set_params(multiplier=3.0, offset=0.5).fit(X, y).predict(X)
    array([3.5, 3.5, 3.5])

    See Also
    --------
    _fit_context : Decorator for automatic instantiation during fit.

    References
    ----------
    1. [Scikit-learn developer guide](https://scikit-learn.org/stable/developers/develop.html):
            conventions for estimator compatibility.
    2. [Scikit-learn BaseEstimator](https://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html):
            the parent class providing `get_params` / `set_params`.
    """

    _required_parameters: list[str] = []
    _estimator_name: str | None = None
    _estimator_base_class = None
    _estimator_default_class: type | None = None
    _parameter_constraints: dict[str, list] = {}  # For validating parameter types

    def __init_subclass__(cls, **kwargs):
        """Set ``_required_parameters`` automatically for subclasses.

        When a subclass defines ``_estimator_name`` and optionally
        ``_estimator_default_class``, this hook populates
        ``_required_parameters`` so that scikit-learn utilities (e.g.
        ``clone``) know which constructor arguments are mandatory.

        See Also
        --------
        BaseClassWrapper.__init__ : Constructor that consumes the required parameter.
        """
        super().__init_subclass__(**kwargs)
        name = getattr(cls, "_estimator_name", None)
        if isinstance(name, str):
            has_default = getattr(cls, "_estimator_default_class", None) is not None
            cls._required_parameters = [] if has_default else [name]

    def __init__(self, **params):
        name = self._estimator_name
        if not isinstance(name, str):
            raise ValueError("Class should define a static `_estimator_name`.")

        if name not in params:
            default_cls = self._estimator_default_class
            if default_cls is not None:
                params[name] = default_cls
            else:
                raise TypeError(f"{self.__class__.__name__}.__init__() missing required keyword argument: '{name}'")
        estimator_class = params.pop(name)

        self.estimator_class = self._validate_estimator_class(estimator_class)
        self.params = self._validate_estimator_params(params)

        # Validate parameter constraints (including nested wrappers)
        for param_name, param_value in self.params.items():
            if param_value is not REQUIRED_PARAM_VALUE and param_value is not None:
                self._validate_nested_wrapper_param(param_name, param_value)

    @property
    def estimator_name(self) -> str:
        """Get the name of the wrapped estimator type.

        Returns
        -------
        str
            The estimator name.

        See Also
        --------
        BaseClassWrapper.estimator_base_class : The required base class for wrapped estimators.
        """
        if not isinstance(self._estimator_name, str):
            raise ValueError("Class should define a static `_estimator_name`.")

        return self._estimator_name

    @property
    def estimator_base_class(self) -> type:
        """Get the required base class for the wrapped estimator.

        Returns
        -------
        type
            The base class.

        See Also
        --------
        BaseClassWrapper.estimator_name : The name key for the wrapped estimator.
        """
        if self._estimator_base_class is None:
            raise ValueError("Class should define a static `_estimator_base_class`.")

        return self._estimator_base_class

    def _validate_estimator_class(self, estimator_class: type) -> type:
        """Validate the estimator class.

        Parameters
        ----------
        estimator_class : type
            The estimator class to validate.

        Returns
        -------
        type
            The validated estimator class.

        See Also
        --------
        BaseClassWrapper._validate_estimator_params : Validates parameter names and defaults.
        BaseClassWrapper._validate_params : Full validation combining class and params.
        """
        if not inspect.isclass(estimator_class):
            raise TypeError(
                f"{self._estimator_name} parameter for estimator "
                f"{self.__class__.__name__} is not a class. It is {estimator_class!r}."
            )

        if not issubclass(estimator_class, self.estimator_base_class):
            base_class = self.estimator_base_class
            base_class_name = f"{base_class.__module__}.{base_class.__qualname__}"
            raise ValueError(
                f"Invalid {self._estimator_name} class {estimator_class.__name__!r} for estimator "
                f"{self.__class__.__name__!r}. Valid estimator class should be derived from "
                f"{base_class_name}."
            )

        return estimator_class

    def _validate_estimator_params(self, params: dict, *, validate_nested: bool = True):
        """Validate estimator parameters.

        Check the estimator parameter names and set the omitted ones
        to their default value as per the ``estimator_class``
        constructor.

        Parameters
        ----------
        params : dict
            Dictionary of estimator parameters. Keys should be base parameter
            names (without ``"__"`` for nested params).
        validate_nested : bool, default=True
            If False, skip validation and only return the params as-is.
            Used internally when processing already-split nested parameters.

        Returns
        -------
        dict
            Validated dictionary of estimator parameters.

        See Also
        --------
        BaseClassWrapper._validate_estimator_class : Validates the estimator class itself.
        BaseClassWrapper._validate_nested_wrapper_param : Validates nested wrapper constraints.
        """
        if not validate_nested:
            return params

        # Wrapper-specific: reject double-underscore parameter names
        for param_name in params:
            if "__" in param_name:
                raise ValueError(
                    f"Parameter name {param_name!r} cannot contain '__' (double underscore). "
                    f"This delimiter is reserved for nested parameter syntax."
                )

        return validate_class_params(self.estimator_class, params)

    def _validate_nested_wrapper_param(self, param_name: str, param_value: Any) -> None:
        """Validate a parameter value that should be a BaseClassWrapper.

        Checks parameter constraints defined in _parameter_constraints to ensure
        wrapped estimators have the correct base class.

        Parameters
        ----------
        param_name : str
            Name of the parameter being validated.
        param_value : Any
            The parameter value to validate.

        Raises
        ------
        TypeError
            If the value is not a BaseClassWrapper when required.
        ValueError
            If the wrapped estimator_class doesn't inherit from expected base class.

        See Also
        --------
        BaseClassWrapper._validate_estimator_params : Validates parameter names and defaults.
        BaseClassWrapper._validate_params : Full validation combining class and params.
        """
        if param_name not in self._parameter_constraints:
            return

        constraints = self._parameter_constraints[param_name]
        for constraint in constraints:
            # Check if constraint specifies a required wrapper base class
            if isinstance(constraint, dict) and "wrapper_base_class" in constraint:
                required_base = constraint["wrapper_base_class"]

                # Value must be a BaseClassWrapper
                if not isinstance(param_value, BaseClassWrapper):
                    raise TypeError(
                        f"Parameter {param_name!r} must be a BaseClassWrapper instance, "
                        f"got {type(param_value).__name__!r}."
                    )

                # Check the wrapped estimator_class inheritance
                if not issubclass(param_value.estimator_class, required_base):
                    raise ValueError(
                        f"Parameter {param_name!r} must wrap an estimator class derived from "
                        f"{required_base.__module__}.{required_base.__qualname__}, "
                        f"but got {param_value.estimator_class.__name__} which derives from "
                        f"{param_value.estimator_class.__bases__}."
                    )

    def _validate_params(self):
        """Validate types and values of constructor parameters.

        The expected type and values must be defined in the ``_parameter_constraints``
        class attribute, which is a dictionary ``param_name: list of constraints``.

        See Also
        --------
        BaseClassWrapper._validate_estimator_class : Validates the estimator class.
        BaseClassWrapper._validate_estimator_params : Validates parameter names and defaults.
        BaseClassWrapper._validate_nested_wrapper_param : Validates nested wrapper constraints.
        """
        self._validate_estimator_class(self.estimator_class)
        self._validate_estimator_params(self.params)

        # Validate nested wrapper parameters according to constraints
        for param_name, param_value in self.params.items():
            if param_value is not REQUIRED_PARAM_VALUE and param_value is not None:
                self._validate_nested_wrapper_param(param_name, param_value)

    def __sklearn_is_fitted__(self) -> bool:
        """Check if the estimator has been fitted.

        This method is used by sklearn's check_is_fitted() and _is_fitted() to
        determine if an estimator has been fitted.

        Checks for fitted attributes (attributes ending with '_' excluding 'instance_'),
        which is the sklearn convention for fitted attributes. Also checks for the
        `_fitted` internal flag for backward compatibility.

        Returns
        -------
        bool
            True if the estimator has fitted attributes,
            False otherwise.

        See Also
        --------
        BaseClassWrapper.instantiate : Creates the wrapped instance (does not mark as fitted).
        _fit_context : Decorator that sets the fitted state after successful fit.
        """
        # Check internal _fitted flag first (for backward compatibility)
        if getattr(self, "_fitted", False):
            return True

        # Check for fitted attributes (excluding instance_)
        fitted_attrs = [v for v in vars(self) if v.endswith("_") and not v.startswith("__") and v != "instance_"]
        return len(fitted_attrs) > 0

    def instantiate(self) -> "BaseClassWrapper":
        """Validate parameters and create an instance.

        Returns
        -------
        self

        See Also
        --------
        _fit_context : Decorator that calls instantiate automatically during fit.
        BaseClassWrapper._validate_params : Parameter validation called by this method.
        """
        self._validate_params()

        for param_name, param_value in self.params.items():
            if param_value == REQUIRED_PARAM_VALUE:
                raise ValueError(f"Class {self.estimator_class.__name__!r} requires parameter {param_name!r}.")

        self.instance_ = self.estimator_class(**self.params)

        # Reset fitted flag when creating a new instance
        self._fitted = False

        return self

    def get_params(self, deep: bool = True) -> dict[str, Any]:
        """Get parameters for this estimator.

        Parameters
        ----------
        deep : bool, default=True
            If True, will return the parameters for this estimator and
            contained subobjects that are estimators.

        Returns
        -------
        params : dict
            Parameter names mapped to their values.

        Notes
        -----
        The estimator class is always returned under the ``_estimator_name`` key
        (e.g. ``"regressor"``, ``"classifier"``). This ensures that
        ``sklearn.base.clone()`` can reconstruct the wrapper correctly, since
        ``clone()`` passes the dict returned by ``get_params(deep=False)`` as
        keyword arguments to the constructor.

        See Also
        --------
        BaseClassWrapper.set_params : Set parameters on this estimator.
        """
        out = {}
        for key, value in self.params.items():
            if deep and hasattr(value, "get_params") and not isinstance(value, type):
                deep_items = value.get_params().items()
                # Exclude the estimator class parameter from nested params
                # to prevent roundtrip issues (estimator class can't be set via set_params)
                if isinstance(value, BaseClassWrapper):
                    estimator_name = value._estimator_name
                    deep_items = [(k, v) for k, v in deep_items if k != estimator_name]
                out.update((key + "__" + k, val) for k, val in deep_items)
            out[key] = value

        out[self._estimator_name] = self.estimator_class

        return out

    def set_params(self, **params: object) -> "BaseClassWrapper":
        """Set the parameters of this estimator.

        The method works on simple estimators as well as on nested objects
        (such as ``Pipeline``). The latter have parameters of the form
        ``<component>__<parameter>`` so that it's possible to update each
        component of a nested object.

        Parameters
        ----------
        **params : dict
            Estimator parameters.

        Returns
        -------
        self : estimator instance
            Estimator instance.

        See Also
        --------
        BaseClassWrapper.get_params : Get parameters for this estimator.
        """
        if not params:
            # Simple optimization to gain speed (inspect is slow)
            return self

        # Check if trying to change estimator class
        if self._estimator_name in params:
            raise ValueError(
                f"Cannot change estimator class via set_params. "
                f"The '{self._estimator_name}' parameter cannot be set. Redeclare the "
                f"estimator class by creating a new instance of {self.__class__.__name__}."
            )

        # Step 1: Split parameters into simple and nested BEFORE validation
        # This is the key fix - we need to know which params are nested before validating
        simple_params = {}
        nested_params = defaultdict(dict)  # grouped by prefix
        has_nested = False

        for full_key, value in params.items():
            base_key, delim, sub_key = full_key.partition("__")
            if delim:  # Contains "__", so it's a nested parameter
                nested_params[base_key][sub_key] = value
                has_nested = True
            else:
                simple_params[base_key] = value

        # Step 2: Validate only the base/simple parameter names
        if simple_params:
            self._validate_estimator_params(simple_params)

        # Step 3: Validate base keys for nested params exist
        if has_nested:
            for base_key in nested_params:
                if base_key not in self.params:
                    raise ValueError(
                        f"Invalid parameter {base_key!r} for estimator {self}. "
                        f"Valid parameters are: {list(self.params.keys())!r}."
                    )

        # Step 4: Update simple parameters and validate type constraints
        for key, value in simple_params.items():
            # Validate nested wrapper parameters (Option B + C)
            if value is not None and value is not REQUIRED_PARAM_VALUE:
                self._validate_nested_wrapper_param(key, value)
            self.params[key] = value

        # Step 5: Recursively set nested parameters
        for base_key, sub_params in nested_params.items():
            nested_obj = self.params[base_key]
            if not hasattr(nested_obj, "set_params"):
                raise AttributeError(
                    f"Cannot set nested parameters on {base_key!r}. "
                    f"Object of type {type(nested_obj).__name__!r} does not have a set_params method."
                )
            nested_obj.set_params(**sub_params)

        return self

Methods

estimator_name property

Get the name of the wrapped estimator type.

Returns
Type Description
str

The estimator name.

See Also

BaseClassWrapper.estimator_base_class : The required base class for wrapped estimators.

estimator_base_class property

Get the required base class for the wrapped estimator.

Returns
Type Description
type

The base class.

See Also

BaseClassWrapper.estimator_name : The name key for the wrapped estimator.

__init_subclass__(**kwargs)

Set _required_parameters automatically for subclasses.

When a subclass defines _estimator_name and optionally _estimator_default_class, this hook populates _required_parameters so that scikit-learn utilities (e.g. clone) know which constructor arguments are mandatory.

See Also

BaseClassWrapper.init : Constructor that consumes the required parameter.

Source Code
Show/Hide source
def __init_subclass__(cls, **kwargs):
    """Set ``_required_parameters`` automatically for subclasses.

    When a subclass defines ``_estimator_name`` and optionally
    ``_estimator_default_class``, this hook populates
    ``_required_parameters`` so that scikit-learn utilities (e.g.
    ``clone``) know which constructor arguments are mandatory.

    See Also
    --------
    BaseClassWrapper.__init__ : Constructor that consumes the required parameter.
    """
    super().__init_subclass__(**kwargs)
    name = getattr(cls, "_estimator_name", None)
    if isinstance(name, str):
        has_default = getattr(cls, "_estimator_default_class", None) is not None
        cls._required_parameters = [] if has_default else [name]

__sklearn_is_fitted__()

Check if the estimator has been fitted.

This method is used by sklearn's check_is_fitted() and _is_fitted() to determine if an estimator has been fitted.

Checks for fitted attributes (attributes ending with '' excluding 'instance'), which is the sklearn convention for fitted attributes. Also checks for the _fitted internal flag for backward compatibility.

Returns
Type Description
bool

True if the estimator has fitted attributes, False otherwise.

See Also

BaseClassWrapper.instantiate : Creates the wrapped instance (does not mark as fitted). _fit_context : Decorator that sets the fitted state after successful fit.

Source Code
Show/Hide source
def __sklearn_is_fitted__(self) -> bool:
    """Check if the estimator has been fitted.

    This method is used by sklearn's check_is_fitted() and _is_fitted() to
    determine if an estimator has been fitted.

    Checks for fitted attributes (attributes ending with '_' excluding 'instance_'),
    which is the sklearn convention for fitted attributes. Also checks for the
    `_fitted` internal flag for backward compatibility.

    Returns
    -------
    bool
        True if the estimator has fitted attributes,
        False otherwise.

    See Also
    --------
    BaseClassWrapper.instantiate : Creates the wrapped instance (does not mark as fitted).
    _fit_context : Decorator that sets the fitted state after successful fit.
    """
    # Check internal _fitted flag first (for backward compatibility)
    if getattr(self, "_fitted", False):
        return True

    # Check for fitted attributes (excluding instance_)
    fitted_attrs = [v for v in vars(self) if v.endswith("_") and not v.startswith("__") and v != "instance_"]
    return len(fitted_attrs) > 0

instantiate()

Validate parameters and create an instance.

Returns
Type Description
self
See Also

_fit_context : Decorator that calls instantiate automatically during fit. BaseClassWrapper._validate_params : Parameter validation called by this method.

Source Code
Show/Hide source
def instantiate(self) -> "BaseClassWrapper":
    """Validate parameters and create an instance.

    Returns
    -------
    self

    See Also
    --------
    _fit_context : Decorator that calls instantiate automatically during fit.
    BaseClassWrapper._validate_params : Parameter validation called by this method.
    """
    self._validate_params()

    for param_name, param_value in self.params.items():
        if param_value == REQUIRED_PARAM_VALUE:
            raise ValueError(f"Class {self.estimator_class.__name__!r} requires parameter {param_name!r}.")

    self.instance_ = self.estimator_class(**self.params)

    # Reset fitted flag when creating a new instance
    self._fitted = False

    return self

get_params(deep=True)

Get parameters for this estimator.

Parameters
Name Type Description Default
deep bool

If True, will return the parameters for this estimator and contained subobjects that are estimators.

True
Returns
Name Type Description
params dict

Parameter names mapped to their values.

Notes

The estimator class is always returned under the _estimator_name key (e.g. "regressor", "classifier"). This ensures that sklearn.base.clone() can reconstruct the wrapper correctly, since clone() passes the dict returned by get_params(deep=False) as keyword arguments to the constructor.

See Also

BaseClassWrapper.set_params : Set parameters on this estimator.

Source Code
Show/Hide source
def get_params(self, deep: bool = True) -> dict[str, Any]:
    """Get parameters for this estimator.

    Parameters
    ----------
    deep : bool, default=True
        If True, will return the parameters for this estimator and
        contained subobjects that are estimators.

    Returns
    -------
    params : dict
        Parameter names mapped to their values.

    Notes
    -----
    The estimator class is always returned under the ``_estimator_name`` key
    (e.g. ``"regressor"``, ``"classifier"``). This ensures that
    ``sklearn.base.clone()`` can reconstruct the wrapper correctly, since
    ``clone()`` passes the dict returned by ``get_params(deep=False)`` as
    keyword arguments to the constructor.

    See Also
    --------
    BaseClassWrapper.set_params : Set parameters on this estimator.
    """
    out = {}
    for key, value in self.params.items():
        if deep and hasattr(value, "get_params") and not isinstance(value, type):
            deep_items = value.get_params().items()
            # Exclude the estimator class parameter from nested params
            # to prevent roundtrip issues (estimator class can't be set via set_params)
            if isinstance(value, BaseClassWrapper):
                estimator_name = value._estimator_name
                deep_items = [(k, v) for k, v in deep_items if k != estimator_name]
            out.update((key + "__" + k, val) for k, val in deep_items)
        out[key] = value

    out[self._estimator_name] = self.estimator_class

    return out

set_params(**params)

Set the parameters of this estimator.

The method works on simple estimators as well as on nested objects (such as Pipeline). The latter have parameters of the form <component>__<parameter> so that it's possible to update each component of a nested object.

Parameters
Name Type Description Default
**params dict

Estimator parameters.

{}
Returns
Name Type Description
self estimator instance

Estimator instance.

See Also

BaseClassWrapper.get_params : Get parameters for this estimator.

Source Code
Show/Hide source
def set_params(self, **params: object) -> "BaseClassWrapper":
    """Set the parameters of this estimator.

    The method works on simple estimators as well as on nested objects
    (such as ``Pipeline``). The latter have parameters of the form
    ``<component>__<parameter>`` so that it's possible to update each
    component of a nested object.

    Parameters
    ----------
    **params : dict
        Estimator parameters.

    Returns
    -------
    self : estimator instance
        Estimator instance.

    See Also
    --------
    BaseClassWrapper.get_params : Get parameters for this estimator.
    """
    if not params:
        # Simple optimization to gain speed (inspect is slow)
        return self

    # Check if trying to change estimator class
    if self._estimator_name in params:
        raise ValueError(
            f"Cannot change estimator class via set_params. "
            f"The '{self._estimator_name}' parameter cannot be set. Redeclare the "
            f"estimator class by creating a new instance of {self.__class__.__name__}."
        )

    # Step 1: Split parameters into simple and nested BEFORE validation
    # This is the key fix - we need to know which params are nested before validating
    simple_params = {}
    nested_params = defaultdict(dict)  # grouped by prefix
    has_nested = False

    for full_key, value in params.items():
        base_key, delim, sub_key = full_key.partition("__")
        if delim:  # Contains "__", so it's a nested parameter
            nested_params[base_key][sub_key] = value
            has_nested = True
        else:
            simple_params[base_key] = value

    # Step 2: Validate only the base/simple parameter names
    if simple_params:
        self._validate_estimator_params(simple_params)

    # Step 3: Validate base keys for nested params exist
    if has_nested:
        for base_key in nested_params:
            if base_key not in self.params:
                raise ValueError(
                    f"Invalid parameter {base_key!r} for estimator {self}. "
                    f"Valid parameters are: {list(self.params.keys())!r}."
                )

    # Step 4: Update simple parameters and validate type constraints
    for key, value in simple_params.items():
        # Validate nested wrapper parameters (Option B + C)
        if value is not None and value is not REQUIRED_PARAM_VALUE:
            self._validate_nested_wrapper_param(key, value)
        self.params[key] = value

    # Step 5: Recursively set nested parameters
    for base_key, sub_params in nested_params.items():
        nested_obj = self.params[base_key]
        if not hasattr(nested_obj, "set_params"):
            raise AttributeError(
                f"Cannot set nested parameters on {base_key!r}. "
                f"Object of type {type(nested_obj).__name__!r} does not have a set_params method."
            )
        nested_obj.set_params(**sub_params)

    return self