K-Dimensional Tree

2014.12.24 11:05 Wed| 10 visits oi_2015| 2015_算法笔记| Text

K-D Tree

概述

K-D Tree主要能够解决 K 维空间中的距离最值(k大值)问题。

比如,给你平面(2维空间)上的 n 个点,多次询问,每次询问到某个给定点 $(x,y)$ 曼哈顿距离(欧几里德距离)最小的点。

如果暴力枚举的话显然是 $\mathcal O(n^2)$ 的,无法承受。而K-D Tree可以用以优化暴力。

构造

首先确定一个维度做为基准,找出这个维度坐标是目前所有点这个维度坐标中位数的点,将其作为根,其左子树保存的是所有这个维度坐标小于这个点该维度坐标的点,其右子树保存的是所有这个维度坐标大于这个点该维度坐标的点。对于这个点,我们要保存的是以这个点为根的所有的点的所有维度的坐标范围,即每一个维度坐标的上下界。随后对于两颗子树递归进行划分。

容易发现,KDTree的空间复杂度为 $\mathcal O(n)$ 。

维度的确定

一种说法是每一次选择方差最小的那一维。但实践中,我们只需要依次枚举维度进行划分即可。例如三维空间中先按照 $x$ 坐标划分,再按照 $y$ 坐标划分,再按照 $z$ 坐标划分,再按照 $x$ 坐标划分,以此类推。

代码实现

寻找中位数点可以利用STL函数 nth_element 完成,这样建树的复杂度就为 $O(nlogn)$ 。

另外由于是优化暴力,所以常数优化非常关键,因此K-D Tree的主体需要用指针实现,否则很容易被Karp-De-Chant。

处理询问

对于询问,从树的根节点开始,首先用这个点对应的平面上的点更新答案。现在考虑左儿子和右儿子,我们分别计算出询问点到达两个矩形的估价,记作$distl,distr$。优先遍历曼哈顿距离比较小的儿子的子树。之后,考虑另外一个儿子,显然当以另外一个儿子为根的子树可能存在更优的答案时我们才会去遍历它。这等价于 $disl(disr)<nowans$ ,当满足这个条件时我们才去遍历另一颗子树。

一般情况下这个剪枝非常靠谱,它只需要 $\mathcal O(\sqrt{n})$ 的时间就能完成一次询问!

不过这是为什么呢?至今没有找到证明,不过确实是卡不掉的。

处理插入

我们只需要从根节点开始向下插入就好了。首先更新当前点的各维度上下界,随后按照当前的维度看一下应该是插入到左子树还是右子树,若对应子树为空直接插入,否则再递归插入即可。

这样一次插入的复杂度为$\mathcal O(logn)$。但是如果随着插入树变得不平衡怎么办?一种方法是每隔若干次操作对树进行暴力重构,或者像替罪羊树一样设立一个平衡因子。但是不进行处理很多情况下也没关系。在插入的点是随机的情况下可以姑且认为每次插入为 $\mathcal O(logn)$ 。

距离估价

我们要求的经常是曼哈顿距离的最值以及欧几里德距离的最值。下面列举了处理这些问题时的估价。

曼哈顿距离最小值:到空间边界或内部的最小曼哈顿距离。

曼哈顿距离最大值:到空间端点的最大曼哈顿距离。

欧几里德距离最小值:当前点划分依据的维度坐标与边界距离最小值的平方。

欧几里德距离最大值:每一维度坐标与边界距离最大值的平方和。

K-D Tree的时间复杂度:若求出第K大单组询问的时间为$O(\sqrt{n}logk)$

代码实现

Source: BZOJ2648&&BZOJ2716

#include <bits/stdc++.h>
using namespace std;
#define N 1000005
#define inf 0x3f3f3f3f
#define lson(x) (x->ch[0])
#define rson(x) (x->ch[1])
#define son(x,d) (x->ch[d])

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

int pos;
struct point
{
    int x,y;
    point(){}
    point(int _x,int _y):x(_x),y(_y) {}
    point(bool flag):x(read()),y(read()) {}
    bool operator <(const point &p) const
    {
        return pos?y<p.y:x<p.x;
    }
    inline friend ostream &operator <<(ostream &out, const point &p)
    {
        out<<'('<<p.x<<','<<p.y<<')';
        return out;
    }
    inline friend int dist(const point &p1,const point &p2)
    {
        return abs(p1.x-p2.x)+abs(p1.y-p2.y);
    }
};

struct kdtree
{
    point p;
    kdtree *ch[2];
    int l,r,u,d;
    kdtree(){}
    kdtree(const point &_p):p(_p) { l=r=_p.x; u=d=_p.y; }
    void *operator new(size_t size);
    inline void update(int c)
    {
        l=min(l,ch[c]->l); r=max(r,ch[c]->r);
        u=max(u,ch[c]->u); d=min(d,ch[c]->d);
    }
    inline int mindist(const point &_p)
    {
        int re=0;
        if(_p.x<l) re+=l-_p.x;
        else if(_p.x>r) re+=_p.x-r;
        if(_p.y>u) re+=_p.y-u;
        else if(_p.y<d) re+=d-_p.y;
        return re;
    }
}newnode[N],*null=new kdtree();
void *kdtree::operator new(size_t size)
{
    static kdtree *P=newnode;
    return P++;
}

int n,m;
point p[N];

kdtree *build(int l,int r,int _pos=0)
{
    if(l>r) return null;
    pos=_pos;
    int mid=(l+r)>>1;
    nth_element(p+l,p+mid,p+r+1);
    kdtree *re=new kdtree(p[mid]);
    lson(re)=build(l,mid-1,_pos^1);
    rson(re)=build(mid+1,r,_pos^1);
    if(lson(re)!=null) re->update(0);
    if(rson(re)!=null) re->update(1);
    return re;
}

void insert(kdtree* &root,point p,int _pos=0)
{
    if(root==null)
    {
        root=new kdtree(p);
        lson(root)=rson(root)=null;
        return;
    }
    pos=_pos;
    if(p<root->p) insert(lson(root),p,_pos^1), root->update(0);
    else insert(rson(root),p,_pos^1), root->update(1);
}

int ans;
void query(kdtree *root,point p)
{
    ans=min(ans,dist(root->p,p));
    int dist[2]={lson(root)==null?inf:lson(root)->mindist(p),
        rson(root)==null?inf:rson(root)->mindist(p)};
    int d=dist[0]>dist[1];
    if(son(root,d)!=null) query(son(root,d),p);
    if(son(root,d^1)!=null&&ans>dist[d^1])
        query(son(root,d^1),p);
}

int main()
{
    cin>>n>>m;
    for(int i=1;i<=n;i++) p[i]=point(true);
    kdtree *root=build(1,n);
    for(int opt,i=1;i<=m;i++)
    {
        opt=read()-1;
        if(!opt) insert(root,point(true));
        else ans=inf,query(root,point(true)),printf("%d\n",ans);
    }
}