%23%20%2F%2F%2F%20script%0A%23%20requires-python%20%3D%20%22%3E%3D3.11%22%0A%23%20dependencies%20%3D%20%5B%0A%23%20%20%20%20%20%22numpy%22%2C%0A%23%20%20%20%20%20%22scikit-learn%22%2C%0A%23%20%20%20%20%20%22sklearn-wrap%22%2C%0A%23%20%5D%0A%23%20%2F%2F%2F%0A%22%22%22%0A%23%20How%20to%20Serialize%20Estimators%0A%0ASave%20and%20load%20wrapped%20estimators%2C%20pipelines%2C%20and%20GridSearchCV%20objects%20with%20%60joblib%60.%0A%22%22%22%0A%0Aimport%20marimo%0A%0A__generated_with%20%3D%20%220.23.2%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%0A%20%20%20%20return%20(mo%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20numpy%20as%20np%0A%20%20%20%20import%20tempfile%0A%20%20%20%20from%20pathlib%20import%20Path%0A%0A%20%20%20%20from%20sklearn.model_selection%20import%20GridSearchCV%0A%20%20%20%20from%20sklearn.pipeline%20import%20Pipeline%0A%20%20%20%20from%20sklearn.preprocessing%20import%20StandardScaler%0A%20%20%20%20import%20joblib%0A%0A%20%20%20%20from%20sklearn_wrap%20import%20BaseClassWrapper%0A%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20BaseClassWrapper%2C%0A%20%20%20%20%20%20%20%20GridSearchCV%2C%0A%20%20%20%20%20%20%20%20Pipeline%2C%0A%20%20%20%20%20%20%20%20StandardScaler%2C%0A%20%20%20%20%20%20%20%20joblib%2C%0A%20%20%20%20%20%20%20%20np%2C%0A%20%20%20%20%20%20%20%20tempfile%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20This%20guide%20covers%20saving%20and%20loading%20wrapped%20estimators%20with%20%60joblib%60%20-%0A%20%20%20%20individually%2C%20inside%20pipelines%2C%20and%20as%20part%20of%20%60GridSearchCV%60%20results.%0A%0A%20%20%20%20**Prerequisites**%20-%20Familiarity%20with%20%5Bfirst_wrapper.py%5D(first_wrapper.py).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%201.%20Save%20and%20Load%20a%20Single%20Estimator%0A%0A%20%20%20%20Use%20%60joblib.dump()%60%20and%20%60joblib.load()%60%20exactly%20as%20with%20any%20sklearn%20estimator.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(np)%3A%0A%20%20%20%20class%20SimpleRegressor%3A%0A%20%20%20%20%20%20%20%20%22%22%22A%20simple%20regressor%20without%20sklearn%20conventions.%22%22%22%0A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20multiplier%3D1.0)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self._scale_factor%20%3D%20multiplier%0A%0A%20%20%20%20%20%20%20%20def%20train_model(self%2C%20X%2C%20y)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22Train%20by%20computing%20scaled%20mean%20(not%20'fit').%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20self._computed_value%20%3D%20y.mean()%20*%20self._scale_factor%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self%0A%0A%20%20%20%20%20%20%20%20def%20generate_predictions(self%2C%20X)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22Generate%20predictions%20(not%20'predict').%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20np.full(X.shape%5B0%5D%2C%20self._computed_value)%0A%0A%20%20%20%20return%20(SimpleRegressor%2C)%0A%0A%0A%40app.cell%0Adef%20_(BaseClassWrapper)%3A%0A%20%20%20%20class%20SimpleWrapper(BaseClassWrapper)%3A%0A%20%20%20%20%20%20%20%20_estimator_name%20%3D%20%22regressor%22%0A%20%20%20%20%20%20%20%20_estimator_base_class%20%3D%20object%0A%0A%20%20%20%20%20%20%20%20def%20fit(self%2C%20X%2C%20y)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.instantiate()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.instance_.train_model(X%2C%20y)%0A%20%20%20%20%20%20%20%20%20%20%20%20self.fitted_%20%3D%20True%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self%0A%0A%20%20%20%20%20%20%20%20def%20predict(self%2C%20X)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self.instance_.generate_predictions(X)%0A%0A%20%20%20%20return%20(SimpleWrapper%2C)%0A%0A%0A%40app.function(hide_code%3DTrue)%0Adef%20generate_regression_data(n_samples%3D100%2C%20n_features%3D1%2C%20noise%3D1.0%2C%20random_state%3D42)%3A%0A%20%20%20%20from%20sklearn.datasets%20import%20make_regression%0A%20%20%20%20from%20sklearn.model_selection%20import%20train_test_split%0A%20%20%20%20X%2C%20y%20%3D%20make_regression(n_samples%3Dn_samples%2C%20n_features%3Dn_features%2C%20noise%3Dnoise%2C%20random_state%3Drandom_state)%0A%20%20%20%20return%20train_test_split(X%2C%20y%2C%20test_size%3D0.3%2C%20random_state%3Drandom_state)%0A%0A%0A%40app.cell%0Adef%20_(SimpleRegressor%2C%20SimpleWrapper%2C%20np)%3A%0A%20%20%20%20%23%20Train%20and%20fit%0A%20%20%20%20X_train%2C%20X_test%2C%20y_train%2C%20y_test%20%3D%20generate_regression_data()%0A%0A%20%20%20%20estimator%20%3D%20SimpleWrapper(regressor%3DSimpleRegressor%2C%20multiplier%3D2.0)%0A%20%20%20%20estimator.fit(X_train%2C%20y_train)%0A%0A%20%20%20%20original_predictions%20%3D%20estimator.predict(X_test)%0A%20%20%20%20original_score%20%3D%20np.mean((original_predictions%20-%20y_test)%20**%202)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20X_test%2C%0A%20%20%20%20%20%20%20%20X_train%2C%0A%20%20%20%20%20%20%20%20estimator%2C%0A%20%20%20%20%20%20%20%20original_predictions%2C%0A%20%20%20%20%20%20%20%20original_score%2C%0A%20%20%20%20%20%20%20%20y_test%2C%0A%20%20%20%20%20%20%20%20y_train%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell%0Adef%20_(estimator%2C%20joblib%2C%20tempfile)%3A%0A%20%20%20%20%23%20Save%20to%20file%0A%20%20%20%20temp_dir%20%3D%20tempfile.mkdtemp()%0A%20%20%20%20estimator_path%20%3D%20f%22%7Btemp_dir%7D%2Festimator.pkl%22%0A%20%20%20%20joblib.dump(estimator%2C%20estimator_path)%0A%0A%20%20%20%20%23%20Load%20from%20file%0A%20%20%20%20loaded_estimator%20%3D%20joblib.load(estimator_path)%0A%20%20%20%20return%20loaded_estimator%2C%20temp_dir%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20X_test%2C%0A%20%20%20%20loaded_estimator%2C%0A%20%20%20%20mo%2C%0A%20%20%20%20np%2C%0A%20%20%20%20original_predictions%2C%0A%20%20%20%20original_score%2C%0A%20%20%20%20y_test%2C%0A)%3A%0A%20%20%20%20%23%20Verify%20loaded%20estimator%20works%0A%20%20%20%20loaded_predictions%20%3D%20loaded_estimator.predict(X_test)%0A%20%20%20%20loaded_score%20%3D%20np.mean((loaded_predictions%20-%20y_test)%20**%202)%0A%0A%20%20%20%20mo.md(f%22%22%22%0A%20%20%20%20%23%23%23%20Estimator%20Serialization%0A%0A%20%20%20%20-%20Original%20MSE%3A%20%7Boriginal_score%3A.2f%7D%0A%20%20%20%20-%20Loaded%20MSE%3A%20%7Bloaded_score%3A.2f%7D%0A%20%20%20%20-%20Predictions%20match%3A%20%7Bnp.allclose(original_predictions%2C%20loaded_predictions)%7D%0A%0A%20%20%20%20The%20loaded%20estimator%20produces%20identical%20predictions.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%202.%20Save%20and%20Load%20a%20Pipeline%0A%0A%20%20%20%20Pipelines%20containing%20wrapped%20estimators%20persist%20all%20preprocessing%20steps.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20Pipeline%2C%0A%20%20%20%20SimpleRegressor%2C%0A%20%20%20%20SimpleWrapper%2C%0A%20%20%20%20StandardScaler%2C%0A%20%20%20%20X_test%2C%0A%20%20%20%20X_train%2C%0A%20%20%20%20y_train%2C%0A)%3A%0A%20%20%20%20%23%20Create%20and%20fit%20pipeline%0A%20%20%20%20pipeline%20%3D%20Pipeline(%5B%0A%20%20%20%20%20%20%20%20(%22scaler%22%2C%20StandardScaler())%2C%0A%20%20%20%20%20%20%20%20(%22regressor%22%2C%20SimpleWrapper(regressor%3DSimpleRegressor%2C%20multiplier%3D1.5))%0A%20%20%20%20%5D)%0A%0A%20%20%20%20pipeline.fit(X_train%2C%20y_train)%0A%20%20%20%20pipeline_predictions%20%3D%20pipeline.predict(X_test)%0A%20%20%20%20return%20pipeline%2C%20pipeline_predictions%0A%0A%0A%40app.cell%0Adef%20_(joblib%2C%20pipeline%2C%20temp_dir)%3A%0A%20%20%20%20%23%20Save%20and%20load%20pipeline%0A%20%20%20%20pipeline_path%20%3D%20f%22%7Btemp_dir%7D%2Fpipeline.pkl%22%0A%20%20%20%20joblib.dump(pipeline%2C%20pipeline_path)%0A%20%20%20%20loaded_pipeline%20%3D%20joblib.load(pipeline_path)%0A%20%20%20%20return%20(loaded_pipeline%2C)%0A%0A%0A%40app.cell%0Adef%20_(X_test%2C%20loaded_pipeline)%3A%0A%20%20%20%20%23%20Verify%20pipeline%0A%20%20%20%20loaded_pipeline_predictions%20%3D%20loaded_pipeline.predict(X_test)%0A%20%20%20%20return%20(loaded_pipeline_predictions%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(%0A%20%20%20%20loaded_pipeline%2C%0A%20%20%20%20loaded_pipeline_predictions%2C%0A%20%20%20%20mo%2C%0A%20%20%20%20np%2C%0A%20%20%20%20pipeline_predictions%2C%0A)%3A%0A%20%20%20%20mo.md(f%22%22%22%0A%20%20%20%20%23%23%23%20Pipeline%20Serialization%0A%0A%20%20%20%20-%20Predictions%20match%3A%20%7Bnp.allclose(pipeline_predictions%2C%20loaded_pipeline_predictions)%7D%0A%20%20%20%20-%20Pipeline%20steps%20preserved%3A%20%7Blist(loaded_pipeline.named_steps.keys())%7D%0A%0A%20%20%20%20The%20loaded%20pipeline%20maintains%20all%20preprocessing%20steps%20and%20wrapped%20estimator.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%203.%20Save%20and%20Load%20GridSearchCV%20Results%0A%0A%20%20%20%20Save%20the%20complete%20search%20state%20including%20the%20best%20estimator%20and%20all%20CV%20results.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(GridSearchCV%2C%20SimpleRegressor%2C%20SimpleWrapper%2C%20X_test%2C%20X_train%2C%20y_train)%3A%0A%20%20%20%20%23%20Run%20grid%20search%0A%20%20%20%20grid%20%3D%20GridSearchCV(%0A%20%20%20%20%20%20%20%20SimpleWrapper(regressor%3DSimpleRegressor)%2C%0A%20%20%20%20%20%20%20%20param_grid%3D%7B%22multiplier%22%3A%20%5B0.5%2C%201.0%2C%201.5%2C%202.0%5D%7D%2C%0A%20%20%20%20%20%20%20%20cv%3D3%2C%0A%20%20%20%20%20%20%20%20scoring%3D%22neg_mean_squared_error%22%0A%20%20%20%20)%0A%0A%20%20%20%20grid.fit(X_train%2C%20y_train)%0A%20%20%20%20grid_predictions%20%3D%20grid.predict(X_test)%0A%20%20%20%20best_params%20%3D%20grid.best_params_%0A%20%20%20%20return%20best_params%2C%20grid%2C%20grid_predictions%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(grid%2C%20joblib%2C%20temp_dir)%3A%0A%20%20%20%20%23%20Save%20and%20load%20grid%20search%0A%20%20%20%20grid_path%20%3D%20f%22%7Btemp_dir%7D%2Fgrid.pkl%22%0A%20%20%20%20joblib.dump(grid%2C%20grid_path)%0A%20%20%20%20loaded_grid%20%3D%20joblib.load(grid_path)%0A%20%20%20%20return%20(loaded_grid%2C)%0A%0A%0A%40app.cell%0Adef%20_(X_test%2C%20loaded_grid)%3A%0A%20%20%20%20%23%20Verify%20grid%20search%0A%20%20%20%20loaded_grid_predictions%20%3D%20loaded_grid.predict(X_test)%0A%20%20%20%20return%20(loaded_grid_predictions%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(%0A%20%20%20%20best_params%2C%0A%20%20%20%20grid_predictions%2C%0A%20%20%20%20loaded_grid%2C%0A%20%20%20%20loaded_grid_predictions%2C%0A%20%20%20%20mo%2C%0A%20%20%20%20np%2C%0A)%3A%0A%20%20%20%20mo.md(f%22%22%22%0A%20%20%20%20%23%23%23%20GridSearchCV%20Serialization%0A%0A%20%20%20%20-%20Best%20params%20preserved%3A%20%60%7Bloaded_grid.best_params_%7D%60%0A%20%20%20%20-%20Original%20best%20params%3A%20%60%7Bbest_params%7D%60%0A%20%20%20%20-%20Predictions%20match%3A%20%7Bnp.allclose(grid_predictions%2C%20loaded_grid_predictions)%7D%0A%20%20%20%20-%20CV%20results%20available%3A%20%7Blen(loaded_grid.cv_results_%5B'params'%5D)%7D%20configs%20tested%0A%0A%20%20%20%20The%20loaded%20GridSearchCV%20retains%20best%20estimator%20and%20all%20cross-validation%20results.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20**More%20examples%3A**%20%5Bxgboost_wrapper.py%5D(xgboost_wrapper.py)%20%7C%20%5Byaml_config.py%5D(yaml_config.py)%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
689cc7dfefbed8d321cfe34c0e192976