`
wangqisen
  • 浏览: 47104 次
文章分类
社区版块
存档分类
最新评论

数据挖掘之RandomForeast算法

 
阅读更多

RandomForest算法,精髓之处在于在建立决策树的时候,在每个节点进行属性选取时,是随机地选取部分属性,从中进行最优属性的选取,而不是在全部的所有属性中进行选择。建立了决策树森林之后,每次都要对这些不同的决策树进行预测,选出其中被预测最多的那个类别来作为最终的预测类别。在有5棵决策树时,我得出的对于离散属性的预测准确度为0.73,对于连续属性的预测准确度为0.96.

下面是我的RandomForest算法的代码,


/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package auxiliary;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map.Entry;

/**
 *
 * @author daq
 */
public class RandomForest extends Classifier {

	private int K=5;
	private ArrayList<DecisionTree> trees=new ArrayList<DecisionTree>();
	private HashMap<Double,Integer> map=null;
	private double newFeatures[][]=null;
	private double newLabels[]=null;
    public RandomForest() {
    }

    @Override
    public void train(boolean[] isCategory, double[][] features, double[] labels) {
    	for(int i=0;i<K;i++){
    		DecisionTree tree=new DecisionTree();
    		produceNewFeaturesLabels(features,labels);
    		tree.train(isCategory,newFeatures, newLabels);
    		trees.add(tree);
    		newFeatures=null;
    		newLabels=null;
    	}
    	DecisionTree tree=new DecisionTree(); 
    	tree.train(isCategory, features, labels);
    	trees.add(tree);
    }
    
    public void produceNewFeaturesLabels(double[][] features, double[] labels){
    	int size=features.length;
    	newFeatures=new double[size][];
    	newLabels=new double[size];
    	int length=features[0].length;
    	for(int i=0;i<size;i++){
    		int ran=(int) (Math.random()*size);
    		newFeatures[i]=Arrays.copyOf(features[ran],length);
    		newLabels[i]=labels[ran];
    	}
    }

    @Override
    public double predict(double[] features) {
    	map=new HashMap<Double, Integer>();
       for(int i=0;i<K;i++){
        	DecisionTree tree=trees.get(i);
        	double label=tree.predict(features);
        	if(map.get(label)==null)
        		map.put(label,1);
        	else
        		map.put(label, map.get(label)+1);
        }
        double maxIndex=0;
        int max=-1;
        Iterator<Entry<Double,Integer>> ite=map.entrySet().iterator();
        while(ite.hasNext()){
        	Entry entry=ite.next();
        	if((Integer)entry.getValue()>max){
        		maxIndex=(Double)entry.getKey();
        		max=(Integer)entry.getValue();
        	}
        }
        return maxIndex;
    }
    	
}


分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics