0PricingLogin
Machine Learning Academy · Lesson

Training a Multinomial Naive Bayes Classifier

Learners will fit MultinomialNB on a spam/ham dataset, tune the alpha smoothing parameter, and evaluate with classification_report.

Why Multinomial Naive Bayes for Text?

Scikit-learn provides three Naive Bayes variants. MultinomialNB is designed for discrete count data — exactly what CountVectorizer produces. It models the probability of each word count given the class. BernoulliNB works with binary presence/absence features and is suited for very short documents. GaussianNB assumes continuous features with Gaussian distributions — appropriate for numerical data, not word counts. For text classification, MultinomialNB is almost always the right choice because word counts are non-negative integers that fit the multinomial distribution assumption perfectly.

from sklearn.naive_bayes import MultinomialNB, BernoulliNB, GaussianNB
from sklearn.feature_extraction.text import CountVectorizer

corpus = ['buy cheap meds now', 'hello friend meeting tomorrow',
          'click here for discount', 'project update next week']
labels = [1, 0, 1, 0]  # 1=spam, 0=ham

vec = CountVectorizer()
X = vec.fit_transform(corpus)

nb = MultinomialNB()
nb.fit(X, labels)
print('Classes:', nb.classes_)
print('Feature log probs shape:', nb.feature_log_prob_.shape)
# feature_log_prob_[class][feature] = log P(feature | class)

Loading and Splitting the 20 Newsgroups Dataset

The 20 Newsgroups dataset is the canonical text classification benchmark. It contains approximately 18,000 newsgroup posts across 20 topics, from 'sci.space' to 'talk.politics.guns'. Scikit-learn provides it via fetch_20newsgroups() with the convenient option to pre-remove headers, footers, and quotes that would make the task artificially easy. We select a binary subset (spam vs. not-spam is the hardest to get in this dataset; instead we compare two newsgroups) and split into train/test.

from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split

# Load two categories for binary classification
cats = ['sci.space', 'rec.sport.hockey']
train = fetch_20newsgroups(subset='train', categories=cats,
                           remove=('headers', 'footers', 'quotes'))
test  = fetch_20newsgroups(subset='test',  categories=cats,
                           remove=('headers', 'footers', 'quotes'))

print('Train samples:', len(train.data))
print('Test  samples:', len(test.data))
print('Categories:', train.target_names)
print('Sample (first 200 chars):', train.data[0][:200])

All lessons in this course

  1. Bayes' Theorem in Plain Language
  2. Bag of Words: CountVectorizer and TfidfVectorizer
  3. Training a Multinomial Naive Bayes Classifier
  4. Laplace Smoothing and Zero-Probability Problem
← Back to Machine Learning Academy