Skip to content

EstimatorConfig

sklearn_wrap.config.EstimatorConfig

Bases: BaseModel

Configuration for a scikit-learn compatible estimator.

Validates the structure of estimator configurations and can build instantiated estimators, convert existing estimators to configs, and serialize/deserialize to YAML.

Attributes

Name Type Description
estimator_class str

Fully qualified dotted import path, e.g. "sklearn.linear_model.Ridge".

params dict[str, Any]

Constructor parameters. Nested dicts with an estimator_class key are automatically converted to nested EstimatorConfig instances. Dicts with a __type__ key become _ClassRef markers resolved during build().

Examples

>>> config = EstimatorConfig(
...     estimator_class="sklearn.linear_model.Ridge",
...     params={"alpha": 1.0},
... )
>>> est = config.build()
>>> est.get_params()["alpha"]
1.0

Source Code

Show/Hide source
class EstimatorConfig(BaseModel):
    """Configuration for a scikit-learn compatible estimator.

    Validates the structure of estimator configurations and can build
    instantiated estimators, convert existing estimators to configs,
    and serialize/deserialize to YAML.

    Attributes
    ----------
    estimator_class : str
        Fully qualified dotted import path, e.g. ``"sklearn.linear_model.Ridge"``.
    params : dict[str, Any]
        Constructor parameters. Nested dicts with an ``estimator_class`` key are
        automatically converted to nested ``EstimatorConfig`` instances. Dicts
        with a ``__type__`` key become ``_ClassRef`` markers resolved during
        ``build()``.

    Examples
    --------
    >>> config = EstimatorConfig(
    ...     estimator_class="sklearn.linear_model.Ridge",
    ...     params={"alpha": 1.0},
    ... )
    >>> est = config.build()
    >>> est.get_params()["alpha"]
    1.0
    """

    estimator_class: str
    params: dict[str, Any] = Field(default_factory=dict)

    model_config = {"arbitrary_types_allowed": True}

    @field_validator("estimator_class")
    @classmethod
    def _validate_estimator_class(cls, v: str) -> str:
        """Validate that estimator_class is a dotted path with at least two segments.

        Parameters
        ----------
        v : str
            The dotted import path to validate.

        Returns
        -------
        str
            The validated path.
        """
        try:
            validate_dotted_path(v, min_segments=2)
        except ValueError:
            raise ValueError(
                f"estimator_class must be a valid dotted import path with at least "
                f"two segments (e.g. 'sklearn.linear_model.Ridge'), got {v!r}"
            ) from None
        return v

    @model_validator(mode="before")
    @classmethod
    def _convert_nested(cls, data: Any) -> Any:
        """Convert nested dicts with ``estimator_class`` keys into EstimatorConfig.

        Parameters
        ----------
        data : Any
            The raw data to validate.

        Returns
        -------
        Any
            The data with nested dicts converted.
        """
        if isinstance(data, dict):
            params = data.get("params")
            if isinstance(params, dict):
                data["params"] = _walk_params(params)
        return data

    def build(
        self,
        *,
        trusted_modules: frozenset[str] | None = None,
        validate_params: bool = True,
    ) -> Any:
        """Resolve the configuration into an instantiated estimator.

        Parameters
        ----------
        trusted_modules : frozenset[str] or None
            Allowed top-level packages for class resolution. When ``None``
            (the default), the value from the global configuration is used
            (see `set_config`).
        validate_params : bool, default=True
            If True, validate that the resolved parameter names match the
            constructor signature of the target class before instantiation.

        Returns
        -------
        estimator
            An instantiated scikit-learn compatible estimator.

        Examples
        --------
        >>> config = EstimatorConfig(
        ...     estimator_class="sklearn.linear_model.Ridge",
        ...     params={"alpha": 2.0},
        ... )
        >>> est = config.build()
        >>> est.alpha
        2.0

        See Also
        --------
        EstimatorConfig.from_estimator : Create a config from an existing estimator.
        EstimatorConfig.from_yaml : Load a config from a YAML file.
        set_config : Set global trusted modules configuration.
        """
        if trusted_modules is None:
            trusted_modules = get_config()["trusted_modules"]
        cls = _import_class(self.estimator_class, trusted_modules)
        resolved = _resolve_params(self.params, trusted_modules=trusted_modules)
        if validate_params:
            validate_class_params(cls, resolved)
        return cls(**resolved)

    @classmethod
    def from_estimator(cls, estimator: Any) -> EstimatorConfig:
        """Create a configuration from an existing estimator instance.

        Parameters
        ----------
        estimator : BaseEstimator
            A scikit-learn compatible estimator (must implement ``get_params``).

        Returns
        -------
        EstimatorConfig
            The configuration capturing the estimator's class and parameters.

        Examples
        --------
        >>> from sklearn.linear_model import Ridge
        >>> est = Ridge(alpha=3.0)
        >>> config = EstimatorConfig.from_estimator(est)
        >>> config.estimator_class
        'sklearn.linear_model._ridge.Ridge'
        >>> config.params["alpha"]
        3.0

        See Also
        --------
        EstimatorConfig.build : Instantiate an estimator from a config.
        EstimatorConfig.to_yaml : Serialize a config to YAML.
        """
        dotted = _class_to_dotted_path(type(estimator))
        raw_params = estimator.get_params(deep=False)
        params = _serialize_params(raw_params)
        return cls(estimator_class=dotted, params=params)

    def to_yaml(self, path: str | Path) -> None:
        """Write the configuration to a YAML file.

        Parameters
        ----------
        path : str or Path
            Destination file path.

        See Also
        --------
        EstimatorConfig.from_yaml : Load a config from a YAML file.
        """
        data = self.model_dump()
        with open(path, "w") as f:
            yaml.dump(data, f, default_flow_style=False, sort_keys=False)

    @classmethod
    def from_yaml(cls, path: str | Path) -> EstimatorConfig:
        """Load a configuration from a YAML file.

        Supports YAML anchors, merge keys (``<<: *alias``), and the
        ``!include`` tag for multi-file composition.

        Parameters
        ----------
        path : str or Path
            Path to the YAML file.

        Returns
        -------
        EstimatorConfig
            The validated configuration.

        See Also
        --------
        EstimatorConfig.to_yaml : Write a config to a YAML file.
        EstimatorConfig.build : Instantiate an estimator from a config.
        """
        data = _load_yaml(path)
        return cls.model_validate(data)

