全局加一(维护异或和)
这个不太好表述啊~
01-trie数
是指字符集为 $\{0,1\}$ 的 Trie 树。
01-trie树
可以用来维护一堆数字的异或和,支持修改(删除+重新插入),和全部维护值加一。
<!-- more -->
如果要维护异或和,需要按值从低位到高位建立trie
。
一个约定:文中说当前节点往上指当前节点到根这条路径,当前节点往下指当前结点的子树。
插入&删除
如果要维护异或和,我们只需要知道某一位上0
和1
个数的奇偶性即可,也就是对于数字1
来说,当且仅当这一位上数字1
的个数为奇数时,这一位上的数字才是1
。
对于每一个节点,我们需要记录以下三个量:
ch[o][0/1]
指节点o
的两个儿子,ch[o][0]
指下一位是0
,同理ch[o][1]
指下一位是0
。w[o]
指节点o
到其父亲节点这条边上数值的数量(权值)。每插入一个数字x
,x
二进制拆分后在trie
树上 路径的权值都会+1
。xorv[o]
指以o
为根的子树维护的异或和。
具体维护结点的代码如下所示。
void maintain(int o){
w[o] = xorv[o] = 0;
if(ch[o][0]){ w[o] += w[ch[o][0]]; xorv[o] ^= xorv[ch[o][0]] << 1; }
if(ch[o][1]){ w[o] += w[ch[o][1]]; xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1); }
//w[o] = w[o] & 1;
//只需知道奇偶性即可,不需要具体的值。当然这句话删掉也可以,因为上文就只利用了他的奇偶性。
}
插入和删除的代码非常相似。
需要注意的地方就是:
-
这里的
MAXH
指trie
的深度,也就是强制让每一个叶子节点到根的距离为MAXH
。对于一些比较小的值,可能有时候不需要建立这么深(例如:如果插入数字4
,分解成二进制后为100
, 从根开始插入001
这三位即可),但是我们强制插入MAXH
位。这样做的目的是为了便于全局+1
时处理进位。例如:如果原数字是3
(11
),++之后变成4
(100
),如果当初插入3
时只插入了2
位,那这里的进位就没了。 -
插入和删除,只需要修改叶子节点的
w[]
即可,在回溯的过程中一路维护即可。
namespace trie{
const int MAXH = 21;
int ch[_ * (MAXH + 1)][2], w[_ * (MAXH + 1)], xorv[_ * (MAXH + 1)];
int tot = 0;
int mknode(){ ++tot; ch[tot][1] = ch[tot][0] = w[tot] = xorv[tot] = 0; return tot;}
void maintain(int o){
w[o] = xorv[o] = 0;
if(ch[o][0]){ w[o] += w[ch[o][0]]; xorv[o] ^= xorv[ch[o][0]] << 1; }
if(ch[o][1]){ w[o] += w[ch[o][1]]; xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1); }
w[o] = w[o] & 1;
}
void insert(int &o, int x, int dp){
if(!o) o = mknode();
if(dp > MAXH) return (void)(w[o] ++);
insert(ch[o][ x&1 ], x >> 1, dp + 1);
maintain(o);
}
}
全局加一
void addall(int o){
swap(ch[o][0], ch[o][1]);
if(ch[o][0]) addall(ch[o][0]);
maintain(o);
}
不知道你能不能直接看懂呢?
我们思考一下二进制意义下+1
是如何操作的。
我们只需要从低位到高位开始找第一个出现的0
,把它变成1
,然后这个位置后面的1
都变成0
即可。
下面给出几个例子感受一下。(括号内的数字表示其对应的十进制数字)
1000 (10) + 1 = 1001 (11)
10011 (19) + 1 = 10100 (20)
11111 (31) + 1 = 100000 (32)
10101 (21) + 1 = 10110 (22)
100000000111111(16447) + 1 = 100000001000000(16448)
回顾一下w[o]
的定义:w[o]
指节点o
到其父亲节点这条边上数值的数量(权值)。
有没有感觉这个定义有点怪呢?如果在父亲结点存储到两个儿子的这条边的边权也许会更接近于习惯。但是在这里,在交换左右儿子的时候,在儿子结点存储到父亲这条边的距离,显然更加方便。
01-trie
合并
指的是将上述的两个01-trie
进行合并,同时合并维护的信息。
我来编写此文字的初衷就是,合并01-trie
的文字好像比较少?其实合并01-trie
和合并线段树的思路非常相似。可以搜索合并线段树
来学习如何合并01-trie
。
其实合并trie
非常简单,就是考虑一下我们有一个int marge(int a, int b)
函数,这个函数传入两个trie
树位于同一相对位置的结点编号,然后合并完成后返回合并完成的结点编号。
考虑怎么实现?
int marge(int a, int b){
if(!a) return b; // 如果a没有这个位置上的结点,返回b
if(!b) return a; // 如果b没有这个位置上的结点,返回a
// 如果a, b都健在,那就把b的信息合并到a上,然后递归操作。
// 如果需要的合并是将a, b合并到一棵新树上,这里可以新建结点,然后进行合并。这里的代码实现仅仅是将b的信息合并到a上。
w[a] = w[a] + w[b];
xorv[a] ^= xorv[b];
//不要使用maintain,maintain是根据a的两个儿子的数值进行信息合并,而这里需要a b两个节点进行信息合并
ch[a][0] = marge(ch[a][0], ch[b][0]);
ch[a][1] = marge(ch[a][1], ch[b][1]);
return a;
}
顺便说一句,其实trie
都可以合并,换句话说,trie
合并不仅仅限于01-trie
。
int n;
int V[_];
int debug = 0;
int cnt = 0;
namespace trie{
const int MAXH = 21;
int ch[_ * (MAXH + 1)][2], w[_ * (MAXH + 1)], xorv[_ * (MAXH + 1)];
int tot = 0;
int mknode(){ ++tot; ch[tot][1] = ch[tot][0] = w[tot] = xorv[tot] = 0; return tot;}
void maintain(int o){
w[o] = xorv[o] = 0;
if(ch[o][0]){ w[o] += w[ch[o][0]]; xorv[o] ^= xorv[ch[o][0]] << 1; }
if(ch[o][1]){ w[o] += w[ch[o][1]]; xorv[o] ^= (xorv[ch[o][1]] << 1) | (w[ch[o][1]] & 1); }
w[o] = w[o] & 1;
}
void insert(int &o, int x, int dp){
if(!o) o = mknode();
if(dp > MAXH) return (void)(w[o] ++);
insert(ch[o][ x&1 ], x >> 1, dp + 1);
maintain(o);
}
// #define errnum 90189
int marge(int a, int b){
// if(debug == 90189) cerr << "a = " << a << " b = " << b << endl;
cnt++;
if(!a) return b;
if(!b) return a;
w[a] = w[a] + w[b];
xorv[a] ^= xorv[b];
//不要使用maintain,maintain是根据a的两个儿子的数值进行信息合并,而这里需要a b两个节点进行信息合并
ch[a][0] = marge(ch[a][0], ch[b][0]);
ch[a][1] = marge(ch[a][1], ch[b][1]);
return a;
}
void addall(int o){
swap(ch[o][0], ch[o][1]);
if(ch[o][0]) addall(ch[o][0]);
maintain(o);
}
}
int rt[_];
long long Ans = 0;
vector<int>E[_];
void dfs0(int o){
for(int i = 0;i < E[o].size();i++){
int node = E[o][i];
dfs0(node);
rt[o] = trie::marge(rt[o], rt[node]);
}
trie::addall(rt[o]);
trie::insert(rt[o], V[o], 0);
Ans += trie::xorv[rt[o]];
}
int main()
{
#ifdef LOCAL_JUDGE
// freopen("out.txt", "w", stdout);
#endif
// freopen("in.in", "r", stdin);
clock_t c1 = clock();
n = read();
for(int i = 1;i <= n;i++) V[i] = read();
for(int i = 2;i <= n;i++) E[read()].push_back(i);
dfs0(1);
printf("%lld", Ans);
std::cerr << "\n\nTime: " << clock() - c1 << " ms" << std::endl;
return 0;
}