`
wangqisen
  • 浏览: 47024 次
文章分类
社区版块
存档分类
最新评论

k-means算法

 
阅读更多

下面是我对该算法的实现:

public class Kmeans {

	private int K;
	private int colsNum;
	private int rowsNum;
	private double[][] kMedians=null;
	private double[][]myFeatures=null;
	private HashMap<Integer,Integer> map=new HashMap<Integer, Integer>();
	
	
    public Kmeans() {
    }

    /*
     * Input double[numIns][numAtt] features, int K
     * Output double[K][numAtt] clusterCenters, int[numIns] clusterIndex
     * 
     * clusterCenters[k] should store the kth cluster center
     * clusterIndex[i] should store the cluster index which the ith sample belongs to
     */
    public void train(double[][] features, int K, double[][] clusterCenters, int[] clusterIndex) {
    	this.colsNum=features[0].length;
    	this.rowsNum=features.length;
        this.kMedians=new double[K][colsNum];
        this.myFeatures=features.clone();
        this.K=K;
        preHandle();
        init();
        setCluster();
        while(true){
        	reSetMedian();
        	if(setCluster())
        		break;
        }
        features=myFeatures.clone();
        for(int i=0;i<kMedians.length;i++){
        	for(int j=0;j<kMedians[0].length;j++)
        		clusterCenters[i][j]=kMedians[i][j];
        }
        for(int i=0;i<rowsNum;i++){
        	clusterIndex[i]=map.get(i);
        }
        kMedians=null;
        myFeatures=null;
        map.clear();
        map=null;
    }
    
    public void preHandle(){//对数据进行预处理,替换掉NaN型数据
		for(int i=0;i<myFeatures.length;i++){
			double line[]=myFeatures[i];
			for(int j=0;j<line.length;j++){
				if(line[j]!=line[j]){
					myFeatures[i][j]=getAverage(j);
				}
			}
		}
	}
    
    public double getAverage(int attr){
    	double sum=0;
    	int  k=0;
    	for(int i=0;i<rowsNum;i++){
    		double []row=myFeatures[i];
    		if(row[attr]==row[attr]){
    			sum+=row[attr];
    			k++;
    		}
    	}
    	return sum*1.0/k;
    }
    
    public boolean setCluster(){
    	boolean flag=true;
    	for(int i=0;i<rowsNum;i++){
    		double []row=myFeatures[i];
    		double []distances=new double[kMedians.length];
    		for(int j=0;j<kMedians.length;j++){
    			distances[j]=computeDistance(row, kMedians[j]);
    		}
    		int minIndex=0;
    		double min=Double.MAX_VALUE;
    		for(int j=0;j<distances.length;j++){
    			if(distances[j]<min){
    				minIndex=j;
    				min=distances[j];
    			}
    		}
    		int oldCluster=-1;
    		if(map.get(i)!=null)
    			oldCluster=map.get(i);
    		if(minIndex!=oldCluster)
    			flag=false;
    		map.put(i,minIndex);
    	}
    	return flag;
    }
    
    public void reSetMedian(){
    	double [][]sum=new double[K][colsNum];
    	int []num=new int[K];
    	Arrays.fill(num,0);
    	for(int i=0;i<K;i++)
    		for(int j=0;j<colsNum;j++)
    			sum[i][j]=0;
    	Iterator<Entry<Integer, Integer>> ite=map.entrySet().iterator();
    	while(ite.hasNext()){
    		Entry<Integer,Integer> entry=ite.next();
    		int rowNum=entry.getKey();
    		int clusterNum=entry.getValue();
    		num[clusterNum]++;
    		double []row=Arrays.copyOf(myFeatures[rowNum],colsNum);
    		for(int i=0;i<colsNum;i++){
    			sum[clusterNum][i]+=row[i];
    		}
    	}
    	for(int i=0;i<K;i++){
    		for(int j=0;j<colsNum;j++){
    			double ave=sum[i][j]*1.0/num[i];
    			kMedians[i][j]=ave;
    		}
    	}
    }
    
    public double computeDistance(double[] row,double[] median){
    	double sum=0;
    	for(int i=0;i<colsNum;i++){
    		sum+=Math.pow(row[i]-median[i],2);
    	}
    	return sum;
    }
    
    public void init(){
    	int index=0,i;
    	HashSet<Integer> rows=new HashSet<Integer>();
    	while(index<K){
    		int row=(int)(Math.random()*rowsNum);
    		if(rows.contains(row))
    			continue;
    		double []tempRow=myFeatures[row];
    		boolean flag1=true;
    		for(i=0;i<K;i++){
    			double []temp=kMedians[i];
    			boolean flag2=false;
    			for(int j=0;j<colsNum;j++){
    				if(temp[j]!=tempRow[j])
    					flag2=true;
    			}
    			if(flag2==false)
    				flag1=false;
    		}
    		if(!flag1)
    			continue;
    		for(i=0;i<tempRow.length;i++){
    			if(tempRow[i]!=tempRow[i])
    				break;
    		}
    		if(i>=tempRow.length){
    			rows.add(row);
    			kMedians[index++]=Arrays.copyOf(myFeatures[row],colsNum);
    		}
    	}
    	rows.clear(); 	
    	rows=null;
    }
    
}


分享到:
评论

相关推荐

Global site tag (gtag.js) - Google Analytics