Survey
* Your assessment is very important for improving the workof artificial intelligence, which forms the content of this project
* Your assessment is very important for improving the workof artificial intelligence, which forms the content of this project
A Novel Local Patch Framework for Fixing Supervised Learning Models Yilei Wang1, Bingzheng Wei2, Jun Yan2, Yang Hu2, Zhi-Hong Deng1, Zheng Chen2 1Peking University 2Microsoft Research Asia Outline Motivation & Background Problem Definition & Algorithm Overview Algorithm Details Experiments - Classification Experiments - Search Ranking Conclusion Motivation & Background Supervised Learning: Prediction Error: Machine Learning task of inferring a function from labeled training data No matter how strong a learning model is, it will suffer from prediction errors. Noise in training data, dynamically changing data distribution, weakness of learner Feedback from User: Good signal for learning models to find the limitation and then improve accordingly Learning to Fix Errors from Failure Cases Automatically fix model prediction errors from failure cases in feedback data. Input: Output: A well trained supervised model (we name it as Mother Model) A collection of failure cases in feedback dataset. Learning to automatically fix the model bugs from failure cases Previous Works Model Retraining Model Aggregation Incremental Learning Local Patching: from Global to Local Learning models are generally optimized globally New Error Introducing new prediction errors when fixing the old ones New Error Our key idea: learning to fix the model locally using patches Problem Definition Our proposed Local Patch Framework(LPF) aims to learn a new model 𝑔 𝑥 𝑁 𝑔 𝑥 =𝑓 𝑥 + 𝐾𝑖 𝑥 × 𝑖=1 𝑃𝑎𝑡𝑐ℎ 𝑑𝑜𝑚𝑎𝑖𝑛 𝑃𝑖 (𝑥) 𝑃𝑎𝑡𝑐ℎ 𝑚𝑜𝑑𝑒𝑙 𝑓 𝑥 : the original mother model 𝑃𝑖 (𝑥): Patch model 𝐾𝑖 𝑥 : Gaussian distribution defined by a centroid 𝑧𝑖 and a range 𝜎𝑖 1 𝐾𝑖 𝑥 = exp[− 1.5 2𝜎𝑖2 𝑥 − 𝑧𝑖 1 0.5 0 0 1 2 3 4 5 6 7 8 9 10 2 ] Algorithm Overview Failure Case Collection Learning Patch Regions/Failure Case Clustering Clustering Failure Cases into N groups through subspace learning, compute the centroid and range for every group, then define our patches Learning Patch Model Learn a patch model using only the data samples that sufficiently close to the patch centroid Algorithm Details Learning Patch Region – Key Challenge Failure cases may distribute diffusely Success Case Failure Case Small N = large patch range → many success cases will be patched Big N = small patch range → high computational complexity How to make trade-offs ? Solution: Clustered Metric Learning Our solution to diffusion: Metric Learning Learn a distance metric, i.e. subspace, for failure cases, such that the similar failure cases will aggregate, and keep distant from the success cases. • • • (Red circle = failure cases; blue circle = success cases) Key idea of the patch model learning (Left): The cases in original data space. (Middle): The cases mapped to the learned subspace. (Right): Repair the failure cases using a single patch. Metric Learning Conditional distribution over 𝑥𝑖 ≠ 𝑥𝑗 2 Ideal distribution 1, (𝑗 ∈ 𝐶𝑓 ⋀𝐺𝑖 = 𝐺𝑗 ) 𝑃0 𝑗 𝑖 ∝ 0, (𝑗 ∈ 𝐶𝑓 ⋀𝐺𝑖 ≠ 𝐺𝑗 ) 0, 𝑗 ∈ 𝐶𝑠 Learn 𝐴 to satisfy ∀𝑗, 𝑃𝐴 𝑗 𝑖 → 𝑃0 𝑗 𝑖 𝑓𝑜𝑟 𝑎𝑛𝑦 𝑥𝑖 ∈ 𝐶𝑓 𝑃𝐴 𝑗 𝑖 = exp(− 𝐴(𝑥𝑖 −𝑥𝑗 ) ) 2 𝑘≠𝑖 exp(− 𝐴(𝑥𝑖 −𝑥𝑘 ) ) min 𝐴 𝑖∈𝐶𝑓 𝐾𝐿 𝑃0 𝑗 𝑖 |𝑃𝐴 𝑗 𝑖 Which is equivalent to maximize 𝑓 𝐴 = 𝑖∈𝐶𝑓 𝑗≠𝑖 𝑃0 𝑗 𝑖 log 𝑃𝐴 𝑗 𝑖 Clustered Metric Learning Algorithm: 1. Initialize each failure case with a random group 2. Repeat the following steps: a) For the given clusters, proceeds metric learning step b) Update the centroids of the groups, and re-assign the failure cases to its closest centroid. Local Patch Region: For each cluster i, we define a corresponding patch with as its centroid 𝑧𝑖 , and as its variance 𝜎𝑖2 Gaussian weight: 𝐾𝑖 𝑥 = 𝑒𝑥𝑝 − 𝑥−𝑧𝑖 2𝐴 𝑖 2𝜎𝑖2 Learning Patch Model Objective: min 𝐶𝑜𝑠𝑡(𝑔 ∙ , 𝑙) Where 𝑤 are the parameters, 𝑙 are the labels Update parameter: 𝑤 𝑤𝑘 → 𝑤𝑘 − 𝜂 𝜕𝐶𝑜𝑠𝑡 𝜕𝑔 ∙ 𝜕𝑔 𝜕𝑤𝑘 For 𝜕𝑔/ 𝜕𝑤𝑘 , we have 𝜕𝑔 𝜕𝑤𝑘 = 𝜕𝑝𝑘 (𝑥) 𝐾𝑘 (𝑥) × 𝜕𝑤𝑘 Notice: 𝜕 𝑝𝑘 (𝑥)/𝜕 𝑤𝑘 dependent on the specific patch model Experiments Experiments - Classification Dataset Randomly select 3 UCI subset Spambase, Waveform, Optical Digit Recognition Convert to binary classification dataset ~5000 instances in each dataset Split to: 60% - training, 20% - feedback, 20% - test Baseline Algorithm SVM Logistic Regression SVM - retrained with training + feedback data Logistic Regression - retrained with training + feedback data SVM – Incremental Learning Logistic Regression - Incremental Learning Classification Accuracy Classification accuracy on feedback dataset SVM SVM+LPF LR LR+LPF Spam 0.8230 0.8838 0.9055 0.9283 Wave 0.7270 0.8670 0.8600 0.8850 Optdigit 0.9066 0.9724 0.9306 0.9689 Classification accuracy on test dataset SVM SVMRetain SVM-IL SVM+LPF LR LR-Retain LR-IL LR-LPF Spam 0.8196 0.8348 0.8478 0.8587 0.9152 0.9174 0.9185 0.9217 Wave 0.7530 0.7780 0.7850 0.8620 0.8460 0.8600 0.8770 0.8800 Optdigit 0.9101 0.9128 0.9217 0.9635 0.9332 0.9368 0.9388 0.9413 Classification – Case Coverage Parameter Tuning Number of Patches Data sensitive, in our experiment the best N is 2 Experiments – Search Ranking Dataset Metrics Data from a commonly used commercial search engine ~14, 126 <q, d> pairs With 5 grades label NDCG@K {1,3,5} Baseline Algorithm GBDT GBDT + IL Experiment Results – Ranking GBRT IL GBRT + LPF nDCG@1 0.9115 0.9122 0.9422 nDCG@3 0.8837 0.8910 0.9149 nDCG@5 0.8790 0.8873 0.9090 Experiment Results – Ranking (Cont.) Conclusion We proposed The local model fixing problem A novel patch framework fox fixing the failure cases in feedback dataset in local view The experiment results demonstrate the effectiveness of our proposed Local Patch Framework Thank you!