cdq分治
这是一个用来解决多维限制问题的离线算法
由原yali巨佬cdq提出
【模板】三维偏序(陌上花开)
模板题
我们在学习树状数组时一定做过二维偏序的题目
逆序对就是二维偏序
那么三维偏序怎么用cdq分治解决呢
我们类比二维偏序,先按第一维排序
接下来进行分治
对于一个区间[l,r]
我们递归计算好了[l,mid],[mid+1,r]这两个区间的答案
并将这两个区间分别按y排好序了
由于我们只需考虑左区间对右区间的贡献
我们最终的目标是将整个区间也按y排序
我们用归并处理
当这一个位置是原左区间的数时,我们把它的z放进树状数组里
由于后面进来的右区间的数y比该数大,x当然也更大
故我们查一边树状数组就能得到一个右区间数的答案
处理完成后记得清空树状数组
#include<bits/stdc++.h>
using namespace std;
const int maxn=100005;
struct node{
int x,y,z,ans,v;
bool operator <(node i)const{
return x==i.x?(y==i.y?z<i.z:y<i.y):x<i.x;
}
}a[maxn],b[maxn];
int n,k;
int tree[200005];
int ans[maxn];
int read(){
int x=0,y=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')y=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*y;
}
int lowbit(int x){
return x&-x;
}
void update(int x,int val){
while(x<=k){
tree[x]+=val;
x+=lowbit(x);
}
return ;
}
int query(int x){
int ans=0;
while(x){
ans+=tree[x];
x-=lowbit(x);
}
return ans;
}
void merge_sort(int l,int r,int mid){
int p=l,q=mid+1,len=l;
while(p<=mid&&q<=r){
if(a[p].y<=a[q].y){
update(a[p].z,a[p].v);
b[len++]=a[p++];
}
else{
a[q].ans+=query(a[q].z);
b[len++]=a[q++];
}
}
while(p<=mid){
update(a[p].z,a[p].v);
b[len++]=a[p++];
}
while(q<=r){
a[q].ans+=query(a[q].z);
b[len++]=a[q++];
}
for(int i=l;i<=mid;i++)update(a[i].z,-a[i].v);
for(int i=l;i<=r;i++)a[i]=b[i];
return ;
}
void cdq(int l,int r){
if(l==r)return ;
int mid=l+((r-l)>>1);
cdq(l,mid);
cdq(mid+1,r);
merge_sort(l,r,mid);
return ;
}
int main(){
n=read();k=read();
for(int i=1;i<=n;i++){
a[i].x=read();a[i].y=read();a[i].z=read();
a[i].ans=0;
a[i].v=1;
}
sort(a+1,a+n+1);
int p=1;
for(int i=2;i<=n;i++){
if(a[i].x==a[p].x&&a[i].y==a[p].y&&a[i].z==a[p].z)
a[p].v++;
else
a[++p]=a[i];
}
cdq(1,p);
for(int i=1;i<=p;i++)ans[a[i].ans+a[i].v-1]+=a[i].v;
for(int i=0;i<n;i++)
printf("%d\n",ans[i]);
return 0;
}
[CQOI2011]动态逆序对
我们可以用nlog的时间跑出来第一行的答案
接下来我们删掉一个点i,逆序对数应减去所有点对(i,j)的数量
(i,j)满足
<,>,>或,<,>
为该点被删去的时间(不会被删默认m+1)
很明显,这是个三维偏序问题
我们用刚刚学到的cdq就能解决
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=100005;
struct node{
int x,y,z,ans;
}a[maxn],b[maxn];
int n,m;
int tree[maxn];
int c[maxn],pos[maxn];
int ans[maxn];
int read(){
int x=0,y=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')y=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*y;
}
bool cmp1(node i,node j){
return i.x==j.x?(i.y==j.y?i.z>j.z:i.y>j.y):i.x<j.x;
}
bool cmp2(node i,node j){
return i.x==j.x?(i.y==j.y?i.z>j.z:i.y<j.y):i.x>j.x;
}
int lowbit(int x){
return x&-x;
}
void update(int x,int val){
while(x<=n+1){
tree[x]+=val;
x+=lowbit(x);
}
return ;
}
int query(int x){
int ans=0;
while(x){
ans+=tree[x];
x-=lowbit(x);
}
return ans;
}
void merge_sort1(int l,int r,int mid){
int p=l,q=mid+1,len=l;
while(p<=mid&&q<=r){
if(a[p].y>a[q].y){
update(a[p].z,1);
b[len++]=a[p++];
}
else{
a[q].ans+=query(m+1)-query(a[q].z);
b[len++]=a[q++];
}
}
while(p<=mid){
update(a[p].z,1);
b[len++]=a[p++];
}
while(q<=r){
a[q].ans+=query(m+1)-query(a[q].z);
b[len++]=a[q++];
}
for(int i=l;i<=mid;i++)update(a[i].z,-1);
for(int i=l;i<=r;i++)a[i]=b[i];
return ;
}
void merge_sort2(int l,int r,int mid){
int p=l,q=mid+1,len=l;
while(p<=mid&&q<=r){
if(a[p].y<a[q].y){
update(a[p].z,1);
b[len++]=a[p++];
}
else{
a[q].ans+=query(m+1)-query(a[q].z);
b[len++]=a[q++];
}
}
while(p<=mid){
update(a[p].z,1);
b[len++]=a[p++];
}
while(q<=r){
a[q].ans+=query(m+1)-query(a[q].z);
b[len++]=a[q++];
}
for(int i=l;i<=mid;i++)update(a[i].z,-1);
for(int i=l;i<=r;i++)a[i]=b[i];
return ;
}
void cdq(int l,int r,bool flag){
if(l==r)return ;
int mid=l+((r-l)>>1);
cdq(l,mid,flag);
cdq(mid+1,r,flag);
if(!flag)merge_sort1(l,r,mid);
else merge_sort2(l,r,mid);
return ;
}
signed main(){
n=read();m=read();
for(int i=1;i<=n;i++){
a[i].x=i;a[i].y=read();a[i].z=m+1;
a[i].ans=0;
c[i]=a[i].y;
pos[c[i]]=i;
}
for(int i=1;i<=m;i++)a[pos[read()]].z=i;
sort(a+1,a+n+1,cmp1);
cdq(1,n,0);
sort(a+1,a+n+1,cmp2);
cdq(1,n,1);
for(int i=1;i<=n;i++)ans[a[i].z]=a[i].ans;
int ANS=0;
for(int i=1;i<=n;i++){
ANS+=i-1-query(c[i]);
update(c[i],1);
}
for(int i=1;i<=m;i++){
printf("%lld\n",ANS);
ANS-=ans[i];
}
return 0;
}
练习
[Violet]天使玩偶/SJY摆棋子
简单题
tips:
练习题WA了n次的教训,cdq分治的排序所有维度都要排到,不能只排一维
需要这种
bool cmp1(node i,node j){ return i.x==j.x?(i.y==j.y?i.z>j.z:i.y>j.y):i.x<j.x; }
下面这种会在递归到下层时出问题
bool cmp1(node i,node j){ return i.x<j.x; }
[SDOI2011]拦截导弹
导弹拦截拓展版
与大家都做过的入门级习题不同的是,这题不仅有高度还有速度
也就是是一个三维问题(分别是高度,速度,时间即输入顺序,下面我会分别用x,y,z表示这三维)
我们考虑借鉴只有二维时的思路
这一道题的第一问其实就是让我们求多了一维限制的最长不上升子序列
而第二问是求每枚导弹被拦截掉的概率,即每枚导弹在所有最长不上升子序列方案里的出现次数/总最长不上升子序列方案数
我们可以比较套路的定义一下dp
定义:
dp1[i]表示以第i枚导弹为结尾的方案中的最长不上升子序列长度
dp2[i]表示以第i枚导弹为开头的方案中的最长不上升子序列长度
f1[i]表示以第i枚导弹为结尾的最长不上升子序列方案总数
f2[i]表示以第i枚导弹为开头的最长不上升子序列方案总数
那么第一问的答案即max{dp1[i]}或者max{dp2[i]}
第二问第i枚导弹的答案即f1[i]*f2[i](乘法原理)且dp1[i]+dp2[i]=ans-1,ans为第一问答案,减1是i导弹被计算了两次
故我们只要把dp解决这个问题就结束了
显然O(n^2)的dp类比导弹拦截能很容易的得出
//dp1的代码,dp2对称一下容易得到
//这里默认z按从小到大排好了
for(int i=1;i<=n;i++){
for(int j=1;j<i;j++){
if(a[i].x<=a[j].x&&a[i].y<=a[j].y){
if(dp1[i]==dp1[j]+1)
f1[i]+=f1[j];
else if(dp1[i]<dp1[j]+1){
dp1[i]=dp1[j]+1;
f1[i]=f1[j];
}
}
}
}
这是一个三维偏序的相关问题
我们考虑用cdq分治来优化dp
这里由于要用到树状数组,我懒得离散化x和y所以我用z搞树状数组,用第一维排序
我们把按x排好序后发现只能左边的导弹的dp值贡献给右边
这里我们应该先把左边的dp值都算完了再去更新右边的
故这里的递归顺序应是左中右
这个地方特别注意
bool cmp1(node i,node j){
return i.x==j.x?(i.y==j.y?i.z<j.z:i.y>j.y):i.x>j.x;
}
bool cmp2(node i,node j){
return i.x==j.x?(i.y==j.y?i.z<j.z:i.y<j.y):i.x<j.x;
}
bool cmp3(node i,node j){
return i.y>j.y;
}
bool cmp4(node i,node j){
return i.y<j.y;
}
//顺序什么的可以自己搞定
void cdq(int l,int r,int opt){//opt==1是计算dp1,2是计算dp2
if(l==r)return ;
int mid=l+((r-l)>>1);
cdq(l,mid,opt);//先把左边算好
if(opt==1){
sort(a+l,a+mid+1,cmp3);
sort(a+mid+1,a+r+1,cmp3);
}
else{
sort(a+l,a+mid+1,cmp4);
sort(a+mid+1,a+r+1,cmp4);
}
merge_sort(l,r,opt);//把左边对右边的贡献加到右边上
//这样写而不是先去递归右边主要是怕右边的左半部分还每加上左边的贡献就去算对右边右半部分的贡献了
if(opt==1)sort(a+mid+1,a+r+1,cmp1);
else sort(a+mid+1,a+r+1,cmp2);
//记得排回来
cdq(mid+1,r,opt);
return ;
}
然后注意树状数组维护时算f会比较麻烦
因为f需要所有方案的和
放代码
#include<bits/stdc++.h>
using namespace std;
const int maxn=50005;
struct node{
int x,y,z;
int dp;
double f;
}a[maxn];
int n;
int tree1[maxn];
double tree2[maxn];
int dp1[maxn],dp2[maxn];
double f1[maxn],f2[maxn];
int read(){
int x=0,y=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')y=-1;ch=getchar();}
while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
return x*y;
}
bool cmp1(node i,node j){
return i.x==j.x?(i.y==j.y?i.z<j.z:i.y>j.y):i.x>j.x;
}
bool cmp2(node i,node j){
return i.x==j.x?(i.y==j.y?i.z<j.z:i.y<j.y):i.x<j.x;
}
bool cmp3(node i,node j){
return i.y>j.y;
}
bool cmp4(node i,node j){
return i.y<j.y;
}
int lowbit(int x){
return x&-x;
}
void update(int x,int dp,double f){
while(x<=n){
if(tree1[x]<dp){
tree1[x]=dp;
tree2[x]=f;
}
else if(tree1[x]==dp)
tree2[x]+=f;
x+=lowbit(x);
}
return ;
}
int query1(int x){
int ans=0;
while(x){
ans=max(ans,tree1[x]);
x-=lowbit(x);
}
return ans;
}
double query2(int x,int dp){
double ans=0;
while(x){
if(tree1[x]==dp)ans+=tree2[x];
x-=lowbit(x);
}
return ans;
}
void del(int x){
while(x<=n){
tree1[x]=tree2[x]=0;
x+=lowbit(x);
}
return ;
}
bool check(int p,int q,int opt){
if(opt==1)return a[p].y>=a[q].y;
return a[p].y<=a[q].y;
}
void merge_sort(int l,int r,int opt){
int mid=l+((r-l)>>1);
int p=l,q=mid+1;
while(p<=mid&&q<=r){
if(check(p,q,opt)){
update(a[p].z,a[p].dp,a[p].f);
p++;
}
else{
int val=query1(a[q].z)+1;
if(a[q].dp<val){
a[q].dp=val;
a[q].f=query2(a[q].z,val-1);
}
else if(a[q].dp==val)
a[q].f+=query2(a[q].z,val-1);
q++;
}
}
while(q<=r){
int val=query1(a[q].z)+1;
if(a[q].dp<val){
a[q].dp=val;
a[q].f=query2(a[q].z,val-1);
}
else if(a[q].dp==val)
a[q].f+=query2(a[q].z,val-1);
q++;
}
for(int i=l;i<p;i++)del(a[i].z);
return ;
}
void cdq(int l,int r,int opt){
if(l==r)return ;
int mid=l+((r-l)>>1);
cdq(l,mid,opt);
if(opt==1){
sort(a+l,a+mid+1,cmp3);
sort(a+mid+1,a+r+1,cmp3);
}
else{
sort(a+l,a+mid+1,cmp4);
sort(a+mid+1,a+r+1,cmp4);
}
merge_sort(l,r,opt);
if(opt==1)sort(a+mid+1,a+r+1,cmp1);
else sort(a+mid+1,a+r+1,cmp2);
cdq(mid+1,r,opt);
return ;
}
int main(){
n=read();
for(int i=1;i<=n;i++){
a[i].x=read();a[i].y=read();a[i].z=i;
a[i].dp=1;a[i].f=1;
}
sort(a+1,a+n+1,cmp1);
cdq(1,n,1);
int ans=0;
for(int i=1;i<=n;i++){
dp1[a[i].z]=a[i].dp;
f1[a[i].z]=a[i].f;
ans=max(ans,dp1[a[i].z]);
}
printf("%d\n",ans);
for(int i=1;i<=n;i++){
a[i].z=n-a[i].z+1;
a[i].dp=1;a[i].f=1;
}
sort(a+1,a+n+1,cmp2);
cdq(1,n,2);
double k=0;
for(int i=1;i<=n;i++){
dp2[n-a[i].z+1]=a[i].dp;
f2[n-a[i].z+1]=a[i].f;
if(dp2[n-a[i].z+1]==ans)k+=f2[n-a[i].z+1];
}
for(int i=1;i<=n;i++){
if(dp1[i]+dp2[i]==ans+1)
printf("%.5lf ",f1[i]*f2[i]/k);
else
printf("0.00000 ");
}
return 0;
}