写在前面
树形DP一般就是$f[i][j]$表示以第$i$个结点为根的子树中,$j$个儿子或者价值为$j$,在这道题里面,$j$表示在这个子树中选取$j$个叶子节点。
在AC这道题之前,看过无数题解,这句话是笔者认为写的最精彩的。
状态
考虑一钟状态设计方式$f[i][j]$表示在以$i$为根的子树中,选取满足$j$个“叶子”,获得的最大价值。
方程
$$f[i][j] = \max_{k = 0}^{k = Size_{exNode}}(f[i][j - k] + f[exNode][k]) - W_x$$
- $f[i][j]$表示在以$i$为根的子树中,选择子树中的$j$个叶子结点获得的最大价值。
- $Size_k$表示子树$k$的大小。
- $exNode$表示$i$的儿子结点。
- $W_x$表示边权。
就是说,选择一部分的$j$(选择满足叶子结点数)分配给$exNode$ $k$个,其余分配给其他结点。这样获得的最大利润。
思考,如果让儿子结点$exNode$为根的子树中,“满足”$k$个叶子结点(用户端)的需求,那么会获得$f[exNode][k]$的利润,再减去信号从$i$传到$exNode$的花费$W_x$,再加上其他结点获得的利润$f[i][j - k]$,就能得到让$exNode$,中的$k$个叶子结点满足的利润。
为什么是分组背包
分组背包指在很多组中只能选一种物品,获得的最大利润
回到本题,组就是每一棵子树,物品其实并不是叶子结点,而是叶子节点的满足个数。
比如,在以结点$u$为根的子树中(第$u$组),有一些物品,这些物品是:
- 满足以结点$u$为根的子树中,$1$个叶子结点的需求。
- 满足以结点$u$为根的子树中,$2$个叶子结点的需求。
- 满足以结点$u$为根的子树中,$3$个叶子结点的需求。
- 满足以结点$u$为根的子树中,$4$个叶子结点的需求。
- 满足以结点$u$为根的子树中,$Size_v$个叶子结点的需求。
每一组的各个元素都是互相矛盾只能选一种的。
这就是分组背包。
对于每一个结点都要进行分组背包。
对于扩展结点时,假如当前结点时$now$, 扩展节点是$exNode$。
应该使用$Size_{exNode}$的每一个可能满足的取值更新$f[now][j]$(取最大值(利润一定要最大啊))
int &sum = lef[now];//表示子树大小
for(_R int i = head[now];i;i = edge[i].nxt) {
int exNode = edge[i].node;
int t = dfs(exNode); sum += t;
for(_R int j = lef[now];j >= 0;j--) {
for(_R int k = 0;k <= lef[exNode];k++) {
if(j - k < 0) continue;
dp[now][j] = max(dp[now][j], dp[now][j - k] + dp[exNode][k] - edge[i].w);
}
关于for(int j = lef[now];j >= 0;j--)
倒序枚举的原因同0-1背包
滚动数组优化后的方法。
防止多次转移,也就是防止在同一组里选重复的元素。
即, 保证每一次$dp[now][j - k]$的取值都是上一个儿子结点更新完答案剩下的,而不是当前美剧道德儿子结点更新答案之后的指。
最后的最后
题目没有问能获得的最大利润,而是希望更多的用户能够用上电视,于是只要电视台不亏本,就可以满足。
从打到小枚举$f[1][i]$找一个最大的$i$使$f[1][i] >= 0$
Codes
#include<cstdio>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#define _R register
#define inf 0x7fffffff
using namespace std;
const int _ = 3100;
inline int read()
{
char c = getchar(); int sign = 1; int x = 0;
while(c > '9' || c < '0') { if(c=='-')sign = -1; c = getchar(); }
while(c <= '9' && c >= '0') { x *= 10; x += c - '0'; c = getchar(); }
return x * sign;
}
int n, m;
int nodeVal[_];
int dp[_][_];
int head[_];
struct edge{
int node;
int w;
int nxt;
}edge[_];
int tot = 0;
void add(int u, int v, int w){
edge[++tot].nxt = head[u];
head[u] = tot;
edge[tot].node = v;
edge[tot].w = w;
}
int lef[_];
int dfs(int now)
{
if(now >= n - m + 1) return dp[now][1] = nodeVal[now], lef[now] = 1;
int &sum = lef[now];
for(_R int i = head[now];i;i = edge[i].nxt) {
int exNode = edge[i].node;
int t = dfs(exNode); sum += t;
for(_R int j = lef[now];j >= 0;j--) {
for(_R int k = 0;k <= lef[exNode];k++) {
if(j - k < 0) continue;
dp[now][j] = max(dp[now][j], dp[now][j - k] + dp[exNode][k] - edge[i].w);
}
}
}
return lef[now];
}
int main()
{
n = read(), m = read();
for(_R int i = 1;i <= n - m;i++){
int k = read();
for(_R int j = 1;j <= k;j++){
int A = read();
int C = read();
add(i, A, C);
}
}
for(_R int i = n - m + 1;i <= n;i++){
nodeVal[i] = read();
}
memset(dp, -100, sizeof(dp));
for(_R int i = 1;i <= n;i++) dp[i][0] = 0;
dfs(1);
for(_R int i = m;i >= 0;i--)
if(dp[1][i] >= 0){
return printf("%d\n", i), 0;
}
return 0;
}