シンプルな問題だけどかなり難しい。
セグメント木ばっかり。
#include<cstdio> #include<vector> #include<algorithm> #define rep(i,n) for(int i = 0; i < n; i++) #define rrep(i,o,n) for(int i = o; i < n; i++) #define drep(i,n) for(int i = n; i >= 0; i--) #define pb push_back #define lev(x) nest[tra[x]] using namespace std; struct road{int a, b, s, t;}; const int n2 = 17, n3 = n2 + 1, n17 = 1 << n2, n18 = n17 << 1, n19 = n18 << 1; vector<road> e; vector<int> g[100001]; bool used[100001]; int tra[100001], city[100001], rmq[n19], nest[n18], ju[2][n18], ti = n17, ri = n18; void dfs(int x, int d){ if(!used[x]){ used[x] = true; tra[x] = ti; nest[ti] = d; ti ++; city[x] = ri; } rmq[ri] = x; ri ++; d--; rep(i,g[x].size()) if(!used[g[x][i]]){dfs(g[x][i], d); rmq[ri] = x; ri ++;} return; } void rmq_init(){ drep(i,n2){ int ni = 1<<(i+1); rrep(j,1<<i,ni){ int j2 = j << 1; rmq[j] = (lev(rmq[j2]) > lev(rmq[j2+1]) ? rmq[j2] : rmq[j2+1]); } } } void nest_init(){ drep(i,n2-1){ int ni = 1<<(i+1); rrep(j,1<<i,ni){ int j2 = j << 1; nest[j] = max(nest[j2], nest[j2+1]); } } } void add(int st[2], int x, int d){ bool bl = true; while(x > 1){ if(x&1){ if(bl){ rep(i,2) ju[i][x] += st[i]; bl = false; } } else { if(nest[x+1] < d){ if(!bl) rep(i,2) ju[i][x+1] += st[i]; } else { if(bl) rep(i,2) ju[i][x] += st[i]; x++; break; } } x >>= 1; } if(x <= 1){ if(bl) rep(i,2) ju[i][1] += st[i]; return; } while(x < n17){ x <<= 1; if(nest[x] < d){ rep(i,2) ju[i][x] += st[i]; x++; } } return; } int lca(int a, int b, int x, int d){ int d2 = n18/d, l = n18 + (x-d2)*d, r = l + d - 1; if(a > r || b < l) return 0; if(a <= l && b >= r) return rmq[x]; int x2 = x << 1, pl, pr; pl = lca(a,b,x2,d>>1); pr = lca(a,b,x2+1,d>>1); return (lev(pl) > lev(pr)) ? pl : pr; } int sum(int x, int i){ int k = 0; while(x >= 1){ k += ju[i][x]; x >>= 1; } return k; } int main(){ int n, m, p, q, r, st[2], x, y, a, b, par, cx, cy; char c; scanf("%d%d",&n,&m); e.pb((road){0,0,0,0}); rep(i,n-1){ scanf("%d%d ",&p,&q); g[p].pb(q); g[q].pb(p); if(p > q) swap(p,q); e.pb((road){p,q,1,1}); } dfs(1,100001); rmq_init(); nest_init(); st[0] = 1; st[1] = 1; rrep(i,1,n){ a = e[i].a; b = e[i].b; if(lev(a) > lev(b)) swap(a,b); add(st,tra[a],lev(a)); } rep(i,m){ scanf("%c",&c); if(c == 'I'){ scanf("%d%d%d ",&r,&st[0],&st[1]); a = e[r].a; b = e[r].b; st[0] -= e[r].s; st[1] -= e[r].t; e[r].s += st[0]; e[r].t += st[1]; if(lev(a) > lev(b)){swap(st[0],st[1]); swap(a,b);} add(st,tra[a],lev(a)); } else { scanf("%d%d ",&x,&y); cx = city[x]; cy = city[y]; if(cx > cy) swap(cx,cy); par = lca(cx, cy, 1, n18); printf("%d\n",sum(tra[x],0) - sum(tra[par],0) + sum(tra[y],1) - sum(tra[par],1)); } } return 0; }