Trie

字典树 Trie

基本了解

字典树是一个对于存储字符较为高效的数据结构

它利用了单词有可能有公共前缀的原理

把每一个字符放在树上的节点存储

对于我们想把cake,cat,code,trie这四个单词放入字典树中

我们会得到如下的一棵树

Trie

标红的节点是该单词的末尾,为了当我们查询单词时分辩可能to,tomorrow,类似这样的前缀单词出现

因为我们存的都是单词

而长度即为高度,理论上规模越大,重复的节点越多,越省空间

基本实现

我们主要有两个函数

ins(string s)负责将字符串插入字典树

find(string s)负责查找字符串

void ins(string s){
	int x=root;
    for(int i=0;i<s.size();i++){
		if(!trie[x][s[i]-'a'])
            trie[x][s[i]-'a']=++cnt;
        x=trie[x][s[i]-'a'];
    }
    tot[x]++;
    return ;
}
bool find(string s){
	int x=root;
    for(int i=0;i<s.size();i++){
		if(!trie[x][s[i]-'a'])return 0;
        x=trie[x][s[i]-'a'];
    }
    if(tot[x])return 1;
    return 0;
}

代码应该很容易理解

例题

于是他错误的点名开始了

板子题,处理一下find函数即可

#include<bits/stdc++.h>
using namespace std;
const int maxn=10005,maxl=55,maxs=35;
int cnt=1;
int ans;
int tree[maxn*maxl][maxs];
int sum[maxn*maxl];
bool flag[maxn*maxl];
int read(){
	int x=0,y=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')y=-1;ch=getchar();}
	while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return x*y;
}
void ins(string s){
	int x=1,len=s.size();
	for(int i=0;i<len;i++){
		int c=s[i]-'a';
		if(!tree[x][c]){
			cnt++;
			tree[x][c]=cnt;
		}
		x=tree[x][c];
	}
	flag[x]=1;
	return ;
}
void find(string s){
	int x=1,len=s.size();
	for(int i=0;i<len;i++){
		int c=s[i]-'a';
		if(!tree[x][c]){
			ans=0;
			break;
		}
		x=tree[x][c];
	}
	sum[x]++;
	if(flag[x]==0||ans==0)printf("WRONG\n");
	else if(sum[x]>1)printf("REPEAT\n");
	else printf("OK\n");
	return ;
}
int main()
{
	int n,m;
	n=read();
	for(int i=1;i<=n;i++){
		string s;
		cin>>s;
		ins(s);
	}
	m=read();
	for(int i=1;i<=m;i++){
		string s;
		cin>>s;
		ans=-1;
		find(s);
	}
	return 0;
}

[HNOI2004]L语言

我们要找到一个文本串能被分割成若干个字典中的单词的最长前缀

对于这样的一个问题我们很容易想到暴力做法

直接在字典树中爆搜整个文本串,从头开始,只要能搜出一个字典中的单词

就可以从在这个地方分割一下,递归下去,搜完之后回来继续循环

看能不能使这个单词更长一些

部分暴力代码如下

void find(int l,int r){
	if(l>r)return ;
	int x=root;
	int pos=0;
	for(int i=l;i<=r;i++){
		if(tot[x]){
			ans=max(ans,i);
			find(i,r);//可以从这里分出一个单词继续往后搜
		}
		if(tree[x][st[i]-'a'])
			x=tree[x][st[i]-'a'];
		else return ;
	}
	if(tot[x])ans=max(ans,r+1);
	return ;
}

实际上我们已经能拿到74分的好成绩

我们只需对暴力稍加优化,加上记忆化就可以通过本题

//f[i]记录这个文本串的第i个位置到size-1的位置能得到的满足条件的最长前缀的位置
int find(int l,int r){
	int w=root;
	int cur=l;
	for(int i=l;i<=r;i++){
		if(flag[w]){
			if(f[i])
				cur=max(cur,f[i]);//搜过了,直接拿上数据
			else
				cur=max(cur,find(i,r));//同暴力
		}
		if(tree[w][st[i]-'a'])
			w=tree[w][st[i]-'a'];
		else return f[l]=cur;//在当前分割单词的方案下没法更长了
	}
	if(flag[w])cur=r+1;//搜到最后了
	return f[l]=cur;
}

完整代码就不放了

[USACO12DEC]第一!First!

对于a,b两串的字典序比较

我们有两种情况

1.找到a,b第一个不同的位置比较即可

2.发现a是b的前缀或相反

题目希望我们判断给出的所有串中

能否通过改变字典序,使某些串成为所有串中字典序最小的串

我们需要找到某两个串第一个不同的位置

很容易想到字典树树上做类似lca的操作

把所有单词插入字典树中

对于我们当前判断的单词S走的路径是X-Y

那么由于字典树中的节点都是我们需要判断的单词

若我们想要S成为字典序最小的,根据之前的结论

我们要使图中X到A,B,C这三条边所代表的字母的字典序改成大于X-Y的字典序

对于每一条被判断的单词边我们都做如上处理

一定能保证满足题目所需条件

而对于判断是否可行我们可以建一个图来实现

对于每一个字母char1字典序需大于char2则char1向char2连一条有向边

最后判下环就可以知道了

可以着重看一下find函数

完整代码

