This documentation is automatically generated by online-judge-tools/verification-helper
View the Project on GitHub maspypy/library
#include "mod/modular_subset_sum.hpp"
#include "random/base.hpp" #include "mod/modint61.hpp" // Faster Deterministic Modular Subset Sum. arXiv preprint arXiv:2012.06062. // modular subset sum のための、シフト付きセグ木 // shift には 2^(N-k) 時間かかる struct ShiftTree { using M61 = modint61; int delta; int N, n; M61 base; vc<M61> dat; vc<M61> base_pow; ShiftTree(int N, ll base) : delta(0), N(N), n(topbit(N)), base(base) { assert(N == (1 << n)); dat.assign(2 * N, 0); base_pow.assign(n, 1); base_pow[n - 1] = base; FOR_R(i, n - 1) base_pow[i] = base_pow[i + 1] * base_pow[i + 1]; } inline int skew(int k) { return (delta >> (n - k)) & 1; } inline int left(int k, int i) { int mask = (1 << (k + 1)) - 1; return ((2 * i + 0 - skew(k + 1)) & mask) + (1 << (k + 1)); } inline int right(int k, int i) { int mask = (1 << (k + 1)) - 1; return ((2 * i + 1 - skew(k + 1)) & mask) + (1 << (k + 1)); } inline int parent(int k, int i) { int mask = (1 << k) - 1; return (((i + skew(k)) & mask) + (1 << k)) / 2; } inline void update(int k, int i) { M61 b = base_pow[k]; dat[i] = b * dat[left(k, i)] + dat[right(k, i)]; } inline void set(int i, ll x) { i = (i + N - delta) % N + N; dat[i] = x; int k = n; while (i != 1) { i = parent(k, i); --k; update(k, i); } } void shift(int k) { k %= N; if (k < 0) k += N; if (k == 0) return; int j = lowbit(k); delta = (delta + k) % N; FOR_R(k, n - j) { FOR3(i, 1 << k, 2 << k) update(k, i); } } // [a,b) における difference の列挙。output sensitive。 // T のノード i、Q のノード j が (x,y) を指すとする。 static void find_differences(vc<int>& res, ShiftTree& T, ShiftTree& Q, int a, int b, int k, int i, int j, int x, int y) { if (T.dat[i] == Q.dat[j]) return; if (max(a, x) >= min(b, y)) return; if (y == x + 1) { res.eb(x); return; } int z = (x + y) / 2; find_differences(res, T, Q, a, b, k + 1, T.left(k, i), Q.left(k, j), x, z); find_differences(res, T, Q, a, b, k + 1, T.right(k, i), Q.right(k, j), z, y); } static vc<int> diff(ShiftTree& T, ShiftTree& Q, int a, int b) { assert(T.N == Q.N); vc<int> res; find_differences(res, T, Q, a, b, 0, 1, 1, 0, T.N); return res; } }; /* 計算量:(|vals| + mod) * log(mod) ・can(x) または [x] で bool を返す。 ・restore(x) で復元。 コンストラクタには、(mod, vals) をわたす */ template <typename INT> struct Modular_Subset_Sum { int mod; vc<INT>& vals; vc<int> par; Modular_Subset_Sum(int mod, vc<INT>& vals) : mod(mod), vals(vals) { for (auto&& x: vals) assert(0 <= x && x < mod); par.assign(mod, -1); const ll base = RNG(0, (1LL << 61) - 1); int k = 1; while ((1 << k) < 2 * mod) ++k; int L = 1 << k; assert(L >= 2 * mod); ShiftTree T1(L, base); ShiftTree T2(L, base); T1.set(0, 1); T2.set(0, 1); T2.set(L - mod, 1); auto bit_rev = [&](int i) -> int { int x = 0; FOR(k) { x = 2 * x + (i & 1); i >>= 1; } return x; }; vc<vi> IDS(L); FOR(i, len(vals)) { IDS[vals[i]].eb(i); } FOR(i, 1, L) { int x = bit_rev(i); if (len(IDS[x]) == 0) continue; T2.shift(x - T2.delta); for (auto&& idx: IDS[x]) { auto diff = ShiftTree::diff(T1, T2, 0, mod); for (auto&& d: diff) { if (can(d)) continue; par[d] = idx; T1.set(d, 1); T2.set((d + x) % L, 1); T2.set((L + d + x - mod) % L, 1); } } } } bool can(int x) { if (x >= mod) return false; return (x == 0 || par[x] != -1); } bool operator[](int x) { return can(x); } vc<int> restore(int x) { assert(can(x)); vc<int> res; while (x) { int i = par[x]; res.eb(i); x -= vals[i]; if (x < 0) x += mod; } reverse(all(res)); return res; } };
#line 2 "random/base.hpp" u64 RNG_64() { static u64 x_ = u64(chrono::duration_cast<chrono::nanoseconds>(chrono::high_resolution_clock::now().time_since_epoch()).count()) * 10150724397891781847ULL; x_ ^= x_ << 7; return x_ ^= x_ >> 9; } u64 RNG(u64 lim) { return RNG_64() % lim; } ll RNG(ll l, ll r) { return l + RNG_64() % (r - l); } #line 2 "mod/modint61.hpp" struct modint61 { static constexpr u64 mod = (1ULL << 61) - 1; u64 val; constexpr modint61() : val(0ULL) {} constexpr modint61(u32 x) : val(x) {} constexpr modint61(u64 x) : val(x % mod) {} constexpr modint61(int x) : val((x < 0) ? (x + static_cast<ll>(mod)) : x) {} constexpr modint61(ll x) : val(((x %= static_cast<ll>(mod)) < 0) ? (x + static_cast<ll>(mod)) : x) {} static constexpr u64 get_mod() { return mod; } modint61 &operator+=(const modint61 &a) { val = ((val += a.val) >= mod) ? (val - mod) : val; return *this; } modint61 &operator-=(const modint61 &a) { val = ((val -= a.val) >= mod) ? (val + mod) : val; return *this; } modint61 &operator*=(const modint61 &a) { const unsigned __int128 y = static_cast<unsigned __int128>(val) * a.val; val = (y >> 61) + (y & mod); val = (val >= mod) ? (val - mod) : val; return *this; } modint61 operator-() const { return modint61(val ? mod - val : u64(0)); } modint61 &operator/=(const modint61 &a) { return (*this *= a.inverse()); } modint61 operator+(const modint61 &p) const { return modint61(*this) += p; } modint61 operator-(const modint61 &p) const { return modint61(*this) -= p; } modint61 operator*(const modint61 &p) const { return modint61(*this) *= p; } modint61 operator/(const modint61 &p) const { return modint61(*this) /= p; } bool operator<(const modint61 &other) const { return val < other.val; } bool operator==(const modint61 &p) const { return val == p.val; } bool operator!=(const modint61 &p) const { return val != p.val; } modint61 inverse() const { ll a = val, b = mod, u = 1, v = 0, t; while (b > 0) { t = a / b; swap(a -= t * b, b), swap(u -= t * v, v); } return modint61(u); } modint61 pow(ll n) const { assert(n >= 0); modint61 ret(1), mul(val); while (n > 0) { if (n & 1) ret *= mul; mul *= mul, n >>= 1; } return ret; } }; #ifdef FASTIO void rd(modint61 &x) { fastio::rd(x.val); assert(0 <= x.val && x.val < modint61::mod); } void wt(modint61 x) { fastio::wt(x.val); } #endif #line 3 "mod/modular_subset_sum.hpp" // Faster Deterministic Modular Subset Sum. arXiv preprint arXiv:2012.06062. // modular subset sum のための、シフト付きセグ木 // shift には 2^(N-k) 時間かかる struct ShiftTree { using M61 = modint61; int delta; int N, n; M61 base; vc<M61> dat; vc<M61> base_pow; ShiftTree(int N, ll base) : delta(0), N(N), n(topbit(N)), base(base) { assert(N == (1 << n)); dat.assign(2 * N, 0); base_pow.assign(n, 1); base_pow[n - 1] = base; FOR_R(i, n - 1) base_pow[i] = base_pow[i + 1] * base_pow[i + 1]; } inline int skew(int k) { return (delta >> (n - k)) & 1; } inline int left(int k, int i) { int mask = (1 << (k + 1)) - 1; return ((2 * i + 0 - skew(k + 1)) & mask) + (1 << (k + 1)); } inline int right(int k, int i) { int mask = (1 << (k + 1)) - 1; return ((2 * i + 1 - skew(k + 1)) & mask) + (1 << (k + 1)); } inline int parent(int k, int i) { int mask = (1 << k) - 1; return (((i + skew(k)) & mask) + (1 << k)) / 2; } inline void update(int k, int i) { M61 b = base_pow[k]; dat[i] = b * dat[left(k, i)] + dat[right(k, i)]; } inline void set(int i, ll x) { i = (i + N - delta) % N + N; dat[i] = x; int k = n; while (i != 1) { i = parent(k, i); --k; update(k, i); } } void shift(int k) { k %= N; if (k < 0) k += N; if (k == 0) return; int j = lowbit(k); delta = (delta + k) % N; FOR_R(k, n - j) { FOR3(i, 1 << k, 2 << k) update(k, i); } } // [a,b) における difference の列挙。output sensitive。 // T のノード i、Q のノード j が (x,y) を指すとする。 static void find_differences(vc<int>& res, ShiftTree& T, ShiftTree& Q, int a, int b, int k, int i, int j, int x, int y) { if (T.dat[i] == Q.dat[j]) return; if (max(a, x) >= min(b, y)) return; if (y == x + 1) { res.eb(x); return; } int z = (x + y) / 2; find_differences(res, T, Q, a, b, k + 1, T.left(k, i), Q.left(k, j), x, z); find_differences(res, T, Q, a, b, k + 1, T.right(k, i), Q.right(k, j), z, y); } static vc<int> diff(ShiftTree& T, ShiftTree& Q, int a, int b) { assert(T.N == Q.N); vc<int> res; find_differences(res, T, Q, a, b, 0, 1, 1, 0, T.N); return res; } }; /* 計算量:(|vals| + mod) * log(mod) ・can(x) または [x] で bool を返す。 ・restore(x) で復元。 コンストラクタには、(mod, vals) をわたす */ template <typename INT> struct Modular_Subset_Sum { int mod; vc<INT>& vals; vc<int> par; Modular_Subset_Sum(int mod, vc<INT>& vals) : mod(mod), vals(vals) { for (auto&& x: vals) assert(0 <= x && x < mod); par.assign(mod, -1); const ll base = RNG(0, (1LL << 61) - 1); int k = 1; while ((1 << k) < 2 * mod) ++k; int L = 1 << k; assert(L >= 2 * mod); ShiftTree T1(L, base); ShiftTree T2(L, base); T1.set(0, 1); T2.set(0, 1); T2.set(L - mod, 1); auto bit_rev = [&](int i) -> int { int x = 0; FOR(k) { x = 2 * x + (i & 1); i >>= 1; } return x; }; vc<vi> IDS(L); FOR(i, len(vals)) { IDS[vals[i]].eb(i); } FOR(i, 1, L) { int x = bit_rev(i); if (len(IDS[x]) == 0) continue; T2.shift(x - T2.delta); for (auto&& idx: IDS[x]) { auto diff = ShiftTree::diff(T1, T2, 0, mod); for (auto&& d: diff) { if (can(d)) continue; par[d] = idx; T1.set(d, 1); T2.set((d + x) % L, 1); T2.set((L + d + x - mod) % L, 1); } } } } bool can(int x) { if (x >= mod) return false; return (x == 0 || par[x] != -1); } bool operator[](int x) { return can(x); } vc<int> restore(int x) { assert(can(x)); vc<int> res; while (x) { int i = par[x]; res.eb(i); x -= vals[i]; if (x < 0) x += mod; } reverse(all(res)); return res; } };