Detecting Data Drift in Machine Learning

Hemant Rawat
8 min readFeb 11, 2024

--

ABSTRACT:

In the world of machine learning, models are trained on historical data to make predictions on new, unseen data. However, the underlying assumption here is that the data distribution remains unchanged over time. In reality, data often changes, leading to a phenomenon known as “data drift.”

Data drift can significantly impact the performance and reliability of machine learning models. To ensure the continued accuracy and effectiveness of models, it is essential to detect and adapt to data drift.

Let’s explore the concept of distribution shift, its significance, and the various types of shifts that can affect ML systems differently. Furthermore, we’ll also explores methods for detecting covariate and concept shift drifts.

INTRODUCTION:

Machine learning models are typically trained offline by gathering datasets for training purposes, and then deployed for use. However, during deployment, the data seen by the model may differ from the training data due to real-world changes over time.

For instance, tissue images may differ across various hospitals, just as land use identification may vary between Europe and the Americas [1], [2].

In-distribution (same data sets) vs out-of-distribution data discrepancies.

Understanding Data Drift:

Data drift refers to the phenomenon where the statistical properties of the training and production data deviate over time. These changes can occur due to various factors, such as changes in user behavior, data collection processes, or external factors influencing the data source. Data drift can result in a significant performance degradation of machine learning models as they become mismatched with the new data.

Types of Distribution Shifts:

In machine learning, we work with the joint distribution over input variables of features x and target variables or labels y, denoted as p(x,y).

When training a model to predict y from x, the model will try to learn the conditional distribution p of y given x. p(y|x).

Depending on how exactly the joint distribution shifts, this can have different effects on our goal of learning to predict y from x.

Joint distribution can shift in different ways; hence it is important to identify distinct shifts.

Covariate Shift:

Covariate shift is a specific type of data drift that occurs when the input features’ distribution changes, while the target variable distribution remains the same. In other words, the relationship between the input variables and the output remains constant, but the distribution of the input variables themselves shifts.

Consider a machine learning model that predicts housing prices based on features like location, size, and age. If the distribution of these features in the training data differs significantly from the distribution in the production data, the model’s performance may suffer due to covariate shift.

p(x,y) = p(y|x)p(x)

p(y|x) stays constant.

p(x) may change.

To define covariate shift, we factorize our joint distribution p(x,y) into the product of p(y∣x), which is the part our model aims to learn, multiplied by p(x), representing the marginal distribution over the features. Covariate shift occurs when p(y∣x) remains constant while p(x) is allowed to change. Even though the component our model focuses on remains steady, covariate shift can still have detrimental effects.

Detecting Covariant Shift detection:

Let’s consider a scenario where we have reference data generated by the distribution p(x), and query data generated by the distribution p’(x). We have samples X = {x1, x2,x3,….xn} and X’=(x’1,…,x’n}, and our objective is to determine whether p=p’.

Detecting Covariate Shift presents challenges due to several factors:

· High-dimensional data

· Unstructured inputs, such as text or images

· Potentially limited number of data points

General Framework:

Given samples X = {x1,…,xn} and X’ = {x’1,…,x’n}

1. Feature extractions:

Feature extraction involves transforming raw inputs into informative features. To prevent bias, utilize a model trained on p (X needs to be held-out data so it does not skew the results). It is simple and effective and rely on generic pretrained feature extractor such as ResNet for images, BERT for text.

2. Perform a statistical hypothesis test

Utilize two samples test to determine whether two samples X and X’ have been generated by identical distribution.

Given samples X = {x1,…,xn} and X’ = {x’1,…,x’n}

· Null hypothesis: H0 :p=p’

· Alternative hypothesis: H1: p ≠ p’

· Reject Ho if observing (X, X’) is unlikely under the assumption Ho

o Choose a test statistic T(X,X’)

o Plugging in your sample gives observed value t

o Compute p-value: PH0 [T(X,X’)≥t]

o If the p-value is smaller than significant level α, reject Ho (significance level is a user parameter and controls the false positive rate)

Maximum Mean Discrepancy — Intuition

· Test statistic: MMD(X,X’) = ǁx-x’ǁ2

Compute the mean of two samples X bar and X’ bar and calculate the square difference.

That might be a useful statistic because if there is no shift in the left panel , the two panes will be very similar. If there is a distribution shift, the two means can be dissimilar.

The issue lies in the potential misinterpretation when solely considering the mean. To address this, we employ a feature function that enhances the informativeness of the mean. So if X is a real number, we might choose a feature funtion that just concatenates X and its square. By applying this feature function to our data points and then calculating the mean, the resulting mean incorporates information about expected squares, thus capturing second-order information and enabling better differentiation between means.

An alternate apprach is apply the Kernel Trick [3].

The final test to have MMD Test to approximate P-values

o We need PH0 [T(X,X’)≥t]

o Exact distribution of test statistic is intractable

o Assuming H0 :p=p’, X and X’ are interchangeable