#include<bits/stdc++.h>
using namespace std;
const int maxn=30005;
struct Edge{
	int to;
	int nxt;
}e[26*26];
int len;
int head[26];
int root,cnt;
string st[maxn];
int trie[300005][26];
bool flag[300005];
int rd[26];
int ANS[maxn];
queue<int>q;
int read(){
	int x=0,y=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')y=-1;ch=getchar();}
	while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return x*y;
}
void add(int u,int v){
	if(u==v)return ;
	e[len].to=v;
	e[len].nxt=head[u];
	head[u]=len++;
	rd[v]++;
	return ;
}
void ins(string s){
	int w=root;
	for(int i=0;i<s.size();i++){
		if(!trie[w][s[i]-'a'])
			trie[w][s[i]-'a']=++cnt;
		w=trie[w][s[i]-'a'];
	}
	flag[w]=1;
	return ;
}
int find(string s){
	int w=root;
	len=0;
	memset(head,-1,sizeof(head));
	memset(rd,0,sizeof(rd));
	for(int i=0;i<s.size();i++){
		if(flag[w])return 0;
		for(int j=0;j<26;j++){
			if(j==s[i]-'a'||(!trie[w][j]))continue;
			add(j,s[i]-'a');//建边
		}
		w=trie[w][s[i]-'a'];
	}
	int tot=0;
	for(int i=0;i<26;i++)
		if(!rd[i]){
			q.push(i);
			tot++;
		}
	while(!q.empty()){
		int k=q.front();
		q.pop();
		for(int i=head[k];i!=-1;i=e[i].nxt){
			int tmp=e[i].to;
			rd[tmp]--;
			if(!rd[tmp]){
				q.push(tmp);
				tot++;
			}
		}
	}
    //拓扑判环
	if(tot==26)return 1;
	return 0;
}
int main(){
	int n;
	n=read();
	for(int i=1;i<=n;i++){
		cin>>st[i];
		ins(st[i]);
	}
	int ans=0;
	for(int i=1;i<=n;i++)
		if(find(st[i]))
			ANS[++ans]=i;
	printf("%d\n",ans);
	for(int i=1;i<=ans;i++)
		cout<<st[ANS[i]]<<endl;
	return 0;
}

[JSOI2009]电子字典

查找对于一个字符串s

一个字典里的单词与s编辑距离为1的数目

编辑距离为指将串a通过如下三种操作变成串b的最小次数

1.删除串中某个位置的字母;

2.添加一个字母到串中某个位置;

3.替换串中某一位置的一个字母为另一个字母;

判断s是否为字典里的单词很好搞

判断一个单词是否与询问单词编辑距离为1我们可以这么考虑

因为编辑距离只允许为1

所以两单词从头到尾的大部分一定相同

我们可以把字典放到trie里

直接去查这个询问串

如果这一步还可以走,那么就分别执行三个操作一试

不可一走了直接返回即可

还可以直接枚举要改哪个地方,用一个数组标记

查的时候按数组的标记做一些小改动即可

#include<bits/stdc++.h>
using namespace std;
const int maxn=10005,maxm=10005;
int root,cnt;
int trie[maxn*25][26];
bool flag[maxn*25],bj[maxn*25];
int ban[25],add[25],chan[25];
int read(){
	int x=0,y=1;
	char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')y=-1;ch=getchar();}
	while(ch>='0'&&ch<='9')x=(x<<3)+(x<<1)+(ch^48),ch=getchar();
	return x*y;
}
void ins(string s){
	int w=root;
	for(int i=0;i<s.size();i++){
		if(!trie[w][s[i]-'a'])
			trie[w][s[i]-'a']=++cnt;
		w=trie[w][s[i]-'a'];
	}
	flag[w]=1;
	return ;
}
int find(string s){
	int w=root;
	for(int i=0;i<s.size();i++){
		if(ban[i])continue;//这个节点被删除了
		if(add[i]>-1){
			if(!trie[w][add[i]])return 0;
			w=trie[w][add[i]];//这个地方添加了一个add[i]+'a'这个字母
		}
		if(chan[i+1]>-1){
			if(!trie[w][chan[i+1]])return 0;
			w=trie[w][chan[i+1]];//把这个字母修改为chan[i+1]+'a'
			continue;
		}
		if(!trie[w][s[i]-'a'])return 0;
		w=trie[w][s[i]-'a'];
	}
	if(add[s.size()]>-1){
		if(!trie[w][add[s.size()]])return 0;
		w=trie[w][add[s.size()]];
	}
    //最后一个位置后面也可以添加
	if(flag[w]&&(!bj[w])){bj[w]=1;return 1;}
    //不重复算
	return 0;
}
int main(){
	int n,m;
	n=read();m=read();
	for(int i=1;i<=n;i++){
		string st;
		cin>>st;
		ins(st);
	}
	memset(add,-1,sizeof(add));
	memset(chan,-1,sizeof(chan));
	for(int i=1;i<=m;i++){
		string st;
		cin>>st;
		memset(bj,0,sizeof(bj));
		if(find(st)){printf("-1\n");continue;}
		int ans=0;
		for(int j=0;j<st.size();j++){
			ban[j]=1;
			ans+=find(st);
			ban[j]=0;
		}
		for(int j=0;j<=st.size();j++){
			for(int k=0;k<26;k++){
				add[j]=k;
				ans+=find(st);
				add[j]=-1;
				if(j>0){
					chan[j]=k;
					ans+=find(st);
					chan[j]=-1;
				}
			}
		}
		printf("%d\n",ans);
	}
	return 0;
}