`
wangqisen
  • 浏览: 46461 次
文章分类
社区版块
存档分类
最新评论

数据挖掘之朴素贝叶斯算法的实现

 
阅读更多

这是我数据挖掘课的作业,也就是实现一个朴素贝叶斯算法。所用的训练数据集为加州大学计算机系提供的breast-cancer.data和segment.data。我得出的朴素贝叶斯算法对于离散型属性的预测准确度为0.72,对于连续型属性的预测准确度为0.79。

代码如下:

/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package auxiliary;

import java.beans.FeatureDescriptor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;


/**
 *
 * @author daq
 */

class Store{//用于标识离散属性的P(Xi|Cj)的类
	int attr;//哪个属性
	double attrValue;//属性对应的值
	double lable;//与之对应的标签的值
	
	@Override
	public int hashCode() {//重写的hashCode方法
		final int prime = 31;
		int result = 1;
		result = prime * result + attr;
		long temp;
		temp = Double.doubleToLongBits(attrValue);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		temp = Double.doubleToLongBits(lable);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		return result;
	}
	
	@Override
	public boolean equals(Object obj) {//重写的equals方法
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (getClass() != obj.getClass())
			return false;
		Store other = (Store) obj;
		if (attr != other.attr)
			return false;
		if (Double.doubleToLongBits(attrValue) != Double
				.doubleToLongBits(other.attrValue))
			return false;
		if (Double.doubleToLongBits(lable) != Double
				.doubleToLongBits(other.lable))
			return false;
		return true;
	}
}

class Store2{//用于标识连续属性的P(Xi|Cj)的类
	int attr;//哪个属性
	double label;//标签的值
	
	@Override
	public int hashCode() {//重写的hashCode方法
		final int prime = 31;
		int result = 1;
		result = prime * result + attr;
		long temp;
		temp = Double.doubleToLongBits(label);
		result = prime * result + (int) (temp ^ (temp >>> 32));
		return result;
	}
	
	@Override
	public boolean equals(Object obj) {//重写的equals方法
		if (this == obj)
			return true;
		if (obj == null)
			return false;
		if (getClass() != obj.getClass())
			return false;
		Store2 other = (Store2) obj;
		if (attr != other.attr)
			return false;
		if (Double.doubleToLongBits(label) != Double
				.doubleToLongBits(other.label))
			return false;
		return true;
	}
}



public class NaiveBayes extends Classifier {

	boolean []myIsCategory;
	double [][]myFeatures;
	double []myLabels;
	
	ArrayList<Double> labelKinds=new ArrayList<Double>();//label的数组的种类的数组
	HashMap<Double,Double> labelKindsProp=new HashMap<Double,Double>();//每种label的概率
	HashMap<Store,Double> attrLabelsProp=new HashMap<Store,Double>();
	HashMap<Integer,ArrayList<Double>> valueKinds=new HashMap<Integer,ArrayList<Double>>();
	HashMap<Store2,Double> averageAttrs=new HashMap<Store2,Double>();
	HashMap<Store2,Double>standDev=new HashMap<Store2,Double>();//attr,labelKind
	
    public NaiveBayes() {
    	
    }

    public void setLabelKinds(double[] labels){//计算标签种类并存储各个不同标签值
    	for(int i=0;i<labels.length;i++){
    		boolean flag=true;
    		for(int j=0;j<labelKinds.size();j++){
    			if(labelKinds.get(j)==labels[i]){
    				flag=false;
    				break;
    			}
    		}
    		if(flag==true)
    			this.labelKinds.add(labels[i]);
    	}
    }
    
    public void setLabelKindsProp(double[] labels){//计算不同标签的值占全部标签的比例
    	int sum=labels.length;
    	for(int i=0;i<labelKinds.size();i++){
    		double label=labelKinds.get(i);
    		int num=0;
    		for(int j=0;j<labels.length;j++){
    			if(labels[j]==label)
    				num++;
    		}
    		labelKindsProp.put(label,num*1.0/sum);
    	}
    }
    
    public void setValueKinds(double[][] features){//对于离散的属性,存储其不同的值
    	int lineNums=features.length;
    	for(int i=0;i<features[0].length;i++){
    		if(myIsCategory[i]){
	    		for(int j=0;j<lineNums;j++){
	    			ArrayList<Double> values=valueKinds.get(i);
	    			if(values==null)
	    				values=new ArrayList<Double>();
	    			if(!values.contains(features[j][i])){
	    				values.add(features[j][i]);
	    				valueKinds.put(i,values);
	    			}
	    		}
    		}
    	}
    }
    
    public void setAttrLabelsProp(boolean[] isCategory, double[][] features){//对于离散的属性,计算不同的值占全部元祖的比例
    	for(int i=0;i<labelKinds.size();i++){
    		double label=labelKinds.get(i);
    		for(int j=0;j<features[0].length;j++){
    			if(myIsCategory[j]){
	    			ArrayList<Double> values=valueKinds.get(j);
	    			int num[]=new int[values.size()];
	    			for(int k=0;k<features.length;k++){
			    			for(int l=0;l<values.size();l++){
			    				double value=values.get(l);	
			    				if(features[k][j]==value&&myLabels[k]==label)
			    					num[l]++;
			    			}
	    			}
	    			double labelNum=labelKindsProp.get(label)*features.length;
	    			for(int k=0;k<num.length;k++){
	    				double valuetmp=values.get(k);
	    				double prop=num[k]*1.0/labelNum;
	    				Store store=new Store();
	    				store.attr=j;
	    				store.attrValue=valuetmp;
	    				store.lable=label;
	    				attrLabelsProp.put(store,prop);
	    			}
    			}
    		}
    	}
    }
    
    public double getAverage(double [][]data,int attr,double lable){//得到连续属性对于特定标签值的平均数
    	int lineNums=data.length;
    	double sum=0;
    	int k=0;
    	for(int i=0;i<lineNums;i++){
    		if(myLabels[i]==lable){
    			sum+=data[i][attr];
    			k++;
    		}
    	}
    	return sum*1.0/k;
    }
    
    public void setAverage(){//存储得到的连续属性的值得平均数
    	for(int i=0;i<labelKinds.size();i++){
    		double label=labelKinds.get(i);
    		for(int j=0;j<myFeatures[0].length;j++){
    			double ave=getAverage(myFeatures,j, label);
    			Store2 store=new Store2();
    			store.attr=j;
    			store.label=label;
    			averageAttrs.put(store,ave);
    		}
    	}
    }
    
    public double getStanDev(double [][]data,int attr,double lable){//得到连续属性对于特定标签值的标准差
    	int lineNums=data.length;
    	double sum=0;
    	int k=0;
    	for(int i=0;i<lineNums;i++){
    		if(myLabels[i]==lable){
    			sum+=Math.pow(data[i][attr],2);
    			k++;
    		}
    	}
    	double temp=sum*1.0/k;
    	double res2=temp-Math.pow(getAverage(data,attr,lable),2);
    	return Math.sqrt(res2);
    }
    
    public void setStanDev(){
    	for(int i=0;i<labelKinds.size();i++){
    		double label=labelKinds.get(i);
    		for(int j=0;j<myFeatures[0].length;j++){
    			double sd=getStanDev(myFeatures,j, label);
    			Store2 store=new Store2();
    			store.attr=j;
    			store.label=label;
    			standDev.put(store,sd);
    		}
    	}
    }
    
    public double getG(double x,double u,double m){//计算P(X|C),也就是高斯分布g
    	double t=-1*Math.pow((x-u),2)*1.0/(2*Math.pow(m,2));
    	return Math.pow(Math.E,t)*1.0/(Math.sqrt(2.0*Math.PI)*m);
    }
    
    public double getOverallAve(int attr){//得到某一个属性的整体平均值用于预处理,替换所有的NaN型数据
    	double sum=0;
    	for(int i=0;i<myFeatures.length;i++){
    		sum+=myFeatures[i][attr];
    	}
    	return sum*1.0/myFeatures.length;
    }
    
    public void preHandle(){//对数据进行预处理,替换掉NaN型数据
		int length=myFeatures.length;
		for(int i=0;i<myFeatures.length;i++){
			double line[]=myFeatures[i];
			for(int j=0;j<line.length;j++){
				if(line[j]!=line[j]){
					myFeatures[i][j]=this.getOverallAve(j);
				}
			}
		}
	}
    
    @Override
    public void train(boolean[] isCategory, double[][] features, double[] labels) {
    	this.myIsCategory=isCategory;
    	this.myFeatures=features;
    	this.myLabels=labels;
    	this.preHandle();
    	this.setLabelKinds(labels);
    	this.setLabelKindsProp(labels);
    	this.setValueKinds(features);
    	this.setAttrLabelsProp(isCategory,features);
    	this.setAverage();
    	this.setStanDev();
    }

    @Override
    public double predict(double[] features) {
    	double resMax=0,resMaxIndex=0;
    	for(int i=0;i<this.labelKinds.size();i++){
    		double res=1;
    		double label=this.labelKinds.get(i);
    		double labelProp=this.labelKindsProp.get(label);
    		for(int j=0;j<features.length;j++){
    			double feature=features[j];
    			if(myIsCategory[j]){
    				Store store=new Store();
    				store.attr=j;
    				store.attrValue=feature;
    				store.lable=label;
    				if(attrLabelsProp.get(store)!=null){
    					double prop=attrLabelsProp.get(store);
    					res*=prop;
    				}else
    					res=0;
    			}else{
    				Store2 s1=new Store2();
    				s1.attr=j;
    				s1.label=label;
    				double ave=this.averageAttrs.get(s1);
    				double stanDev=this.standDev.get(s1);
    				if(stanDev==0)
    					continue;
    				double g=getG(features[j],ave,stanDev);
    				res*=g;
    			}
    		}
    		res*=labelProp;
    		if(res>resMax){    			
    			resMax=res;
    			resMaxIndex=label;
    		}
    	}
    	return resMaxIndex;
    }
}

这段代码中,train函数用来训练数据,predict函数预测数据。train的参数isCategory数组是存储各个属性是连续的还是离散的,连续的话,值为0,离散的话,值为1。features[][]数组存放训练数据,label数组用来存储各个元组的标签值。

分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics