数据挖掘之 DBSCAN 算法学习笔记

DBSCAN (Density-Based Spatial Clustering of Applications with Noise) 是一种基于密度的聚类算法,基于密度的聚类寻找被低密度区域分离的高密度区域。

图片2

算法基本定义

DBSCAN 中定义使用基于中心的方法定义密度,在基于中心的方法中,数据集中特定点的密度通过对该点 Eps 半径之内的点计数(包括点本身)来估计。

密度的基于中心的方法使得我们可以将点分类为:

  1. 核心点:稠密区域内部的点,该点的 Eps 邻域(半径)内至少包含 MinPts 个点。
  2. 边界点:稠密区域边缘上的点,该点在核心点的邻域内,但不是核心点。
  3. 噪声点:稀疏区域中的点,既不是核心点也不是边界点。

图片3

直接密度可达:给定一个点集合 D,如果 p 在 q 的 Eps 邻域内,而 q 是一个核心点,则称点 p 从点 q 出发时是直接密度可达的(directly density-reachable)。

密度可达:如果存在一个点链 p1,p2,…,pn,p1=q,pn=p 对于 pi∈D(1≤i≤n),pi+1 是从 pi 关于 Eps 和 MinPts 直接密度可达的,则点 p 是从点 q 关于 Eps 和 Minpts 密度可达的(density-reachable)。

密度相连:如果存在点 O∈D,使点 p 和 q 都是从 O 关于 Eps 和 MinPts 密度可达的,那么点 p 到 q 是关于 Eps 和 MinPts 密度相连的(density-connected)。

图片4

算法伪码

输入:包含 n 个对象的数据集,半径 Eps,最少数目 MinPts。

输出:所有生成的簇,达到密度要求。

1. REPEAT
2.     从数据库中抽取一个未处理过的点
3.     IF 抽出的点是核心点 
4.         THEN 找出所有从该点密度可达的对象,形成一个簇
5.     ELSE 抽出的点是边缘点(边界点或噪声点)
6.         THEN 跳出本次循环,寻找下一点
6. UNTIL 所有点都被处理

时间复杂度和空间复杂度

DBSCAN 的基本时间复杂度是 O(n * 找出 Eps 邻域中的点所需要的时间),其中 n 是点的个数。在最坏的情况下,时间复杂度是 O(n^2)。然而,在低维空间,有一些数据结构,如 kd 树,可以有效地检索特定点给定距离内的所有点,时间复杂度可以降低到O(nlogn)。即便对于高维数据,DBSCAN 的空间也是 O(n),因为对于每个点,它只需要维持少量数据,即簇标号和每个点是核心点、边界点还是噪声点的标识。

优缺点

  • 优点:
    1. 基于密度定义,相对抗噪音
    2. 能处理任意形状和大小的簇
    3. 不用事先决定要分成几类(K-means 就需要事先定义簇的个数)
  • 缺点
    1. 输入参数敏感,确定参数 Eps,MinPts 困难,若选取不当,将造成聚类质量下降。
    2. 由于在 DBSCAN 算法中,变量 Eps,MinPts 是全局惟一的,当空间聚类的密度不均匀、聚类间距离相差很大时,聚类质量较差。
    3. 对于高维数据,密度定义更加困难。

示例代码

import org.junit.Test;

import java.util.Vector;

/**
 * Created by Silocean on 2016-12-08.
 */
public class SimpleDBSCAN {

    private double eps = 1; // 半径
    private int minpts = 1; // 指定半径内点数阈值

    private Vector<SimplePoint> points = new Vector<>(); // 输入数据
    private Vector<SimplePoint> visitedPoints = new Vector<>(); // 访问过的点
    private Vector<SimplePoint> neighbours = new Vector<>(); // 单个cluster
    private Vector<Vector<SimplePoint>> clusters = new Vector<>(); // 聚类结果集(clusters)

    /**
     * 初始化测试数据
     */
    private void initData() {
        points.add(new SimplePoint(1, 1));
        points.add(new SimplePoint(2, 0));
        points.add(new SimplePoint(2, 1));
        points.add(new SimplePoint(1, 4));
        points.add(new SimplePoint(3, 3));
        points.add(new SimplePoint(3, 4));
        points.add(new SimplePoint(4, 2));
        points.add(new SimplePoint(4, 3));
        points.add(new SimplePoint(4, 4));
    }

    @Test
    public void test() {
        dbscan();
        for (Vector<SimplePoint> cluster : clusters) {
            for (SimplePoint point : cluster) {
                System.out.println(point);
            }
            System.out.println("==========");
        }
    }

    /**
     * DBSCAN
     */
    private void dbscan() {
        initData();
        for (SimplePoint point : points) {
            if (!isVisited(point)) { // 如果该点没有被访问过
                visit(point); // 标记为访问过
                neighbours = getNeighbours(point);
                if (neighbours.size() >= minpts) { // 如果邻域内的点数超过阈值,对其中每个点分别做进一步判断
                    int index = 0;
                    while (index < neighbours.size()) {
                        SimplePoint p = neighbours.get(index);
                        if (!isVisited(p)) {
                            visit(p);
                            Vector<SimplePoint> neigh = getNeighbours(p);
                            if (neigh.size() >= minpts) { // 如果邻域内某个点的邻域中的点数也超过阈值,加入到cluster中
                                neighbours = merge(neighbours, neigh);
                            }
                        }
                        index++;
                    }
                    clusters.add(neighbours);
                }
            }
        }
    }

    /**
     * 判断该点是否被访问过
     *
     * @param point
     * @return
     */
    private boolean isVisited(SimplePoint point) {
        return visitedPoints.contains(point);
    }

    /**
     * 访问该点
     *
     * @param point
     */
    private void visit(SimplePoint point) {
        visitedPoints.add(point);
    }

    /**
     * 获取某个点指定邻域内的所有点(初级版)
     * 进阶版可以使用kd-tree来降低时间复杂度(待实现)
     *
     * @param point
     * @return
     */
    private Vector<SimplePoint> getNeighbours(SimplePoint point) {
        Vector<SimplePoint> neighbours = new Vector<>();
        for (SimplePoint stayPoint : points) {
            if (getDistance(stayPoint, point) <= eps) {
                neighbours.add(stayPoint);
            }
        }
        return neighbours;
    }

    /**
     * 计算两点之间距离
     *
     * @param p1
     * @param p2
     * @return
     */
    private double getDistance(SimplePoint p1, SimplePoint p2) {
        return Math.sqrt((p1.x - p2.x) * (p1.x - p2.x) + (p1.y - p2.y) * (p1.y - p2.y));
    }

    /**
     * 将b中的所有点合并到a中
     *
     * @param a
     * @param b
     * @return
     */
    private Vector<SimplePoint> merge(Vector<SimplePoint> a, Vector<SimplePoint> b) {
        for (SimplePoint point : b) {
            if (!a.contains(point)) {
                a.add(point);
            }
        }
        return a;
    }

}

class SimplePoint {
    int x;
    int y;

    public SimplePoint(int x, int y) {
        this.x = x;
        this.y = y;
    }

    @Override
    public String toString() {
        return "SimplePoint{" +
                "x=" + x +
                ", y=" + y +
                '}';
    }
}

输出结果为:

SimplePoint{x=1, y=1}
SimplePoint{x=2, y=1}
SimplePoint{x=2, y=0}
=====================
SimplePoint{x=1, y=4}
=====================
SimplePoint{x=3, y=3}
SimplePoint{x=3, y=4}
SimplePoint{x=4, y=3}
SimplePoint{x=4, y=4}
SimplePoint{x=4, y=2}
=====================

参考资料

  • 《数据挖掘导论》