`
wangqisen
  • 浏览: 46966 次
文章分类
社区版块
存档分类
最新评论

数据挖掘之AdaBoost算法

 
阅读更多

这个算法的精髓在于,虽然其每次用的决策树的构成方式相同,但是,其每次所用的训练元祖并不同,没经过一次训练,其元祖中的那些被预测错误的元祖的权重会加大,使得下次训练更容易被选中,这样的几次训练会比较均匀,使得对于易错元祖的预测比较好。在k=8时,得出的离散型属性的预测准确度为0.73,连续型为0.95。

下面是我的代码:

/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package auxiliary;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map.Entry;

import auxiliary.DecisionTree.DecisionTreeNode;


/**
 *
 * @author daq
 */
public class AdaBoost extends Classifier {
	
	private final int K=8;
	private double[][]newFeatures=null;
	private double[]newLabels=null;
	
	private double[][]testFeatures=null;
	private double[]testLabels=null;
	
	private double[][]selectedFeatures=null;
	private double[]selectedLabels=null;
	
    ArrayList<DecisionTree2> trees=new ArrayList<DecisionTree2>();
    ArrayList<Double> treeWeights=new ArrayList<Double>();
	private HashMap<Integer,Double> qualityMap=new HashMap<Integer,Double>();
	private ArrayList<Integer> choosenLines=new ArrayList<Integer>();
	private HashMap<Integer,Double> choosenQualityMap=new HashMap<Integer,Double>();
	private ArrayList<Integer> truePredictedLines=new ArrayList<Integer>();//
	
    public AdaBoost() {
    	
    }
    
    public void setQualityMap(){
    	double quality=1.0/newFeatures.length;
    	for(int i=0;i<newFeatures.length;i++){
    		qualityMap.put(i,quality);
    	}
    }

    @Override
    public void train(boolean[] isCategory, double[][] features, double[] labels) {
    	int trainSize=features.length;
    	int preSize=features.length-trainSize;
    	
    	newFeatures=new double[trainSize][];
    	newLabels=new double[trainSize];
    	
    	testFeatures=new double[preSize][];
    	testLabels=new double[preSize];
    	
    	selectedFeatures=new double[trainSize][];
    	selectedLabels=new double[trainSize];
    	
    	newFeatures=Arrays.copyOfRange(features,0,trainSize);
    	newLabels=Arrays.copyOfRange(labels, 0, trainSize);
    	
    	testFeatures=Arrays.copyOfRange(features,trainSize,features.length);
    	testLabels=Arrays.copyOfRange(labels,trainSize,features.length);
    	
    	
    	setQualityMap();
    	for(int i=1;i<=K;i++){
    		produceNewLines();
    		buildSelectedFeaturesLabels();
    		DecisionTree2 tree=new DecisionTree2();
    		tree.train(isCategory, selectedFeatures,selectedLabels);
    		double m=computeM(tree);
    		if(m>0.5)
    			continue;
    		else{
    			trees.add(tree);
    			treeWeights.add(Math.log((1-m)*1.0/m));
    			updateQuality(m);
    			choosenLines.clear();
    			truePredictedLines.clear();
    		}
    	}
    	
    }
    
    @SuppressWarnings("unchecked")
	public void updateQuality(double m){

    	double w=m*1.0/(1-m);
    	HashMap<Integer,Double> tempMap=new HashMap<Integer, Double>();
    	tempMap=(HashMap<Integer, Double>)qualityMap.clone();
    	for(int i=0;i<selectedFeatures.length;i++){
    		int line=choosenLines.get(i);
    		if(truePredictedLines.contains(i)){
    			qualityMap.put(line, qualityMap.get(line)*w);
    		}
    	}
    	double newSum=0;
    	Iterator<Entry<Integer, Double>> ite=qualityMap.entrySet().iterator();
        while(ite.hasNext()){
        	Entry<Integer, Double> entry=ite.next();	
        	newSum+=entry.getValue();
        }
        
        double oldSum=0;
        Iterator<Entry<Integer, Double>> ite2=tempMap.entrySet().iterator();
        while (ite2.hasNext()) {
        	Entry<Integer, Double> entry=ite2.next();	
        	oldSum+=entry.getValue();
		}
        
        double ratio=oldSum*1.0/newSum;
        
        Iterator<Entry<Integer, Double>> ite3=qualityMap.entrySet().iterator();
        while (ite3.hasNext()) {
        	Entry<Integer, Double> entry=ite3.next();	
        	int key=entry.getKey();
        	double value=entry.getValue();
        	qualityMap.put(key, value*ratio);
		}
    }

    
    public double computeM(DecisionTree2 tree){
    	int size=selectedFeatures.length;
    	int t=0;
    	for(int i=0;i<size;i++){
    		double res=tree.predict(selectedFeatures[i]);
    		if(res!=selectedLabels[i])
    			t++;
    		else
    			truePredictedLines.add(i);
    	}
    	return t*1.0/(size);	
    }
    
    public void produceNewLines(){
    	double sum=0;
    	Iterator<Entry<Integer, Double>> ite=qualityMap.entrySet().iterator();
    	while(ite.hasNext()){
    		Entry<Integer,Double> entry=ite.next();
    		sum+=entry.getValue();
    	}
    	int lineNums=newFeatures.length;
    	for(int i=0;i<lineNums;i++){
    		double temp=Math.random()*sum;
    		Iterator<Entry<Integer, Double>> ite2=qualityMap.entrySet().iterator();
    		double t=0;
    		while(ite2.hasNext()){
    			Entry<Integer,Double> entry=ite2.next();
    			t+=entry.getValue();
    			if(t>=temp){
    				int key=entry.getKey();
    				choosenLines.add(key);
    				break;
    			}
    		}
    	}
    	return;
    }
    
    public void buildSelectedFeaturesLabels(){
    	int k=0;
    	int lineSize=newFeatures[0].length;
    	for(Integer line:choosenLines){
    		//selectedFeatures[k]=newFeatures[line];
    		//selectedLabels[k]=newLabels[line];
    		selectedFeatures[k]=Arrays.copyOf(newFeatures[line],lineSize);
    		selectedLabels[k]=newLabels[line];
    		k++;
    	}
    }
    
    @Override
    public double predict(double[] features) {
    	HashMap<Double,Double> map=new HashMap<Double, Double>();
       for(int i=0;i<trees.size();i++){
    	   DecisionTree2 tree=trees.get(i);
    	   double label=tree.predict(features);
    	   if(map.get(label)==null){
    		   map.put(label,treeWeights.get(i));
    	   }else{
    		   map.put(label,map.get(label)+treeWeights.get(i));
    	   }
       }
       double max=0,maxIndex=0;
       Iterator<Entry<Double, Double>> ite=map.entrySet().iterator();
       while (ite.hasNext()) {
    	   Entry<Double,Double> entry=ite.next();
    	   if(entry.getValue()>max){
    		   max=entry.getValue();
    		   maxIndex=entry.getKey();
    	   }
       }
       return maxIndex;
    }
}


分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics