%23%20%2F%2F%2F%20script%0A%23%20requires-python%20%3D%20%22%3E%3D3.11%22%0A%23%20dependencies%20%3D%20%5B%0A%23%20%20%20%20%20%22numpy%22%2C%0A%23%20%20%20%20%20%22plotly%22%2C%0A%23%20%20%20%20%20%22scikit-learn%22%2C%0A%23%20%20%20%20%20%22sklearn-wrap%22%2C%0A%23%20%5D%0A%23%20%2F%2F%2F%0A%22%22%22%0A%23%20How%20to%20Use%20GridSearchCV%20with%20Wrappers%0A%0AThis%20notebook%20shows%20how%20to%20run%20%60GridSearchCV%60%20on%20a%20wrapped%20estimator%0Ato%20find%20optimal%20hyperparameters%20automatically.%0A%22%22%22%0A%0Aimport%20marimo%0A%0A__generated_with%20%3D%20%220.23.2%22%0Aapp%20%3D%20marimo.App(width%3D%22medium%22)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20marimo%20as%20mo%0A%0A%20%20%20%20return%20(mo%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_()%3A%0A%20%20%20%20import%20numpy%20as%20np%0A%20%20%20%20from%20sklearn.model_selection%20import%20GridSearchCV%0A%0A%20%20%20%20from%20sklearn_wrap%20import%20BaseClassWrapper%0A%0A%20%20%20%20return%20BaseClassWrapper%2C%20GridSearchCV%2C%20np%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20In%20this%20tutorial%20we%20wrap%20a%20custom%20k-Nearest%20Neighbors%20classifier%20and%20hand%20it%20to%0A%20%20%20%20%60GridSearchCV%60.%20We%20will%20define%20a%20parameter%20grid%2C%20run%20cross-validated%20search%2C%0A%20%20%20%20and%20inspect%20both%20the%20results%20and%20the%20HTML%20representation%20of%20the%20meta-estimator.%0A%0A%20%20%20%20**Prerequisites**%20-%20Familiarity%20with%20%5Bfirst_wrapper.py%5D(first_wrapper.py)%20and%20%5Bparameter_interface.py%5D(parameter_interface.py).%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%201.%20A%20Custom%20KNN%20Classifier%0A%0A%20%20%20%20We%20start%20with%20a%20simple%20classifier%20that%20uses%20its%20own%20method%20names.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(np)%3A%0A%20%20%20%20class%20KNNClassifier%3A%0A%20%20%20%20%20%20%20%20%22%22%22KNN%20classifier%20without%20sklearn%20conventions.%22%22%22%0A%0A%20%20%20%20%20%20%20%20def%20__init__(self%2C%20n_neighbors%3D3%2C%20distance_metric%3D%22euclidean%22)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self._k_value%20%3D%20n_neighbors%0A%20%20%20%20%20%20%20%20%20%20%20%20self._metric_type%20%3D%20distance_metric%0A%0A%20%20%20%20%20%20%20%20def%20train_classifier(self%2C%20X%2C%20y)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22Store%20training%20data%20(not%20'fit').%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20self._training_features%20%3D%20X%0A%20%20%20%20%20%20%20%20%20%20%20%20self._training_labels%20%3D%20y%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self%0A%0A%20%20%20%20%20%20%20%20def%20classify(self%2C%20X)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%22%22%22Classify%20samples%20(not%20'predict').%22%22%22%0A%20%20%20%20%20%20%20%20%20%20%20%20predictions%20%3D%20%5B%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20for%20x%20in%20X%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20distances%20%3D%20self._calculate_distances(x)%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nearest_indices%20%3D%20np.argsort(distances)%5B%3A%20self._k_value%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20nearest_labels%20%3D%20self._training_labels%5Bnearest_indices%5D%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20predictions.append(np.bincount(nearest_labels).argmax())%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20np.array(predictions)%0A%0A%20%20%20%20%20%20%20%20def%20_calculate_distances(self%2C%20x)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20if%20self._metric_type%20%3D%3D%20%22euclidean%22%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%20np.sqrt(((self._training_features%20-%20x)%20**%202).sum(axis%3D1))%0A%20%20%20%20%20%20%20%20%20%20%20%20elif%20self._metric_type%20%3D%3D%20%22manhattan%22%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20return%20np.abs(self._training_features%20-%20x).sum(axis%3D1)%0A%20%20%20%20%20%20%20%20%20%20%20%20else%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20raise%20ValueError(f%22Unknown%20metric%3A%20%7Bself._metric_type%7D%22)%0A%0A%20%20%20%20return%20(KNNClassifier%2C)%0A%0A%0A%40app.cell%0Adef%20_(BaseClassWrapper)%3A%0A%20%20%20%20class%20KNNWrapper(BaseClassWrapper)%3A%0A%20%20%20%20%20%20%20%20_estimator_name%20%3D%20%22knn%22%0A%20%20%20%20%20%20%20%20_estimator_base_class%20%3D%20object%0A%0A%20%20%20%20%20%20%20%20def%20fit(self%2C%20X%2C%20y)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20self.instantiate()%0A%20%20%20%20%20%20%20%20%20%20%20%20self.instance_.train_classifier(X%2C%20y)%0A%0A%20%20%20%20%20%20%20%20%20%20%20%20%23%20Mark%20estimator%20as%20fitted%20for%20sklearn%20compatibility%0A%20%20%20%20%20%20%20%20%20%20%20%20self.fitted_%20%3D%20True%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self%0A%0A%20%20%20%20%20%20%20%20def%20predict(self%2C%20X)%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20return%20self.instance_.classify(X)%0A%0A%20%20%20%20return%20(KNNWrapper%2C)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%202.%20Running%20GridSearchCV%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function(hide_code%3DTrue)%0Adef%20generate_classification_data(n_samples%3D300%2C%20n_features%3D2%2C%20n_classes%3D2%2C%20test_size%3D0.3%2C%20random_state%3D42%2C%20**kwargs)%3A%0A%20%20%20%20from%20sklearn.datasets%20import%20make_classification%0A%20%20%20%20from%20sklearn.model_selection%20import%20train_test_split%0A%20%20%20%20%23%20Ensure%20n_informative%20is%20sufficient%20for%20n_classes%0A%20%20%20%20n_informative%20%3D%20max(n_features%2C%20n_classes)%0A%20%20%20%20X%2C%20y%20%3D%20make_classification(n_samples%3Dn_samples%2C%20n_features%3Dn_informative%2C%20n_classes%3Dn_classes%2C%20n_informative%3Dn_informative%2C%20n_redundant%3D0%2C%20random_state%3Drandom_state%2C%20**kwargs)%0A%20%20%20%20stratify%20%3D%20y%20if%20n_classes%20%3E%201%20else%20None%0A%20%20%20%20return%20train_test_split(X%2C%20y%2C%20test_size%3Dtest_size%2C%20random_state%3Drandom_state%2C%20stratify%3Dstratify)%0A%0A%0A%40app.cell%0Adef%20_(GridSearchCV%2C%20KNNClassifier%2C%20KNNWrapper)%3A%0A%20%20%20%20%23%20Generate%20data%0A%20%20%20%20X_train%2C%20X_test%2C%20y_train%2C%20y_test%20%3D%20generate_classification_data(n_samples%3D200%2C%20n_classes%3D3)%0A%0A%20%20%20%20%23%20Define%20parameter%20grid%0A%20%20%20%20param_grid%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%22n_neighbors%22%3A%20%5B3%2C%205%2C%207%2C%209%5D%2C%0A%20%20%20%20%20%20%20%20%22distance_metric%22%3A%20%5B%22euclidean%22%2C%20%22manhattan%22%5D%2C%0A%20%20%20%20%7D%0A%0A%20%20%20%20%23%20Create%20wrapper%20and%20run%20grid%20search%0A%20%20%20%20wrapper%20%3D%20KNNWrapper(knn%3DKNNClassifier)%0A%20%20%20%20grid_search%20%3D%20GridSearchCV(wrapper%2C%20param_grid%2C%20cv%3D3%2C%20scoring%3D%22accuracy%22%2C%20return_train_score%3DTrue)%0A%20%20%20%20grid_search.fit(X_train%2C%20y_train)%0A%0A%20%20%20%20%23%20Extract%20results%0A%20%20%20%20best_params%20%3D%20grid_search.best_params_%0A%20%20%20%20best_score%20%3D%20grid_search.best_score_%0A%20%20%20%20test_score%20%3D%20grid_search.score(X_test%2C%20y_test)%0A%20%20%20%20cv_results%20%3D%20grid_search.cv_results_%0A%20%20%20%20best_estimator%20%3D%20grid_search.best_estimator_%0A%20%20%20%20return%20(%0A%20%20%20%20%20%20%20%20best_estimator%2C%0A%20%20%20%20%20%20%20%20best_params%2C%0A%20%20%20%20%20%20%20%20best_score%2C%0A%20%20%20%20%20%20%20%20cv_results%2C%0A%20%20%20%20%20%20%20%20grid_search%2C%0A%20%20%20%20%20%20%20%20test_score%2C%0A%20%20%20%20)%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(best_params%2C%20best_score%2C%20mo%2C%20test_score)%3A%0A%20%20%20%20mo.md(f%22%22%22%0A%20%20%20%20%23%23%203.%20Results%0A%0A%20%20%20%20**Best%20Parameters%3A**%20%60%7Bbest_params%7D%60%0A%0A%20%20%20%20**Best%20CV%20Score%3A**%20%7Bbest_score%3A.3f%7D%0A%0A%20%20%20%20**Test%20Score%3A**%20%7Btest_score%3A.3f%7D%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.function(hide_code%3DTrue)%0Adef%20create_comparison_bars(categories%2C%20values_dict%2C%20title%2C%20yaxis_title%3D%22Score%22%2C%20colors%3DNone%2C%20error_bars%3DNone%2C%20**layout_kwargs)%3A%0A%20%20%20%20import%20plotly.graph_objects%20as%20go%0A%20%20%20%20fig%20%3D%20go.Figure()%0A%20%20%20%20default_colors%20%3D%20%5B%22lightblue%22%2C%20%22lightcoral%22%2C%20%22lightgreen%22%2C%20%22lightyellow%22%5D%0A%20%20%20%20for%20i%2C%20(name%2C%20values)%20in%20enumerate(values_dict.items())%3A%0A%20%20%20%20%20%20%20%20color%20%3D%20colors.get(name)%20if%20colors%20else%20default_colors%5Bi%20%25%20len(default_colors)%5D%0A%20%20%20%20%20%20%20%20trace_kwargs%20%3D%20%7B%22name%22%3A%20name%2C%20%22x%22%3A%20categories%2C%20%22y%22%3A%20values%2C%20%22marker%22%3A%20dict(color%3Dcolor)%2C%20%22text%22%3A%20%5Bf%22%7Bv%3A.4f%7D%22%20for%20v%20in%20values%5D%2C%20%22textposition%22%3A%20%22outside%22%7D%0A%20%20%20%20%20%20%20%20if%20error_bars%20and%20name%20in%20error_bars%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20trace_kwargs%5B%22error_y%22%5D%20%3D%20dict(type%3D%22data%22%2C%20array%3Derror_bars%5Bname%5D%2C%20visible%3DTrue)%0A%20%20%20%20%20%20%20%20fig.add_trace(go.Bar(**trace_kwargs))%0A%20%20%20%20fig.update_layout(title%3Dtitle%2C%20yaxis_title%3Dyaxis_title%2C%20barmode%3D%22group%22%2C%20height%3D400%2C%20**layout_kwargs)%0A%20%20%20%20return%20fig%0A%0A%0A%40app.cell%0Adef%20_(cv_results%2C%20np)%3A%0A%20%20%20%20%23%20Extract%20top%205%20configurations%0A%20%20%20%20sorted_indices%20%3D%20np.argsort(cv_results%5B%22rank_test_score%22%5D)%5B%3A5%5D%0A%0A%20%20%20%20categories%20%3D%20%5B%0A%20%20%20%20%20%20%20%20f%22k%3D%7Bcv_results%5B'param_n_neighbors'%5D%5Bi%5D%7D%2C%20%7Bcv_results%5B'param_distance_metric'%5D%5Bi%5D%5B%3A3%5D%7D%22%0A%20%20%20%20%20%20%20%20for%20i%20in%20sorted_indices%0A%20%20%20%20%5D%0A%0A%20%20%20%20values_dict%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%22Train%22%3A%20%5Bcv_results%5B%22mean_train_score%22%5D%5Bi%5D%20for%20i%20in%20sorted_indices%5D%2C%0A%20%20%20%20%20%20%20%20%22CV%22%3A%20%5Bcv_results%5B%22mean_test_score%22%5D%5Bi%5D%20for%20i%20in%20sorted_indices%5D%2C%0A%20%20%20%20%7D%0A%0A%20%20%20%20error_bars%20%3D%20%7B%0A%20%20%20%20%20%20%20%20%22CV%22%3A%20%5Bcv_results%5B%22std_test_score%22%5D%5Bi%5D%20for%20i%20in%20sorted_indices%5D%2C%0A%20%20%20%20%7D%0A%0A%20%20%20%20fig%20%3D%20create_comparison_bars(%0A%20%20%20%20%20%20%20%20categories%2C%0A%20%20%20%20%20%20%20%20values_dict%2C%0A%20%20%20%20%20%20%20%20%22Top%205%20Configurations%22%2C%0A%20%20%20%20%20%20%20%20yaxis_title%3D%22Accuracy%22%2C%0A%20%20%20%20%20%20%20%20error_bars%3Derror_bars%2C%0A%20%20%20%20)%0A%20%20%20%20fig%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%204.%20HTML%20Representation%0A%0A%20%20%20%20Notice%20that%20%60GridSearchCV%60%20and%20the%20best%20estimator%20both%20render%20correctly%20-%20the%0A%20%20%20%20wrapper%20is%20fully%20transparent%20to%20sklearn's%20display%20machinery.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%23%20GridSearchCV%20Object%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(grid_search)%3A%0A%20%20%20%20grid_search%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%23%20Best%20Estimator%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0A%40app.cell%0Adef%20_(best_estimator)%3A%0A%20%20%20%20best_estimator%0A%20%20%20%20return%0A%0A%0A%40app.cell(hide_code%3DTrue)%0Adef%20_(mo)%3A%0A%20%20%20%20mo.md(%22%22%22%0A%20%20%20%20%23%23%20What%20We%20Built%0A%0A%20%20%20%20We%20ran%20a%20full%20%60GridSearchCV%60%20on%20a%20wrapped%20estimator%20with%20zero%20extra%20glue%20code.%0A%20%20%20%20The%20parameter%20interface%20we%20explored%20in%20%5Bparameter_interface.py%5D(parameter_interface.py)%0A%20%20%20%20is%20what%20makes%20this%20possible%20-%20%60GridSearchCV%60%20uses%20%60get_params()%60%20and%20%60set_params()%60%0A%20%20%20%20behind%20the%20scenes.%0A%0A%20%20%20%20Next%3A%20%5Bnested_wrappers.py%5D(nested_wrappers.py)%20extends%20this%20to%20estimators%20that%0A%20%20%20%20contain%20other%20estimators.%0A%20%20%20%20%22%22%22)%0A%20%20%20%20return%0A%0A%0Aif%20__name__%20%3D%3D%20%22__main__%22%3A%0A%20%20%20%20app.run()%0A
3b7ae34fc87a300cb5052c40894cb0b5