o We can ‘permute’ without changing the distribution of T(X,X’)

o Permutation test

§ Perform a number of permutations

§ Count how often T(Xbar, X’bar)≥t

§ Ratio approximate the p-value [4]

Pros of MMD: No assumptions on the distributions, no model traning needed, very powerful (given a good kernel)

Cons of MMD: Compute/memory cost is quadratic O((n+n’)2), choice of kernel = hyperparameters

Other alternatives test are Kolmogorov-Smirnov, using Domain classifier [5]

Addressing Covariate Shift:

Detecting and addressing covariate shift is crucial for maintaining model performance. Here are a few techniques to handle covariate shift:

1. Rebalancing the Data: If the input features’ distribution has shifted, re-weighting or re-sampling the training data to align with the new distribution can help mitigate the covariate shift.

2. Feature Selection/Engineering: Analyzing the features contributing to the covariate shift and modifying them or selecting more robust features can reduce the impact of the shift.

3. Transfer Learning: Leveraging transfer learning techniques allows models to utilize knowledge from previous domains to adapt to the new data distribution effectively.

4. Ensemble Methods: Combining multiple models trained on different datasets or with different feature representations can help improve the model’s robustness to covariate shift.

Concept Drift:

Another type of shift commonly observed in machine learning is concept drift. Concept drift refers to a situation where the relationship between input features and the target variable changes over time. Unlike covariate shift, where the input feature distribution changes while the target distribution remains the same, concept drift implies that the relationship between the input features and the target variable evolves.

Concept drift can occur due to various reasons, such as changes in user preferences, external factors influencing the target variable, or shifts in the underlying data-generating process. This phenomenon poses a significant challenge for machine learning models because the assumptions made during training no longer hold true, leading to decreased model performance.

In concept shift p(y|x) changes. Concept is any shift where p(y|x) changes, so the actual mechanism that generates label or target variable y from x changes. So, the fundamental relationship between these two variables has changed.

Detecting shift in P(y|x) requires label data (x, y).

MMD test can be extended to the conditional distributions.[6]

Detecting and handling concept drift requires continuous monitoring of the model’s performance and adaptation to the changing patterns. Here are a few techniques used to address concept drift:

1. Monitoring Performance Metrics: By tracking performance metrics, such as accuracy, precision, recall, or F1-score, over time, we can identify significant drops or fluctuations that may indicate concept drift. Sudden or gradual changes in these metrics can trigger actions to address the drift.

2. Window-Based Approaches: Instead of considering the entire historical data, window-based approaches limit the analysis to a recent subset of data. By updating the model using only the most recent window of data, the model can adapt to the most relevant concept at a given time.

3. Ensemble Methods: Ensemble methods, such as stacking or boosting, can be employed to combine multiple models trained on different windows or snapshots of data. By leveraging the diversity of these models, they can collectively adapt to concept drift and maintain better performance.

4. Online Learning: Online learning algorithms are specifically designed to handle streaming data where concept drift is prevalent. These algorithms continuously update the model using new instances as they arrive, allowing the model to adapt to the evolving data distribution.

5. Change Detection Algorithms: Change detection algorithms monitor the incoming data stream and identify points or intervals where significant changes occur. This can help identify when concept drift happens, allowing for timely model updates.

Addressing concept drift is an ongoing challenge in machine learning, and there is no one-size-fits-all solution. The choice of techniques depends on the specific problem domain, the available resources, and the severity of the concept drift. Continuous monitoring, adaptation, and model retraining are key to maintaining model performance in the face of concept drift.

Conclusion

Detecting and mitigating data drift, including covariate shift and concept drift, is crucial for maintaining the accuracy and reliability of machine learning models over time. By employing suitable detection techniques and addressing the underlying causes of drift, we can ensure that our models adapt to the evolving data distribution and continue to provide valuable insights and predictions. Being aware of covariate shift and concept drift and utilizing appropriate strategies to handle them will help us build robust and resilient machine learning systems.

Knowing the types of shifts is important because that’s decide how the machine learning model needs to be treated.

For Covariant shift: joint predictive model can work well for both the old and new data.

Concept shift: New model is required. Training the new model does not mean that the old model or your old data is completely useless. It can still be very useful in terms of representation learning or basis for transfer learning.

There is one another shift called Label Shift, it is best explained as a classification problem.

p(x,y) = p(x|y)p(y),

p(x|y) stays constant, p(y) may change

REFERENCES:

[1] https://arxiv.org/abs/2012.07421

[2] https://www-cs-faculty.stanford.edu/people/jure/pubs/wilds-icml21.pdf

[3] https://jmlr.org/papers/volume13/gretton12a/gretton12a.pdf

[4] https://github.com/awslabs/Renate

[5] https://arxiv.org/abs/2206.08843

[6] https://arxiv.org/abs/2203.08644

--

--

Hemant Rawat
Hemant Rawat

Written by Hemant Rawat

Product Management & Solutions Engineering.

No responses yet