Hyperparameter Tuning#
๐น 1. Gaussian Naive Bayes (GaussianNB)#
Used when features are continuous and follow (roughly) a normal distribution.
Hyperparameters:
var_smoothing:Adds a small constant to the variance to prevent division by zero (numerical stability).
Default: \(10^{-9}\).
Tuning: Test values like \(10^{-9}, 10^{-8}, 10^{-7}, โฆ\).
Effect: Too small โ unstable estimates; too large โ overly smoothed probabilities.
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import GridSearchCV
from sklearn.naive_bayes import GaussianNB
from sklearn.model_selection import train_test_split
# Load dataset (continuous features)
X, y = load_iris(return_X_y=True)
# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Define parameter grid for var_smoothing
param_grid = {'var_smoothing': np.logspace(-9, -2, 8)}
# Grid search
grid = GridSearchCV(GaussianNB(), param_grid, cv=5, scoring='accuracy')
grid.fit(X_train, y_train)
print("Best Parameters:", grid.best_params_)
print("Best CV Accuracy:", grid.best_score_)
print("Test Accuracy:", grid.score(X_test, y_test))
Best Parameters: {'var_smoothing': np.float64(1e-09)}
Best CV Accuracy: 0.9333333333333333
Test Accuracy: 0.9777777777777777
๐น 2. Multinomial Naive Bayes (MultinomialNB)#
Used for discrete features (e.g., word counts in text classification).
Hyperparameters:
alpha(Laplace/Lidstone smoothing):Controls how much smoothing is applied to avoid zero probabilities.
Default:
1.0.Tuning: Usually tested between
0.01to10.Effect:
Small
alpha(close to 0): Model relies heavily on observed frequencies (risk of overfitting).Large
alpha: More smoothing, may underfit.
fit_prior:Whether to learn class priors from data.
Can be set
True(default) orFalse(if you want uniform priors).
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
# Load subset of text dataset
categories = ['alt.atheism', 'sci.space']
newsgroups = fetch_20newsgroups(subset='train', categories=categories, remove=('headers','footers','quotes'))
X, y = newsgroups.data, newsgroups.target
# Pipeline: vectorizer + NB
pipeline = Pipeline([
('vect', CountVectorizer()),
('nb', MultinomialNB())
])
# Define grid
param_grid = {
'nb__alpha': [0.01, 0.1, 1, 5, 10],
'nb__fit_prior': [True, False]
}
# Grid search
grid = GridSearchCV(pipeline, param_grid, cv=5, scoring='accuracy')
grid.fit(X, y)
print("Best Parameters:", grid.best_params_)
print("Best CV Accuracy:", grid.best_score_)
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[2], line 8
6 # Load subset of text dataset
7 categories = ['alt.atheism', 'sci.space']
----> 8 newsgroups = fetch_20newsgroups(subset='train', categories=categories, remove=('headers','footers','quotes'))
9 X, y = newsgroups.data, newsgroups.target
11 # Pipeline: vectorizer + NB
File c:\Users\sangouda\AppData\Local\Programs\Python\Python312\Lib\site-packages\sklearn\utils\_param_validation.py:218, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
212 try:
213 with config_context(
214 skip_parameter_validation=(
215 prefer_skip_nested_validation or global_skip_validation
216 )
217 ):
--> 218 return func(*args, **kwargs)
219 except InvalidParameterError as e:
220 # When the function is just a wrapper around an estimator, we allow
221 # the function to delegate validation to the estimator, but we replace
222 # the name of the estimator by the name of the function in the error
223 # message to avoid confusion.
224 msg = re.sub(
225 r"parameter of \w+ must be",
226 f"parameter of {func.__qualname__} must be",
227 str(e),
228 )
File c:\Users\sangouda\AppData\Local\Programs\Python\Python312\Lib\site-packages\sklearn\datasets\_twenty_newsgroups.py:322, in fetch_20newsgroups(data_home, subset, categories, shuffle, random_state, remove, download_if_missing, return_X_y, n_retries, delay)
320 if download_if_missing:
321 logger.info("Downloading 20news dataset. This may take a few minutes.")
--> 322 cache = _download_20newsgroups(
323 target_dir=twenty_home,
324 cache_path=cache_path,
325 n_retries=n_retries,
326 delay=delay,
327 )
328 else:
329 raise OSError("20Newsgroups dataset not found")
File c:\Users\sangouda\AppData\Local\Programs\Python\Python312\Lib\site-packages\sklearn\datasets\_twenty_newsgroups.py:95, in _download_20newsgroups(target_dir, cache_path, n_retries, delay)
90 os.remove(archive_path)
92 # Store a zipped pickle
93 cache = dict(
94 train=load_files(train_path, encoding="latin1"),
---> 95 test=load_files(test_path, encoding="latin1"),
96 )
97 compressed_content = codecs.encode(pickle.dumps(cache), "zlib_codec")
98 with open(cache_path, "wb") as f:
File c:\Users\sangouda\AppData\Local\Programs\Python\Python312\Lib\site-packages\sklearn\utils\_param_validation.py:191, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
189 global_skip_validation = get_config()["skip_parameter_validation"]
190 if global_skip_validation:
--> 191 return func(*args, **kwargs)
193 func_sig = signature(func)
195 # Map *args/**kwargs to the function signature
File c:\Users\sangouda\AppData\Local\Programs\Python\Python312\Lib\site-packages\sklearn\datasets\_base.py:309, in load_files(container_path, description, categories, load_content, shuffle, encoding, decode_error, random_state, allowed_extensions)
307 data = []
308 for filename in filenames:
--> 309 data.append(Path(filename).read_bytes())
310 if encoding is not None:
311 data = [d.decode(encoding, decode_error) for d in data]
File c:\Users\sangouda\AppData\Local\Programs\Python\Python312\Lib\pathlib.py:1020, in Path.read_bytes(self)
1016 def read_bytes(self):
1017 """
1018 Open the file in bytes mode, read it, and close the file.
1019 """
-> 1020 with self.open(mode='rb') as f:
1021 return f.read()
File c:\Users\sangouda\AppData\Local\Programs\Python\Python312\Lib\pathlib.py:1014, in Path.open(self, mode, buffering, encoding, errors, newline)
1012 if "b" not in mode:
1013 encoding = io.text_encoding(encoding)
-> 1014 return io.open(self, mode, buffering, encoding, errors, newline)
KeyboardInterrupt:
๐น 3. Bernoulli Naive Bayes (BernoulliNB)#
Used for binary/boolean features (e.g., whether a word appears in a document).
Hyperparameters:
alpha(same role as in MultinomialNB).binarize:Threshold for turning feature values into binary (0/1).
Default:
0.0(all non-zero values set to 1).Tuning: Try different thresholds based on data distribution.
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import BernoulliNB
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
# Load binary classification dataset
categories = ['rec.sport.baseball', 'sci.space']
newsgroups = fetch_20newsgroups(subset='train', categories=categories, remove=('headers','footers','quotes'))
X, y = newsgroups.data, newsgroups.target
# Pipeline: vectorizer + BernoulliNB
pipeline = Pipeline([
('vect', CountVectorizer(binary=True)), # force binary word presence
('nb', BernoulliNB())
])
# Define parameter grid
param_grid = {
'nb__alpha': [0.01, 0.1, 1, 5, 10],
'nb__binarize': [0.0, 0.5, 1.0, None],
'nb__fit_prior': [True, False]
}
# Grid Search
grid = GridSearchCV(pipeline, param_grid, cv=5, scoring='accuracy')
grid.fit(X, y)
print("Best Parameters:", grid.best_params_)
print("Best CV Accuracy:", grid.best_score_)
๐น 4. Hyperparameter Tuning Process#
We usually apply:
GridSearchCV โ exhaustively checks combinations.
RandomizedSearchCV โ samples combinations randomly (faster).
Cross-validation (StratifiedKFold) โ ensures fair evaluation, especially in imbalanced datasets.