Methods

build(*, trusted_modules=None, validate_params=True)

Resolve the configuration into an instantiated estimator.

Parameters
Name Type Description Default
trusted_modules frozenset[str] or None

Allowed top-level packages for class resolution. When None (the default), the value from the global configuration is used (see set_config).

None
validate_params bool

If True, validate that the resolved parameter names match the constructor signature of the target class before instantiation.

True
Returns
Type Description
estimator

An instantiated scikit-learn compatible estimator.

Examples
>>> config = EstimatorConfig(
...     estimator_class="sklearn.linear_model.Ridge",
...     params={"alpha": 2.0},
... )
>>> est = config.build()
>>> est.alpha
2.0
See Also

EstimatorConfig.from_estimator : Create a config from an existing estimator. EstimatorConfig.from_yaml : Load a config from a YAML file. set_config : Set global trusted modules configuration.

Source Code
Show/Hide source
def build(
    self,
    *,
    trusted_modules: frozenset[str] | None = None,
    validate_params: bool = True,
) -> Any:
    """Resolve the configuration into an instantiated estimator.

    Parameters
    ----------
    trusted_modules : frozenset[str] or None
        Allowed top-level packages for class resolution. When ``None``
        (the default), the value from the global configuration is used
        (see `set_config`).
    validate_params : bool, default=True
        If True, validate that the resolved parameter names match the
        constructor signature of the target class before instantiation.

    Returns
    -------
    estimator
        An instantiated scikit-learn compatible estimator.

    Examples
    --------
    >>> config = EstimatorConfig(
    ...     estimator_class="sklearn.linear_model.Ridge",
    ...     params={"alpha": 2.0},
    ... )
    >>> est = config.build()
    >>> est.alpha
    2.0

    See Also
    --------
    EstimatorConfig.from_estimator : Create a config from an existing estimator.
    EstimatorConfig.from_yaml : Load a config from a YAML file.
    set_config : Set global trusted modules configuration.
    """
    if trusted_modules is None:
        trusted_modules = get_config()["trusted_modules"]
    cls = _import_class(self.estimator_class, trusted_modules)
    resolved = _resolve_params(self.params, trusted_modules=trusted_modules)
    if validate_params:
        validate_class_params(cls, resolved)
    return cls(**resolved)

