BZOJ3053&HDU4347 The Closest M Points

2014.12.24 22:57 Wed| 3 visits oi_2015| 2015_刷题日常| Text

Solution

这才是真正的K-D Tree!我承认本蒟蒻的智商已经消耗殆尽了……

什么叫做把一百好几十行的代码愣是对着标程压到了90多行的感受!

开始WA好久,发现多组数据……

然后开始不停又WA又T……

另外发现询问的时候选择左右子树的过程中好像不需要估价函数,直接判断当前点在剩余范围的中点的哪一方向即可。估价函数还是那么的不好写,还怕写挂……。

开始怀疑人生……

Code

#include <queue>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define N 50005
#define sqr(x) ((x)*(x))
#define lson(x) (x->ch[0])
#define rson(x) (x->ch[1])
#define son(x,d) (x->ch[d])

int n,k,pos,m;

inline int read()
{
    int x=0,f=1;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
    return x*f;
}

struct point
{
    int v[5];
    point(){}
    point(bool flag)
    {
        for(int i=0;i<k;i++) v[i]=read();
    }
    inline void print()
    {
        for(int i=0;i<k;i++)
            printf("%d%c",v[i]," \n"[i==k-1]);
    }
    bool operator <(const point &p) const
    {
        return v[pos]<p.v[pos];
    }
    inline friend int dist(const point &p1,const point &p2)
    {
        int re=0;
        for(int i=0;i<k;i++) re+=sqr(p1.v[i]-p2.v[i]);
        return re;
    }
}p[N],ans[15];

struct kdtree
{
    point p;
    kdtree *ch[2];
    kdtree(){}
    kdtree(const point &_p):p(_p){}
}*root,*null=new kdtree();

kdtree *build(int l,int r,int _pos=0)
{
    if(l>r) return null;
    pos=_pos;
    int mid=(l+r)>>1;
    nth_element(p+l,p+mid,p+r+1);
    kdtree *re=new kdtree(p[mid]);
    lson(re)=build(l,mid-1,(_pos+1)%k);
    rson(re)=build(mid+1,r,(_pos+1)%k);
    return re;
}

priority_queue< pair<int,point> > q;
void querymin(kdtree *root,point a,int t,int _pos=0)
{
    pos=_pos;
    int d=root->p<a;
    if(son(root,d)!=null) querymin(son(root,d),a,t,(_pos+1)%k);
    q.push(make_pair(dist(root->p,a),root->p));
    while(int(q.size())>t) q.pop();
    if(son(root,d^1)!=null&&(int(q.size())<t||sqr(a.v[_pos]-root->p.v[_pos])<q.top().first))
        querymin(son(root,d^1),a,t,(_pos+1)%k);
}

int main()
{
    while(scanf("%d%d",&n,&k)!=EOF)
    {
        for(int i=1;i<=n;i++)
            p[i]=point(true);
        root=build(1,n); m=read();
        for(int i=1;i<=m;i++)
        {
            int x,cnt=0;
            point t(true);
            querymin(root,t,x=read());
            printf("the closest %d points are:\n", x);
            while(!q.empty()) ans[cnt++]=q.top().second, q.pop();
            while(cnt) ans[--cnt].print();
        }
    }
}