数据挖掘之凝聚层次聚类算法 AGNES 学习笔记

层次聚类是一种很直观而且重要的算法。与 K-means 一样,和许多聚类方法相比,这些方法相对较老,但是它们仍然被广泛使用。

有两种产生层次聚类的基本方法。

  • 凝聚的:从点作为个体簇开始,每一步合并两个最接近的簇。这需要定义簇的邻近度概念。
  • 分裂的:从包含所有点的某个簇开始,每一步分裂一个簇,直到仅剩下单点簇。在这种情况下,我们需要确定每一步分裂哪个簇,以及如何分裂。

到目前为止,凝聚层次聚类技术最常见。

101335h3te1bxmbk86zeqe

邻近度定义

  • 单链

    邻近度定义为两个不同簇中任意两点之间的最短距离(最大相似度)

    singlelink sl

  • 全链

    邻近度定义为两个不同簇中任意两点之间的最长距离(最小相似度)

    completelink cl

  • 组平均

    邻近度定义为两个不同簇中所有点对邻近度的平均值

    groupaverage groupaverageformula ga

  • Ward

    邻近度定义为两个不同簇合并时导致的平方误差增量

  • 质心方法

    邻近度定义为两个不同簇质心之间的距离

不同邻近性度量之间的比较:

compare

算法伪码

输入:包含 n 个对象的数据集,目标聚类个数 m。

输出:所有生成的簇。

1. 如果需要,计算量邻近度矩阵
2. REPEAT
3.     合并最接近的两个簇
4.     更新邻近度矩阵,以反映新的簇与原来的簇之间的临近度
5. UNTIL 达到目标聚类个数

优缺点

  • 优点

    层次聚类最大的优点,就是它一次性地得到了整个聚类的过程,只要得到了下图那样的聚类树,想要分多少个簇都可以直接根据树结构来得到结果。

  • 缺点

    层次聚类的缺点是计算量比较大,因为要每次都要计算多个簇内所有数据点的两两距离。另外,由于层次聚类使用的是贪心算法,得到的显然只是局域最优,不一定就是全局最优。

示例代码

import org.junit.Test;

import java.util.Vector;

/**
 * Created by Silocean on 2016-12-24.
 */
public class SimpleAGNES {

    @Test
    public void test() {
        int clusterNum = 4;
        Vector<SimplePoint> points = getInputData();
        Vector<SimpleCluster> clusters = agnes(points, clusterNum);
        for (SimpleCluster cluster : clusters) {
            for (SimplePoint point : cluster.getPoints()) {
                System.out.println(point);
            }
            System.out.println("========" + cluster.getClusterName() + "========\n");
        }
    }

    /**
     * 初始化测试数据
     *
     * @return
     */
    private Vector<SimplePoint> getInputData() {
        Vector<SimplePoint> points = new Vector<>();

        points.add(new SimplePoint(2, 3));
        points.add(new SimplePoint(2, 4));
        points.add(new SimplePoint(1, 4));
        points.add(new SimplePoint(1, 3));
        points.add(new SimplePoint(2, 2));
        points.add(new SimplePoint(3, 2));

        points.add(new SimplePoint(8, 7));
        points.add(new SimplePoint(8, 6));
        points.add(new SimplePoint(7, 7));
        points.add(new SimplePoint(7, 6));
        points.add(new SimplePoint(8, 5));

        points.add(new SimplePoint(100, 2));

        points.add(new SimplePoint(8, 20));
        points.add(new SimplePoint(8, 19));
        points.add(new SimplePoint(7, 18));
        points.add(new SimplePoint(7, 17));
        points.add(new SimplePoint(8, 20));
        return points;
    }

    /**
     * AGNES
     *
     * @param points
     * @param clusterNum
     * @return
     */
    private Vector<SimpleCluster> agnes(Vector<SimplePoint> points, int clusterNum) {
        // 聚类结果集(初始时每个point是一个簇)
        Vector<SimpleCluster> clusters = initClusters(points);

        while (clusters.size() > clusterNum) { // 达到指定的簇个数时结束算法
            double min = Double.MAX_VALUE;

            int indexP = 0;
            int indexQ = 0;

            // 选择两个簇进行比较
            for (int i = 0; i < clusters.size(); i++) {
                for (int j = 0; j < clusters.size(); j++) {
                    if (i != j) {
                        SimpleCluster clusterP = clusters.get(i);
                        SimpleCluster clusterQ = clusters.get(j);

                        double clustersSimilarity = getTwoClustersSimilarityByGroupAverage(clusterP, clusterQ);
                        if (clustersSimilarity < min) {
                            min = clustersSimilarity;
                            indexP = i;
                            indexQ = j;
                        }
                    }
                }
            }
            // 合并两个距离最近的簇
            clusters = mergeClusters(clusters, indexP, indexQ);
        }
        return clusters;
    }

    /**
     * 两个簇之间的距离
     * 单链:两个不同簇中任意两点之间的最短距离
     *
     * @param clusterP
     * @param clusterQ
     * @return
     */
    private double getTwoClustersSimilarityBySingleLink(SimpleCluster clusterP, SimpleCluster clusterQ) {
        Vector<SimplePoint> pointsP = clusterP.getPoints();
        Vector<SimplePoint> pointsQ = clusterQ.getPoints();

        double min = Double.MAX_VALUE;
        // 比较两个簇中所有点
        for (SimplePoint pointP : pointsP) {
            for (SimplePoint pointQ : pointsQ) {
                double distance = getDistance(pointP, pointQ);
                if (distance < min) {
                    min = distance;
                }
            }
        }
        return min;
    }

