ちょうどK個選ぶ部分和問題
JAG春コンのB問題の解法が面白かったので僕なりにまとめておこうと思います。
この解説を参考にしました。
問題(要点を抜き出したバージョン)
N 個の数が与えられる。その中からちょうど K 個選んで作れる和を列挙せよ。
制約
与えられる N 個の数の和を M とすると、
- K ≦ N ≦ 600
- M ≦ 180000
普通のDPをすると O(NKM) になって TLE です。(bitsetでもなんとか間に合うらしいです)
こんなにシンプルだし、有名問題なのに、なんと O(NM) のDPがあるらしいです。
解法
まず数を K 個と N-K 個に分けます。これをそれぞれ Aグループ、Bグループと呼ぶことにします。
考えうる全ての選び方は、以下のような操作で実現できます。
1. 最初に Aグループを全て選ぶ(K個選択中)
2. 次に、Bグループの数を1つ追加する(K+1個選択中)
3. Aグループから1つ数を取り除き、2.に戻る(K個選択中)
コードにすると、
bool selected[N]; for (i = 0; i < K; i++) selected[i] = true; for (i = 0, j = K; j < N; j++) { if (j番目の数を選ぶ) { selected[j] = true; while(true) { if (i番目の数を取り除く) { selected[i] = false; break; } i++; } } }
みたいなイメージ。
こんな方針をDPに落とし込むとこうなります。
- dp0[i][j][s] = Aグループを i 個、Bグループを j 個見終わった時点で数を K 個選んでいて和が s になるように出来るか
- dp1[i][j][s] = Aグループを i 個、Bグループを j 個見終わった時点で数を K+1 個選んでいて和が s になるように出来るか
「上のコードでのi,jの値が i,jとなっているときに選んでいる数の和をsに出来るか」みたいなtrue/falseのDPです。(厳密には少し違って、このDPでは、数を追加したり削除したりしないならいつでも自由にi++かj++をできます)
遷移は、「if (b) a = true」を「update(a,b)」と書くことにして、
- update(dp0[i][j+1][s], dp0[i][j][s])
- update(dp0[i+1][j][s], dp0[i][j][s])
- update(dp1[i][j+1][s+x[j]], dp0[i][j][s])
- update(dp1[i+1][j][s], dp1[i][j][s])
- update(dp1[i][j+1][s], dp1[i][j][s])
- update(dp0[i+1][j][s-x[i]], dp1[i][j][s])
最終的に dp[K][N][s] がtrueなら s が作れます。
ただ、このままでは O(N^2 M) なので計算量を落としたいです。
- update(dp0[i][j+1][s], dp0[i][j][s])
- update(dp0[i+1][j][s], dp0[i][j][s])
- update(dp1[i+1][j][s], dp1[i][j][s])
- update(dp1[i][j+1][s], dp1[i][j][s])
という遷移があり、「trueの領域は i,j に対して単調」となっています。例えば、trueの領域は以下の図のような領域になっていたりするでしょう。
そこで、
- dp0[j][s] = さっきのDPテーブルで、dp0[i][j][s]==true となるような最小の i
- dp1[j][s] = さっきのDPテーブルで、dp1[i][j][s]==true となるような最小の i
というDPを考えます。つまり、下図のオレンジの部分を記録するという感じ。
遷移は、「a = min(a,b)」を「update(a,b)」と書くことにして、
の3つは分かりやすい(さっきのDPの遷移の1,3,5番目とほぼ同じ)。あとは、
- for (i = dp1[j][s]; i < K; i++) update(dp0[j][s-x[i]], i)
があればよい。しかし、これだとまだO(NKM)なのでもう一工夫したい。
試す i の範囲が [dp1[j][s], K) から [dp1[j][s], dp1[j-1][s]) になりました。こうすると、計算量は O(NM) となります。
なぜ、試す i の範囲をこういう風に狭めても良いのでしょうか?
i が dp1[j-1][s] 以上の場合は、
- dp1[j][s] → dp0[j][s-x[i]]
という遷移をしなくても、
- dp1[j-1][s] → dp0[j-1][s-x[i]] → dp0[j][s-x[i]]
という遷移によって更新できるはずだから、わざわざ dp1[j-1][s] 以上の i について遷移をしなくてもちゃんと更新されてくれるからです。
ちなみに、試す i というのは下の図のオレンジの部分に相当します。
ここを試せば、それより上の部分は試さなくても良いという感じでしょうか。
この図を見れば計算量が O(NM) になるのも納得できるのではないかと思います。試す場所の合計は、各 s に対して高々 K 箇所しかないからです。
実装
元の問題は他の要素があって大変そうだったので、ストレートにDPだけの練習問題をNPCA Judgeに置かせていただきました。(今judge止まってるっぽいのでデータも上げておきます)
僕のコードです。
#include <iostream> #include <algorithm> #include <numeric> using namespace std; const int MAX_N = 605; const int MAX_M = 180305; const int INF = 1000; int X[MAX_N]; int dp0[MAX_N][MAX_M]; int dp1[MAX_N][MAX_M]; inline void update(int& a, int b) { a = min(a,b);} int main(){ int N, K; // input cin >> N >> K; for (int i = 0; i < N; ++i) cin >> X[i]; int m = accumulate(X, X+N, 0); // initialize for (int j = K; j <= N; ++j) for(int s = 0; s <= m; ++s) { dp0[j][s] = dp1[j][s] = INF; } dp0[K][accumulate(X, X+K, 0)] = 0; // DP for (int j = K; j <= N; ++j) { for (int s = 0; s <= m; ++s) { if (dp1[j][s] != INF) { update(dp1[j+1][s], dp1[j][s]); int l = dp1[j][s], r = K; if (j > 0) update(r, dp1[j-1][s]); for (int i = l; i < r; ++i) { update(dp0[j][s-X[i]], i+1); } } } for (int s = 0; s <= m; ++s) { if (dp0[j][s] != INF) { update(dp0[j+1][s], dp0[j][s]); update(dp1[j+1][s+X[j]], dp0[j][s]); } } } // output string ans; for (int s = 0; s <= m; ++s) { ans += (dp0[N][s] == INF) ? '0' : '1'; } cout << ans << endl; return 0; }
さらに、このコードのようにメモリを再利用をして空間計算量を落とすこともできます。
#include <iostream> #include <algorithm> #include <numeric> using namespace std; const int MAX_N = 605; const int MAX_M = 180305; const int INF = 1000; int X[MAX_N]; int dp0[MAX_M]; int dp1[MAX_M]; int dpr[MAX_M]; inline void update(int& a, int b) { a = min(a,b);} int main(){ int N, K; // input cin >> N >> K; for (int i = 0; i < N; ++i) cin >> X[i]; int m = accumulate(X, X+N, 0); // initialize for(int s = 0; s <= m; ++s) dp0[s] = dp1[s] = INF, dpr[s] = K; dp0[accumulate(X, X+K, 0)] = 0; // DP for (int j = K; j < N; ++j) { for (int s = 0; s <= m; ++s) update(dpr[s], dp1[s]); for (int s = 0; s <= m; ++s) update(dp1[s+X[j]], dp0[s]); for (int s = 0; s <= m; ++s) { if (dp1[s] != INF) { for (int i = dp1[s]; i < dpr[s]; ++i) { update(dp0[s-X[i]], i+1); } } } } // output string ans; for (int s = 0; s <= m; ++s) { ans += (dp0[s] == INF) ? '0' : '1'; } cout << ans << endl; return 0; }
感想
すごい(小並感)
boolのDPはやっぱり効率が悪いんだなぁ、とか思った。