%23%20%2F%2F%2F%20script%0A%23%20requires-python%20%3D%20%22%3E%3D3.11%22%0A%23%20dependencies%20%3D%20%5B%0A%23%20%20%20%20%20%22numpy%22%2C%0A%23%20%20%20%20%20%22plotly%22%2C%0A%23%20%20%20%20%20%22scikit-learn%22%2C%0A%23%20%20%20%20%20%22sklearn-wrap%22%2C%0A%23%20%5D%0A%23%20%2F%2F%2F%0A%22%22%22%0A%23%20Your%20First%20Wrapper%0A%0AIn%20this%20notebook%2C%20we%20wrap%20a%20custom%20Python%20class%20into%20a%20scikit-learn%20compatible%0Aestimator%20using%20BaseClassWrapper.%0A%22%22%22%0A%0Aimport%20marimo%0A%0A__generated_with%20%3D%20%220.23.2%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%0A%20%20%20%20return%20(mo%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20numpy%20as%20np%0A%0A%20%20%20%20from%20sklearn_wrap%20import%20BaseClassWrapper%0A%0A%20%20%20%20return%20BaseClassWrapper%2C%20np%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20In%20this%20notebook%2C%20we%20wrap%20a%20custom%20Python%20class%20into%20a%20scikit-learn%0A%20%20%20%20compatible%20estimator%20using%20%60BaseClassWrapper%60.%20We%20bridge%20non-standard%20method%0A%20%20%20%20names%20to%20sklearn's%20%60fit%60%2F%60predict%60%20interface%20and%20explore%20the%20result%0A%20%20%20%20interactively.%0A%0A%20%20%20%20**Prerequisites%3A**%20None%20-%20this%20is%20the%20starting%20point.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%201.%20The%20Pattern%0A%0A%20%20%20%20We%20wrap%20any%20Python%20class%20for%20sklearn%20in%203%20steps%3A%0A%0A%20%20%20%201.%20Inherit%20from%20%60BaseClassWrapper%60%0A%20%20%20%202.%20Set%20%60_estimator_name%60%20and%20%60_estimator_base_class%60%0A%20%20%20%203.%20Implement%20%60fit()%60%20and%20%60predict()%60%20(calling%20%60instantiate()%60%20first)%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%202.%20Custom%20Polynomial%20Regressor%0A%0A%20%20%20%20Let's%20define%20a%20non-sklearn%20class%20that%20implements%20polynomial%20regression%20with%20gradient%20descent.%0A%20%20%20%20It%20uses%20its%20own%20method%20names%20and%20doesn't%20inherit%20from%20BaseEstimator.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(np)%3A%0A%20%20%20%20class%20PolynomialRegressor%3A%0A%20%20%20%20%20%20%20%20%22%22%22Custom%20polynomial%20regressor%20without%20sklearn%20conventions.%0A%0A%20%20%20%20%20%20%20%20Uses%20train%2Fcompute_predictions%20instead%20of%20fit%2Fpredict.%0A%20%20%20%20%20%20%20%20Stores%20parameters%20with%20different%20internal%20names.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20degree%3D1%2C%20learning_rate%3D0.01)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Internal%20attributes%20use%20different%20names%20than%20parameters%0A%20%20%20%20%20%20%20%20%20%20%20%20self._poly_degree%20%3D%20degree%0A%20%20%20%20%20%20%20%20%20%20%20%20self._step_size%20%3D%20learning_rate%0A%20%20%20%20%20%20%20%20%20%20%20%20self._weights%20%3D%20None%0A%0A%20%20%20%20%20%20%20%20def%20train(self%2C%20X%2C%20y)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22Train%20the%20model%20(not%20'fit').%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20X_poly%20%3D%20self._create_poly_features(X)%0A%20%20%20%20%20%20%20%20%20%20%20%20n_samples%2C%20n_features%20%3D%20X_poly.shape%0A%20%20%20%20%20%20%20%20%20%20%20%20self._weights%20%3D%20np.zeros(n_features)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Gradient%20descent%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20_%20in%20range(1000)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20y_pred%20%3D%20X_poly%20%40%20self._weights%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20gradient%20%3D%20-2%20*%20X_poly.T%20%40%20(y%20-%20y_pred)%20%2F%20n_samples%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20self._weights%20-%3D%20self._step_size%20*%20gradient%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self%0A%0A%20%20%20%20%20%20%20%20def%20compute_predictions(self%2C%20X)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22Compute%20predictions%20(not%20'predict').%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20self._weights%20is%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20raise%20ValueError(%22Must%20train%20before%20predict%22)%0A%20%20%20%20%20%20%20%20%20%20%20%20X_poly%20%3D%20self._create_poly_features(X)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20X_poly%20%40%20self._weights%0A%0A%20%20%20%20%20%20%20%20def%20_create_poly_features(self%2C%20X)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20X%20%3D%20np.asarray(X).reshape(-1%2C%201)%20if%20X.ndim%20%3D%3D%201%20else%20X%0A%20%20%20%20%20%20%20%20%20%20%20%20features%20%3D%20%5Bnp.ones((X.shape%5B0%5D%2C%201))%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20d%20in%20range(1%2C%20self._poly_degree%20%2B%201)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20features.append(X**d)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20np.hstack(features)%0A%0A%20%20%20%20return%20(PolynomialRegressor%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%203.%20Wrap%20It%20for%20sklearn%0A%0A%20%20%20%20Now%20we%20bridge%20the%20custom%20class%20to%20sklearn's%20interface.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(BaseClassWrapper)%3A%0A%20%20%20%20class%20PolyWrapper(BaseClassWrapper)%3A%0A%20%20%20%20%20%20%20%20_estimator_name%20%3D%20%22poly_regressor%22%0A%20%20%20%20%20%20%20%20_estimator_base_class%20%3D%20object%0A%0A%20%20%20%20%20%20%20%20def%20fit(self%2C%20X%2C%20y)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.instantiate()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.instance_.train(X%2C%20y)%20%20%23%20Call%20train()%2C%20not%20fit()%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Mark%20estimator%20as%20fitted%20for%20sklearn%20compatibility%0A%20%20%20%20%20%20%20%20%20%20%20%20self.fitted_%20%3D%20True%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self%0A%0A%20%20%20%20%20%20%20%20def%20predict(self%2C%20X)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self.instance_.compute_predictions(X)%20%20%23%20Call%20compute_predictions()%0A%0A%20%20%20%20return%20(PolyWrapper%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%204.%20Interactive%20Demo%0A%0A%20%20%20%20Let's%20adjust%20hyperparameters%20and%20see%20results%20in%20real-time.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20def%20create_slider(start%2C%20stop%2C%20value%2C%20label%2C%20step%3DNone%2C%20**kwargs)%3A%0A%20%20%20%20%20%20%20%20params%20%3D%20%7B%22start%22%3A%20start%2C%20%22stop%22%3A%20stop%2C%20%22value%22%3A%20value%2C%20%22label%22%3A%20label%2C%20%22show_value%22%3A%20True%2C%20**kwargs%7D%0A%20%20%20%20%20%20%20%20if%20step%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20params%5B%22step%22%5D%20%3D%20step%0A%20%20%20%20%20%20%20%20return%20mo.ui.slider(**params)%0A%0A%20%20%20%20return%20(create_slider%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(create_slider%2C%20mo)%3A%0A%20%20%20%20degree_slider%20%3D%20create_slider(1%2C%205%2C%202%2C%20%22Polynomial%20Degree%22)%0A%20%20%20%20lr_slider%20%3D%20create_slider(0.001%2C%200.1%2C%200.01%2C%20%22Learning%20Rate%22%2C%20step%3D0.001)%0A%20%20%20%20mo.hstack(%5Bdegree_slider%2C%20lr_slider%5D%2C%20justify%3D%22space-around%22)%0A%20%20%20%20return%20degree_slider%2C%20lr_slider%0A%0A%0A%40app.function(hide_code%3DTrue)%0Adef%20generate_regression_data(n_samples%3D300%2C%20n_features%3D2%2C%20noise%3D20%2C%20test_size%3D0.3%2C%20random_state%3D42%2C%20**kwargs)%3A%0A%20%20%20%20from%20sklearn.datasets%20import%20make_regression%0A%20%20%20%20from%20sklearn.model_selection%20import%20train_test_split%0A%20%20%20%20X%2C%20y%20%3D%20make_regression(n_samples%3Dn_samples%2C%20n_features%3Dn_features%2C%20noise%3Dnoise%2C%20random_state%3Drandom_state%2C%20**kwargs)%0A%20%20%20%20return%20train_test_split(X%2C%20y%2C%20test_size%3Dtest_size%2C%20random_state%3Drandom_state)%0A%0A%0A%40app.cell%0Adef%20_(PolyWrapper%2C%20PolynomialRegressor%2C%20degree_slider%2C%20lr_slider%2C%20np)%3A%0A%20%20%20%20%23%20Create%20wrapper%20with%20slider%20values%0A%20%20%20%20wrapper%20%3D%20PolyWrapper(%0A%20%20%20%20%20%20%20%20poly_regressor%3DPolynomialRegressor%2C%0A%20%20%20%20%20%20%20%20degree%3Ddegree_slider.value%2C%0A%20%20%20%20%20%20%20%20learning_rate%3Dlr_slider.value%2C%0A%20%20%20%20)%0A%0A%20%20%20%20X_train%2C%20X_test%2C%20y_train%2C%20y_test%20%3D%20generate_regression_data(n_features%3D1%2C%20noise%3D15)%0A%20%20%20%20wrapper.fit(X_train%2C%20y_train)%0A%0A%20%20%20%20%23%20Make%20predictions%0A%20%20%20%20y_pred_train%20%3D%20wrapper.predict(X_train)%0A%20%20%20%20y_pred_test%20%3D%20wrapper.predict(X_test)%0A%20%20%20%20X_plot%20%3D%20np.linspace(X_train.min()%2C%20X_train.max()%2C%20200).reshape(-1%2C%201)%0A%20%20%20%20y_pred_plot%20%3D%20wrapper.predict(X_plot)%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20X_plot%2C%0A%20%20%20%20%20%20%20%20X_test%2C%0A%20%20%20%20%20%20%20%20X_train%2C%0A%20%20%20%20%20%20%20%20wrapper%2C%0A%20%20%20%20%20%20%20%20y_pred_plot%2C%0A%20%20%20%20%20%20%20%20y_pred_test%2C%0A%20%20%20%20%20%20%20%20y_pred_train%2C%0A%20%20%20%20%20%20%20%20y_test%2C%0A%20%20%20%20%20%20%20%20y_train%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(np)%3A%0A%20%20%20%20def%20calculate_r2_score(y_true%2C%20y_pred)%3A%0A%20%20%20%20%20%20%20%20return%201%20-%20np.mean((y_true%20-%20y_pred)%20**%202)%20%2F%20np.var(y_true)%0A%0A%20%20%20%20return%20(calculate_r2_score%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(calculate_r2_score)%3A%0A%20%20%20%20def%20calculate_train_test_scores(y_train%2C%20y_pred_train%2C%20y_test%2C%20y_pred_test)%3A%0A%20%20%20%20%20%20%20%20return%20(calculate_r2_score(y_train%2C%20y_pred_train)%2C%20calculate_r2_score(y_test%2C%20y_pred_test))%0A%0A%20%20%20%20return%20(calculate_train_test_scores%2C)%0A%0A%0A%40app.function(hide_code%3DTrue)%0Adef%20create_regression_scatter(X_train%2C%20y_train%2C%20X_test%2C%20y_test%2C%20X_plot%2C%20y_pred_plot%2C%20train_score%2C%20test_score%2C%20title_prefix%3D%22%22%2C%20**layout_kwargs)%3A%0A%20%20%20%20import%20plotly.graph_objects%20as%20go%0A%20%20%20%20fig%20%3D%20go.Figure()%0A%20%20%20%20fig.add_trace(go.Scatter(x%3DX_train.flatten()%2C%20y%3Dy_train%2C%20mode%3D%22markers%22%2C%20name%3D%22Training%20Data%22%2C%20marker%3Ddict(size%3D8%2C%20color%3D%22lightblue%22%2C%20line%3Ddict(width%3D1%2C%20color%3D%22darkblue%22))))%0A%20%20%20%20fig.add_trace(go.Scatter(x%3DX_test.flatten()%2C%20y%3Dy_test%2C%20mode%3D%22markers%22%2C%20name%3D%22Test%20Data%22%2C%20marker%3Ddict(size%3D8%2C%20color%3D%22lightcoral%22%2C%20line%3Ddict(width%3D1%2C%20color%3D%22darkred%22))))%0A%20%20%20%20fig.add_trace(go.Scatter(x%3DX_plot.flatten()%2C%20y%3Dy_pred_plot%2C%20mode%3D%22lines%22%2C%20name%3D%22Model%20Prediction%22%2C%20line%3Ddict(color%3D%22green%22%2C%20width%3D3)))%0A%20%20%20%20title%20%3D%20f%22Train%20R%C2%B2%20%3D%20%7Btrain_score%3A.3f%7D%2C%20Test%20R%C2%B2%20%3D%20%7Btest_score%3A.3f%7D%22%0A%20%20%20%20if%20title_prefix%3A%0A%20%20%20%20%20%20%20%20title%20%3D%20f%22%7Btitle_prefix%7D%3Cbr%3E%7Btitle%7D%22%0A%20%20%20%20fig.update_layout(title%3Dtitle%2C%20xaxis_title%3D%22Feature%22%2C%20yaxis_title%3D%22Target%22%2C%20height%3D500%2C%20showlegend%3DTrue%2C%20**layout_kwargs)%0A%20%20%20%20return%20fig%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20X_plot%2C%0A%20%20%20%20X_test%2C%0A%20%20%20%20X_train%2C%0A%20%20%20%20calculate_train_test_scores%2C%0A%20%20%20%20degree_slider%2C%0A%20%20%20%20lr_slider%2C%0A%20%20%20%20y_pred_plot%2C%0A%20%20%20%20y_pred_test%2C%0A%20%20%20%20y_pred_train%2C%0A%20%20%20%20y_test%2C%0A%20%20%20%20y_train%2C%0A)%3A%0A%20%20%20%20train_r2%2C%20test_r2%20%3D%20calculate_train_test_scores(y_train%2C%20y_pred_train%2C%20y_test%2C%20y_pred_test)%0A%0A%20%20%20%20fig%20%3D%20create_regression_scatter(%0A%20%20%20%20%20%20%20%20X_train%2C%0A%20%20%20%20%20%20%20%20y_train%2C%0A%20%20%20%20%20%20%20%20X_test%2C%0A%20%20%20%20%20%20%20%20y_test%2C%0A%20%20%20%20%20%20%20%20X_plot%2C%0A%20%20%20%20%20%20%20%20y_pred_plot%2C%0A%20%20%20%20%20%20%20%20train_r2%2C%0A%20%20%20%20%20%20%20%20test_r2%2C%0A%20%20%20%20%20%20%20%20title_prefix%3Df%22Polynomial%20Regression%20(degree%3D%7Bdegree_slider.value%7D%2C%20lr%3D%7Blr_slider.value%3A.3f%7D)%22%2C%0A%20%20%20%20)%0A%20%20%20%20fig%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%205.%20HTML%20Representation%0A%0A%20%20%20%20Notice%20that%20wrapped%20estimators%20display%20nicely%20in%20interactive%20environments.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(wrapper)%3A%0A%20%20%20%20wrapper%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%20What%20We%20Built%0A%0A%20%20%20%20We%20wrapped%20a%20custom%20%60PolynomialRegressor%60%20class%20into%20an%20sklearn-compatible%20estimator%0A%20%20%20%20using%20%60BaseClassWrapper%60.%20Along%20the%20way%2C%20we%3A%0A%0A%20%20%20%20-%20Defined%20%60_estimator_name%60%20and%20%60_estimator_base_class%60%20as%20required%20class%20attributes%0A%20%20%20%20-%20Bridged%20custom%20method%20names%20(%60train%60%2F%60compute_predictions%60)%20to%20sklearn's%20%60fit%60%2F%60predict%60%0A%20%20%20%20-%20Got%20free%20parameter%20validation%20and%20HTML%20representation%0A%0A%20%20%20%20**Next%20steps%3A**%0A%0A%20%20%20%20-%20The%20parameter%20interface%20in%20depth%3A%0A%20%20%20%20%20%20%5BView%5D(%2Fexamples%2Fparameter_interface%2F)%20%C2%B7%20%5BOpen%20in%20marimo%5D(%2Fexamples%2Fparameter_interface%2Fedit%2F)%0A%20%20%20%20-%20The%20fit%20context%20decorator%3A%0A%20%20%20%20%20%20%5BView%5D(%2Fexamples%2Ffit_context%2F)%20%C2%B7%20%5BOpen%20in%20marimo%5D(%2Fexamples%2Ffit_context%2Fedit%2F)%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
91b52a8f9614feaf246593ab064c6dda