Survey
* Your assessment is very important for improving the workof artificial intelligence, which forms the content of this project
* Your assessment is very important for improving the workof artificial intelligence, which forms the content of this project
A Novel Local Patch Framework for
Fixing Supervised Learning Models
Yilei Wang1, Bingzheng Wei2, Jun Yan2, Yang Hu2,
Zhi-Hong Deng1, Zheng Chen2
1Peking
University
2Microsoft Research Asia
Outline
Motivation & Background
Problem Definition & Algorithm Overview
Algorithm Details
Experiments - Classification
Experiments - Search Ranking
Conclusion
Motivation & Background
Supervised Learning:
Prediction Error:
Machine Learning task of inferring a function from labeled training
data
No matter how strong a learning model is, it will suffer from
prediction errors.
Noise in training data, dynamically changing data distribution,
weakness of learner
Feedback from User:
Good signal for learning models to find the limitation and then
improve accordingly
Learning to Fix Errors from Failure Cases
Automatically fix model prediction errors from failure
cases in feedback data.
Input:
Output:
A well trained supervised model (we name it as Mother Model)
A collection of failure cases in feedback dataset.
Learning to automatically fix the model bugs from failure cases
Previous Works
Model Retraining
Model Aggregation
Incremental Learning
Local Patching: from Global to Local
Learning models are generally
optimized globally
New Error
Introducing new prediction errors
when fixing the old ones
New Error
Our key idea: learning to fix the model locally using
patches
Problem Definition
Our proposed Local Patch Framework(LPF) aims to learn a
new model 𝑔 𝑥
𝑁
𝑔 𝑥 =𝑓 𝑥 +
𝐾𝑖 𝑥
×
𝑖=1 𝑃𝑎𝑡𝑐ℎ 𝑑𝑜𝑚𝑎𝑖𝑛
𝑃𝑖 (𝑥)
𝑃𝑎𝑡𝑐ℎ 𝑚𝑜𝑑𝑒𝑙
𝑓 𝑥 : the original mother model
𝑃𝑖 (𝑥): Patch model
𝐾𝑖 𝑥 : Gaussian distribution defined by a centroid 𝑧𝑖 and a
range 𝜎𝑖
1
𝐾𝑖 𝑥 = exp[−
1.5
2𝜎𝑖2
𝑥 − 𝑧𝑖
1
0.5
0
0
1
2
3
4
5
6
7
8
9
10
2
]
Algorithm Overview
Failure Case Collection
Learning Patch Regions/Failure Case Clustering
Clustering Failure Cases into N groups through subspace
learning, compute the centroid and range for every group,
then define our patches
Learning Patch Model
Learn a patch model using only the data samples that
sufficiently close to the patch centroid
Algorithm Details
Learning Patch Region – Key Challenge
Failure cases may distribute diffusely
Success Case
Failure Case
Small N = large patch range → many success cases will be patched
Big N = small patch range → high computational complexity
How to make trade-offs ?
Solution: Clustered Metric Learning
Our solution to diffusion: Metric Learning
Learn a distance metric, i.e. subspace, for failure cases,
such that the similar failure cases will aggregate, and keep
distant from the success cases.
•
•
•
(Red circle = failure cases; blue circle = success cases)
Key idea of the patch model learning
(Left): The cases in original data space.
(Middle): The cases mapped to the learned subspace.
(Right): Repair the failure cases using a single patch.
Metric Learning
Conditional distribution over 𝑥𝑖 ≠ 𝑥𝑗
2
Ideal distribution
1, (𝑗 ∈ 𝐶𝑓 ⋀𝐺𝑖 = 𝐺𝑗 )
𝑃0 𝑗 𝑖 ∝ 0, (𝑗 ∈ 𝐶𝑓 ⋀𝐺𝑖 ≠ 𝐺𝑗 )
0, 𝑗 ∈ 𝐶𝑠
Learn 𝐴 to satisfy ∀𝑗, 𝑃𝐴 𝑗 𝑖 → 𝑃0 𝑗 𝑖 𝑓𝑜𝑟 𝑎𝑛𝑦 𝑥𝑖 ∈ 𝐶𝑓
𝑃𝐴 𝑗 𝑖 =
exp(− 𝐴(𝑥𝑖 −𝑥𝑗 ) )
2
𝑘≠𝑖 exp(− 𝐴(𝑥𝑖 −𝑥𝑘 ) )
min
𝐴
𝑖∈𝐶𝑓 𝐾𝐿
𝑃0 𝑗 𝑖 |𝑃𝐴 𝑗 𝑖
Which is equivalent to maximize
𝑓 𝐴 =
𝑖∈𝐶𝑓
𝑗≠𝑖 𝑃0
𝑗 𝑖 log 𝑃𝐴 𝑗 𝑖
Clustered Metric Learning
Algorithm:
1. Initialize each failure case with a random group
2. Repeat the following steps:
a) For the given clusters, proceeds metric learning step
b) Update the centroids of the groups, and re-assign the failure cases
to its closest centroid.
Local Patch Region:
For each cluster i, we define a corresponding patch with as its
centroid 𝑧𝑖 , and as its variance 𝜎𝑖2
Gaussian weight: 𝐾𝑖 𝑥 = 𝑒𝑥𝑝 −
𝑥−𝑧𝑖 2𝐴
𝑖
2𝜎𝑖2
Learning Patch Model
Objective:
min 𝐶𝑜𝑠𝑡(𝑔 ∙ , 𝑙)
Where 𝑤 are the parameters, 𝑙 are the labels
Update parameter:
𝑤
𝑤𝑘 → 𝑤𝑘 − 𝜂
𝜕𝐶𝑜𝑠𝑡
𝜕𝑔
∙
𝜕𝑔
𝜕𝑤𝑘
For 𝜕𝑔/ 𝜕𝑤𝑘 , we have
𝜕𝑔
𝜕𝑤𝑘
=
𝜕𝑝𝑘 (𝑥)
𝐾𝑘 (𝑥) ×
𝜕𝑤𝑘
Notice: 𝜕 𝑝𝑘 (𝑥)/𝜕 𝑤𝑘 dependent on the specific patch model
Experiments
Experiments - Classification
Dataset
Randomly select 3 UCI subset
Spambase, Waveform, Optical Digit Recognition
Convert to binary classification dataset
~5000 instances in each dataset
Split to: 60% - training, 20% - feedback, 20% - test
Baseline Algorithm
SVM
Logistic Regression
SVM - retrained with training + feedback data
Logistic Regression - retrained with training + feedback data
SVM – Incremental Learning
Logistic Regression - Incremental Learning
Classification Accuracy
Classification accuracy on feedback dataset
SVM
SVM+LPF
LR
LR+LPF
Spam
0.8230
0.8838
0.9055
0.9283
Wave
0.7270
0.8670
0.8600
0.8850
Optdigit
0.9066
0.9724
0.9306
0.9689
Classification accuracy on test dataset
SVM
SVMRetain
SVM-IL
SVM+LPF
LR
LR-Retain
LR-IL
LR-LPF
Spam
0.8196
0.8348
0.8478
0.8587
0.9152
0.9174
0.9185
0.9217
Wave
0.7530
0.7780
0.7850
0.8620
0.8460
0.8600
0.8770
0.8800
Optdigit
0.9101
0.9128
0.9217
0.9635
0.9332
0.9368
0.9388
0.9413
Classification – Case Coverage
Parameter Tuning
Number of Patches
Data sensitive, in our experiment the best N is 2
Experiments – Search Ranking
Dataset
Metrics
Data from a commonly used commercial search engine
~14, 126 <q, d> pairs
With 5 grades label
NDCG@K {1,3,5}
Baseline Algorithm
GBDT
GBDT + IL
Experiment Results – Ranking
GBRT
IL
GBRT + LPF
nDCG@1
0.9115
0.9122
0.9422
nDCG@3
0.8837
0.8910
0.9149
nDCG@5
0.8790
0.8873
0.9090
Experiment Results – Ranking (Cont.)
Conclusion
We proposed
The local model fixing problem
A novel patch framework fox fixing the failure cases in
feedback dataset in local view
The experiment results demonstrate the effectiveness of
our proposed Local Patch Framework
Thank you!