from_estimator(estimator) classmethod

Create a configuration from an existing estimator instance.

Parameters
Name Type Description Default
estimator BaseEstimator

A scikit-learn compatible estimator (must implement get_params).

required
Returns
Type Description
EstimatorConfig

The configuration capturing the estimator's class and parameters.

Examples
>>> from sklearn.linear_model import Ridge
>>> est = Ridge(alpha=3.0)
>>> config = EstimatorConfig.from_estimator(est)
>>> config.estimator_class
'sklearn.linear_model._ridge.Ridge'
>>> config.params["alpha"]
3.0
See Also

EstimatorConfig.build : Instantiate an estimator from a config. EstimatorConfig.to_yaml : Serialize a config to YAML.

Source Code
Show/Hide source
@classmethod
def from_estimator(cls, estimator: Any) -> EstimatorConfig:
    """Create a configuration from an existing estimator instance.

    Parameters
    ----------
    estimator : BaseEstimator
        A scikit-learn compatible estimator (must implement ``get_params``).

    Returns
    -------
    EstimatorConfig
        The configuration capturing the estimator's class and parameters.

    Examples
    --------
    >>> from sklearn.linear_model import Ridge
    >>> est = Ridge(alpha=3.0)
    >>> config = EstimatorConfig.from_estimator(est)
    >>> config.estimator_class
    'sklearn.linear_model._ridge.Ridge'
    >>> config.params["alpha"]
    3.0

    See Also
    --------
    EstimatorConfig.build : Instantiate an estimator from a config.
    EstimatorConfig.to_yaml : Serialize a config to YAML.
    """
    dotted = _class_to_dotted_path(type(estimator))
    raw_params = estimator.get_params(deep=False)
    params = _serialize_params(raw_params)
    return cls(estimator_class=dotted, params=params)

to_yaml(path)

Write the configuration to a YAML file.

Parameters
Name Type Description Default
path str or Path

Destination file path.

required
See Also

EstimatorConfig.from_yaml : Load a config from a YAML file.

Source Code
Show/Hide source
def to_yaml(self, path: str | Path) -> None:
    """Write the configuration to a YAML file.

    Parameters
    ----------
    path : str or Path
        Destination file path.

    See Also
    --------
    EstimatorConfig.from_yaml : Load a config from a YAML file.
    """
    data = self.model_dump()
    with open(path, "w") as f:
        yaml.dump(data, f, default_flow_style=False, sort_keys=False)

from_yaml(path) classmethod

Load a configuration from a YAML file.

Supports YAML anchors, merge keys (<<: *alias), and the !include tag for multi-file composition.

Parameters
Name Type Description Default
path str or Path

Path to the YAML file.

required
Returns
Type Description
EstimatorConfig

The validated configuration.

See Also

EstimatorConfig.to_yaml : Write a config to a YAML file. EstimatorConfig.build : Instantiate an estimator from a config.

Source Code
Show/Hide source
@classmethod
def from_yaml(cls, path: str | Path) -> EstimatorConfig:
    """Load a configuration from a YAML file.

    Supports YAML anchors, merge keys (``<<: *alias``), and the
    ``!include`` tag for multi-file composition.

    Parameters
    ----------
    path : str or Path
        Path to the YAML file.

    Returns
    -------
    EstimatorConfig
        The validated configuration.

    See Also
    --------
    EstimatorConfig.to_yaml : Write a config to a YAML file.
    EstimatorConfig.build : Instantiate an estimator from a config.
    """
    data = _load_yaml(path)
    return cls.model_validate(data)