2.13 Cross Validation in Machine Learning
Cross Validation in Machine Learning
- Cross Validation is a statistical method used to evaluate the performance of a machine learning model.
- It helps us understand how well the model will perform on unseen or new data.
- When a model is trained using a dataset, it may perform very well on that dataset but fail when predicting new data. Cross validation helps prevent this problem.
Cross validation helps to:
-
Evaluate model performance more accurately
-
Prevent overfitting
-
Select the best machine learning model
-
Improve model reliability and generalization
Example:
A model shows 95% accuracy on training data. But when tested on new data it gives 65% accuracy. This means the model memorized the training data instead of learning patterns. Cross validation helps detect such problems.
Overfitting
Overfitting occurs when the model learns the training data too well, including noise and unnecessary details.
As a result, the model performs:
-
Very well on training data
-
Poorly on new or unseen data
Example:
A student memorizes answers instead of understanding concepts.
He performs well in practice but fails when questions change.
Underfitting
Underfitting occurs when the model fails to learn patterns from the training data.
As a result, the model performs poorly on both:
-
Training data
-
Test data
Example:
A student studies only a few topics and cannot answer most questions.
The basic process of cross validation:
-
Divide the dataset into multiple subsets (folds).
-
Train the model on some subsets.
-
Test the model on the remaining subset.
-
Repeat the process multiple times.
-
Calculate the average performance score.
This gives a more reliable estimate of model performance.
General Cross Validation Algorithm
-
Divide the dataset into training and testing sets.
-
Train the model on the training data.
-
Validate the model on the test data.
-
Repeat the process several times depending on the method used.
-
Calculate the average accuracy or error.
Types of Cross Validation
There are several cross validation techniques:
-
Holdout Validation
-
Leave-One-Out Cross Validation (LOOCV)
-
Stratified Cross Validation
-
K-Fold Cross Validation
1. Holdout Validation
Holdout validation is the simplest cross validation technique.
In this method, the dataset is divided into two parts:
-
Training dataset
-
Testing dataset
Common splits include:
Example
Suppose we have 1000 data samples.
-
Training data = 800 samples
-
Testing data = 200 samples
Steps:
-
Train the model using the 800 samples.
-
Test the model using 200 samples.
-
Evaluate performance.
Advantages
-
Simple
-
Fast to implement
Disadvantages
-
Model performance depends heavily on the dataset split
-
Some important data may not be used during training
2. Leave-One-Out Cross Validation (LOOCV)
Leave-One-Out Cross Validation is a special case of cross validation where:
-
One data point is used as the test set
-
All remaining data points are used as the training set
This process is repeated for every data point in the dataset.
Example
Suppose a dataset contains 5 data points.
The model is trained 5 times.
Final performance = average of all results.
Advantages
-
Uses almost the entire dataset for training
-
Low bias
Disadvantages
-
Very computationally expensive
-
High variation if the test sample is an outlier
3. Stratified Cross Validation
Definition
Stratified Cross Validation is mainly used for classification problems, especially when the dataset is imbalanced.
It ensures that each fold contains the same class distribution as the original dataset.
Example
Suppose we have a dataset:
If we divide into 5 folds using stratified method:
Each fold will contain approximately:
-
16 Spam
-
4 Not Spam
This keeps the class balance consistent in each fold.
Advantages
-
Works well with imbalanced datasets
-
Provides better model evaluation
4. K-Fold Cross Validation
K-Fold Cross Validation is one of the most widely used cross validation methods.
In this technique:
-
The dataset is divided into K equal parts (folds).
-
One fold is used for testing.
-
Remaining K-1 folds are used for training.
-
This process repeats K times.
Example (5-Fold Cross Validation)
Dataset divided into 5 folds:
Each fold is used once for testing.
Final performance = average of all 5 results.
Advantages
-
Uses entire dataset efficiently
-
Provides more reliable results
-
Reduces overfitting risk
Example of Cross Validation in Real Life
Suppose we are building a model to predict student exam results.
Dataset contains:
-
Study hours
-
Attendance
-
Assignment scores
Using K-Fold Cross Validation, we train the model multiple times and check if predictions remain consistent.
If the model performs well in all folds, it means the model generalizes well to new data.