%23%20%2F%2F%2F%20script%0A%23%20requires-python%20%3D%20%22%3E%3D3.11%22%0A%23%20dependencies%20%3D%20%5B%0A%23%20%20%20%20%20%22numpy%22%2C%0A%23%20%20%20%20%20%22scikit-learn%22%2C%0A%23%20%20%20%20%20%22sklearn-wrap%22%2C%0A%23%20%20%20%20%20%22xgboost%22%2C%0A%23%20%5D%0A%23%20%2F%2F%2F%0A%22%22%22%0A%23%20How%20to%20Wrap%20XGBoost%0A%0AWrap%20XGBoost's%20low-level%20Booster%20API%20into%20a%20scikit-learn%20compatible%20estimator.%0A%22%22%22%0A%0Aimport%20marimo%0A%0A__generated_with%20%3D%20%220.23.2%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%0A%20%20%20%20return%20(mo%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20numpy%20as%20np%0A%20%20%20%20import%20xgboost%20as%20xgb%0A%0A%20%20%20%20from%20sklearn_wrap%20import%20BaseClassWrapper%0A%0A%20%20%20%20return%20BaseClassWrapper%2C%20np%2C%20xgb%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20This%20guide%20wraps%20XGBoost's%20procedural%20%60xgb.train()%60%20API%20using%20an%20adapter%0A%20%20%20%20class%2C%20then%20adds%20nested%20callback%20control%20via%20the%20%60__%60%20syntax.%0A%0A%20%20%20%20**Prerequisites**%20-%20Familiarity%20with%20%5Bfirst_wrapper.py%5D(first_wrapper.py).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%201.%20Create%20the%20Adapter%20and%20Wrapper%0A%0A%20%20%20%20Build%20an%20adapter%20class%20around%20%60xgb.train()%60%2C%20then%20wrap%20it%20with%20nested%20callback%20support.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(BaseClassWrapper%2C%20xgb)%3A%0A%20%20%20%20from%20sklearn_wrap%20import%20base%20as%20skw_base%0A%0A%20%20%20%20class%20XGBoostCallbackWrapper(BaseClassWrapper)%3A%0A%20%20%20%20%20%20%20%20%22%22%22Wrap%20XGBoost%20callback%20classes.%0A%0A%20%20%20%20%20%20%20%20Enables%20nested%20parameter%20control%20over%20callbacks%20using%20the%20__%20syntax.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%0A%20%20%20%20%20%20%20%20_estimator_name%20%3D%20%22callback%22%0A%20%20%20%20%20%20%20%20_estimator_base_class%20%3D%20xgb.callback.TrainingCallback%20%20%23%20XGBoost%20callbacks%20inherit%20from%20TrainingCallback%0A%0A%20%20%20%20%23%20Create%20an%20adapter%20class%20that%20wraps%20xgb.train()%20as%20a%20normal%20class%0A%20%20%20%20class%20XGBoostTrainer%3A%0A%20%20%20%20%20%20%20%20%22%22%22Adapter%20class%20that%20wraps%20XGBoost's%20train()%20function%20as%20a%20trainable%20class.%0A%0A%20%20%20%20%20%20%20%20This%20demonstrates%20how%20to%20wrap%20procedural%20APIs%20by%20creating%20an%20adapter%20class.%0A%20%20%20%20%20%20%20%20%22%22%22%0A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20num_boost_round%3D100%2C%20callbacks%3DNone%2C%20**xgb_params)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.num_boost_round%20%3D%20num_boost_round%0A%20%20%20%20%20%20%20%20%20%20%20%20self.callbacks%20%3D%20callbacks%0A%20%20%20%20%20%20%20%20%20%20%20%20self.xgb_params%20%3D%20xgb_params%0A%0A%20%20%20%20%20%20%20%20def%20fit_model(self%2C%20X%2C%20y)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20dtrain%20%3D%20xgb.DMatrix(X%2C%20label%3Dy)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Instantiate%20wrapped%20callbacks%20if%20provided%0A%20%20%20%20%20%20%20%20%20%20%20%20callback_instances%20%3D%20None%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20self.callbacks%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20if%20isinstance(self.callbacks%2C%20list)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20callback_instances%20%3D%20%5B%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20cb.instance_%20if%20hasattr(cb%2C%20'instance_')%20else%20cb.instantiate().instance_%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20for%20cb%20in%20self.callbacks%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20cb%20%3D%20self.callbacks%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20callback_instances%20%3D%20%5Bcb.instance_%20if%20hasattr(cb%2C%20'instance_')%20else%20cb.instantiate().instance_%5D%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.booster_%20%3D%20xgb.train(%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20self.xgb_params%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20dtrain%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20num_boost_round%3Dself.num_boost_round%2C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20callbacks%3Dcallback_instances%0A%20%20%20%20%20%20%20%20%20%20%20%20)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self%0A%0A%20%20%20%20%20%20%20%20def%20predict_output(self%2C%20X)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20dtest%20%3D%20xgb.DMatrix(X)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self.booster_.predict(dtest)%0A%0A%20%20%20%20class%20XGBoostTrainerWrapper(BaseClassWrapper)%3A%0A%20%20%20%20%20%20%20%20%22%22%22Wrap%20XGBoost's%20training%20process%20with%20nested%20callback%20support.%0A%0A%20%20%20%20%20%20%20%20This%20wrapper%20demonstrates%3A%0A%20%20%20%20%20%20%20%201.%20Wrapping%20an%20adapter%20class%20(XGBoostTrainer)%20that%20bridges%20procedural%20APIs%0A%20%20%20%20%20%20%20%202.%20Nested%20wrappers%20for%20callbacks%20with%20automatic%20parameter%20handling%0A%20%20%20%20%20%20%20%203.%20No%20need%20to%20override%20__init__%2Fget_params%2Fset_params%20-%20it%20all%20works%20automatically!%0A%20%20%20%20%20%20%20%204.%20Using%20_parameter_constraints%20to%20validate%20nested%20wrapper%20parameters%0A%20%20%20%20%20%20%20%205.%20Using%20_estimator_default_class%20to%20avoid%20passing%20the%20class%20every%20time%0A%20%20%20%20%20%20%20%20%22%22%22%0A%0A%20%20%20%20%20%20%20%20_estimator_name%20%3D%20%22trainer%22%0A%20%20%20%20%20%20%20%20_estimator_base_class%20%3D%20object%0A%20%20%20%20%20%20%20%20_estimator_default_class%20%3D%20XGBoostTrainer%0A%20%20%20%20%20%20%20%20_parameter_constraints%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Validates%20that%20callbacks%20is%20None%20or%20a%20BaseClassWrapper%20wrapping%20a%20class%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20that%20inherits%20from%20xgb.callback.TrainingCallback%20(the%20base%20for%20XGBoost%20callbacks)%0A%20%20%20%20%20%20%20%20%20%20%20%20%22callbacks%22%3A%20%5BNone%2C%20%7B%22wrapper_base_class%22%3A%20xgb.callback.TrainingCallback%7D%5D%0A%20%20%20%20%20%20%20%20%7D%0A%0A%20%20%20%20%20%20%20%20%40skw_base._fit_context(prefer_skip_nested_validation%3DTrue)%0A%20%20%20%20%20%20%20%20def%20fit(self%2C%20X%2C%20y)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.instance_.fit_model(X%2C%20y)%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self%0A%0A%20%20%20%20%20%20%20%20def%20predict(self%2C%20X)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self.instance_.predict_output(X)%0A%0A%20%20%20%20return%20XGBoostCallbackWrapper%2C%20XGBoostTrainerWrapper%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%202.%20Train%20with%20Interactive%20Parameters%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20def%20create_slider(start%2C%20stop%2C%20value%2C%20label%2C%20step%3DNone%2C%20**kwargs)%3A%0A%20%20%20%20%20%20%20%20params%20%3D%20%7B%22start%22%3A%20start%2C%20%22stop%22%3A%20stop%2C%20%22value%22%3A%20value%2C%20%22label%22%3A%20label%2C%20%22show_value%22%3A%20True%2C%20**kwargs%7D%0A%20%20%20%20%20%20%20%20if%20step%20is%20not%20None%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20params%5B%22step%22%5D%20%3D%20step%0A%20%20%20%20%20%20%20%20return%20mo.ui.slider(**params)%0A%0A%20%20%20%20return%20(create_slider%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(create_slider%2C%20mo)%3A%0A%20%20%20%20depth_slider%20%3D%20create_slider(1%2C%2010%2C%203%2C%20%22Max%20Depth%22)%0A%20%20%20%20eta_slider%20%3D%20create_slider(0.01%2C%200.5%2C%200.1%2C%20%22Learning%20Rate%20(eta)%22%2C%20step%3D0.01)%0A%20%20%20%20mo.hstack(%5Bdepth_slider%2C%20eta_slider%5D%2C%20justify%3D%22space-around%22)%0A%20%20%20%20return%20depth_slider%2C%20eta_slider%0A%0A%0A%40app.function(hide_code%3DTrue)%0Adef%20generate_regression_data(n_samples%3D300%2C%20n_features%3D2%2C%20noise%3D20%2C%20test_size%3D0.3%2C%20random_state%3D42%2C%20**kwargs)%3A%0A%20%20%20%20from%20sklearn.datasets%20import%20make_regression%0A%20%20%20%20from%20sklearn.model_selection%20import%20train_test_split%0A%20%20%20%20X%2C%20y%20%3D%20make_regression(n_samples%3Dn_samples%2C%20n_features%3Dn_features%2C%20noise%3Dnoise%2C%20random_state%3Drandom_state%2C%20**kwargs)%0A%20%20%20%20return%20train_test_split(X%2C%20y%2C%20test_size%3Dtest_size%2C%20random_state%3Drandom_state)%0A%0A%0A%40app.cell%0Adef%20_(%0A%20%20%20%20XGBoostCallbackWrapper%2C%0A%20%20%20%20XGBoostTrainerWrapper%2C%0A%20%20%20%20depth_slider%2C%0A%20%20%20%20eta_slider%2C%0A%20%20%20%20xgb%2C%0A)%3A%0A%20%20%20%20%23%20Create%20wrapped%20evaluation%20monitor%20callback%20(doesn't%20require%20validation%20set)%0A%20%20%20%20eval_callback%20%3D%20XGBoostCallbackWrapper(%0A%20%20%20%20%20%20%20%20callback%3Dxgb.callback.EvaluationMonitor%2C%0A%20%20%20%20%20%20%20%20period%3D10%2C%0A%20%20%20%20%20%20%20%20show_stdv%3DFalse%0A%20%20%20%20)%0A%0A%20%20%20%20%23%20Use%20single%20callback%20(not%20list)%20to%20demonstrate%20nested%20parameter%20syntax%0A%20%20%20%20wrapper%20%3D%20XGBoostTrainerWrapper(%0A%20%20%20%20%20%20%20%20num_boost_round%3D50%2C%0A%20%20%20%20%20%20%20%20callbacks%3Deval_callback%2C%20%20%23%20Single%20callback%20for%20nested%20params%20demo%0A%20%20%20%20%20%20%20%20max_depth%3Ddepth_slider.value%2C%0A%20%20%20%20%20%20%20%20eta%3Deta_slider.value%2C%0A%20%20%20%20%20%20%20%20objective%3D%22reg%3Asquarederror%22%0A%20%20%20%20)%0A%0A%20%20%20%20X_train%2C%20X_test%2C%20y_train%2C%20y_test%20%3D%20generate_regression_data(n_features%3D5%2C%20noise%3D20)%0A%20%20%20%20wrapper.fit(X_train%2C%20y_train)%0A%0A%20%20%20%20y_pred_train%20%3D%20wrapper.predict(X_train)%0A%20%20%20%20y_pred_test%20%3D%20wrapper.predict(X_test)%0A%20%20%20%20X_plot%20%3D%20None%0A%20%20%20%20y_pred_plot%20%3D%20None%0A%20%20%20%20return%20wrapper%2C%20y_pred_test%2C%20y_pred_train%2C%20y_test%2C%20y_train%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(np)%3A%0A%20%20%20%20def%20calculate_r2_score(y_true%2C%20y_pred)%3A%0A%20%20%20%20%20%20%20%20return%201%20-%20np.mean((y_true%20-%20y_pred)%20**%202)%20%2F%20np.var(y_true)%0A%0A%20%20%20%20return%20(calculate_r2_score%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(calculate_r2_score)%3A%0A%20%20%20%20def%20calculate_train_test_scores(y_train%2C%20y_pred_train%2C%20y_test%2C%20y_pred_test)%3A%0A%20%20%20%20%20%20%20%20return%20(calculate_r2_score(y_train%2C%20y_pred_train)%2C%20calculate_r2_score(y_test%2C%20y_pred_test))%0A%0A%20%20%20%20return%20(calculate_train_test_scores%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(%0A%20%20%20%20calculate_train_test_scores%2C%0A%20%20%20%20depth_slider%2C%0A%20%20%20%20eta_slider%2C%0A%20%20%20%20mo%2C%0A%20%20%20%20y_pred_test%2C%0A%20%20%20%20y_pred_train%2C%0A%20%20%20%20y_test%2C%0A%20%20%20%20y_train%2C%0A)%3A%0A%20%20%20%20train_r2%2C%20test_r2%20%3D%20calculate_train_test_scores(%0A%20%20%20%20%20%20%20%20y_train%2C%20y_pred_train%2C%20y_test%2C%20y_pred_test%0A%20%20%20%20)%0A%0A%20%20%20%20mo.md(%0A%20%20%20%20%20%20%20%20f%22%22%22%0A%20%20%20%20%20%20%20%20%23%23%203.%20Results%0A%0A%20%20%20%20%20%20%20%20**Train%20R%C2%B2%3A**%20%7Btrain_r2%3A.3f%7D%0A%0A%20%20%20%20%20%20%20%20**Test%20R%C2%B2%3A**%20%7Btest_r2%3A.3f%7D%0A%0A%20%20%20%20%20%20%20%20Max%20Depth%3A%20%7Bdepth_slider.value%7D%2C%20Learning%20Rate%3A%20%7Beta_slider.value%3A.2f%7D%0A%20%20%20%20%20%20%20%20%22%22%22%0A%20%20%20%20)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%204.%20HTML%20Representation%0A%0A%20%20%20%20Wrapped%20XGBoost%20models%20display%20correctly.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%23%20XGBoost%20Wrapper%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(wrapper)%3A%0A%20%20%20%20wrapper%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%205.%20Control%20Callback%20Parameters%0A%0A%20%20%20%20Use%20the%20%60__%60%20syntax%20to%20modify%20callback%20settings%20without%20recreating%20the%20wrapper.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(wrapper)%3A%0A%20%20%20%20%23%20Demonstrate%20nested%20parameter%20access%0A%20%20%20%20params%20%3D%20wrapper.get_params(deep%3DTrue)%0A%0A%20%20%20%20%23%20Show%20callback-related%20parameters%0A%20%20%20%20callback_params%20%3D%20%7Bk%3A%20str(v)%20for%20k%2C%20v%20in%20params.items()%20if%20'callbacks'%20in%20k%7D%0A%20%20%20%20callback_params%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%23%20Modify%20Nested%20Parameters%0A%0A%20%20%20%20Change%20the%20callback%20period%20with%20nested%20syntax%3A%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(wrapper)%3A%0A%20%20%20%20%23%20Use%20nested%20parameter%20syntax%20to%20change%20callback%20settings%0A%20%20%20%20wrapper_modified%20%3D%20wrapper.set_params(callbacks__period%3D5)%0A%0A%20%20%20%20%23%20Verify%20the%20change%20-%20automatic%20parameter%20handling%20works!%0A%20%20%20%20modified_period%20%3D%20wrapper_modified.get_params(deep%3DTrue)%5B'callbacks__period'%5D%0A%20%20%20%20modified_period%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20**More%20examples%3A**%20%5Bserialization.py%5D(serialization.py)%20%7C%20%5Byaml_config.py%5D(yaml_config.py)%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
1da73321d80978f453572d28fef3e56d