    /**
     * 两个簇之间的距离
     * 全链:两个不同簇中任意两点之间的最长距离
     *
     * @param clusterP
     * @param clusterQ
     * @return
     */
    private double getTwoClustersSimilarityByCompleteLink(SimpleCluster clusterP, SimpleCluster clusterQ) {
        Vector<SimplePoint> pointsP = clusterP.getPoints();
        Vector<SimplePoint> pointsQ = clusterQ.getPoints();

        double max = Double.MIN_VALUE;
        // 比较两个簇中所有点
        for (SimplePoint pointP : pointsP) {
            for (SimplePoint pointQ : pointsQ) {
                double distance = getDistance(pointP, pointQ);
                if (distance > max) {
                    max = distance;
                }
            }
        }
        return max;
    }

    /**
     * 两个簇之间的距离
     * 组平均:两个不同簇中所有点对邻近度的平均值
     *
     * @param clusterP
     * @param clusterQ
     * @return
     */
    private double getTwoClustersSimilarityByGroupAverage(SimpleCluster clusterP, SimpleCluster clusterQ) {
        Vector<SimplePoint> pointsP = clusterP.getPoints();
        Vector<SimplePoint> pointsQ = clusterQ.getPoints();

        double sum = 0;
        double num = pointsP.size() * pointsQ.size();
        // 比较两个簇中所有点
        for (SimplePoint pointP : pointsP) {
            for (SimplePoint pointQ : pointsQ) {
                double distance = getDistance(pointP, pointQ);
                sum += distance;
            }
        }
        return sum / num;
    }

    /**
     * 合并两个距离最近的簇
     * (把q中的点全部添加到p中,并从clusters中删除q)
     *
     * @param clusters
     * @param indexP
     * @param indexQ
     * @return
     */
    private Vector<SimpleCluster> mergeClusters(Vector<SimpleCluster> clusters, int indexP, int indexQ) {
        SimpleCluster clusterP = clusters.get(indexP);
        SimpleCluster clusterQ = clusters.get(indexQ);

        Vector<SimplePoint> pointsP = clusterP.getPoints();
        Vector<SimplePoint> pointsQ = clusterQ.getPoints();

        for (SimplePoint point : pointsQ) {
            pointsP.add(point);
        }

        clusterP.setPoints(pointsP);
        clusters.remove(indexQ);
        return clusters;
    }

    /**
     * 计算两点之间距离
     *
     * @param p1
     * @param p2
     * @return
     */
    private double getDistance(SimplePoint p1, SimplePoint p2) {
        return Math.sqrt((p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y) * (p1.y - p2.y));
    }

    /**
     * 初始化聚簇(每个point都是一个单独的簇)
     *
     * @param points
     * @return
     */
    private Vector<SimpleCluster> initClusters(Vector<SimplePoint> points) {
        Vector<SimpleCluster> initialClusters = new Vector<>();
        for (int i = 0; i < points.size(); i++) {
            Vector<SimplePoint> tmpPoints = new Vector<>();
            tmpPoints.add(points.get(i));

            SimpleCluster tmpCluster = new SimpleCluster();
            tmpCluster.setPoints(tmpPoints);
            tmpCluster.setClusterName("SimpleCluster:" + i);

            initialClusters.add(tmpCluster);
        }
        return initialClusters;
    }


}

class SimpleCluster {
    private Vector<SimplePoint> points = new Vector<>(); // 类簇中的样本点
    private String clusterName;

    public Vector<SimplePoint> getPoints() {
        return points;
    }

    public void setPoints(Vector<SimplePoint> points) {
        this.points = points;
    }

    public String getClusterName() {
        return clusterName;
    }

    public void setClusterName(String clusterName) {
        this.clusterName = clusterName;
    }
}

public class SimplePoint {

    double x;
    double y;

    public SimplePoint(double x, double y) {
        this.x = x;
        this.y = y;
    }

    @Override
    public String toString() {
        return "SimplePoint{" +
                "x=" + x +
                ", y=" + y +
                '}';
    }
}

输出结果为:

SimplePoint{x=2.0, y=2.0}
SimplePoint{x=3.0, y=2.0}
SimplePoint{x=2.0, y=3.0}
SimplePoint{x=2.0, y=4.0}
SimplePoint{x=1.0, y=4.0}
SimplePoint{x=1.0, y=3.0}
========SimpleCluster:4========

SimplePoint{x=8.0, y=7.0}
SimplePoint{x=8.0, y=6.0}
SimplePoint{x=7.0, y=7.0}
SimplePoint{x=7.0, y=6.0}
SimplePoint{x=8.0, y=5.0}
========SimpleCluster:6========

SimplePoint{x=100.0, y=2.0}
========SimpleCluster:11========

SimplePoint{x=7.0, y=18.0}
SimplePoint{x=7.0, y=17.0}
SimplePoint{x=8.0, y=20.0}
SimplePoint{x=8.0, y=20.0}
SimplePoint{x=8.0, y=19.0}
========SimpleCluster:14========

参考资料