This documentation is automatically generated by online-judge-tools/verification-helper
View the Project on GitHub maspypy/library
#include "other/count_seq_with_fixed_xor_value.hpp"
#include "mod/modint.hpp" // [0, LIM)^N のうちで、xor = X となるものの個数 template <typename mint> mint count_seq_with_fixed_xor(ll N, ll LIM, ll X) { assert(LIM >= 1); --LIM; // closed if (LIM == 0) return (X == 0 ? 1 : 0); int LOG = topbit(LIM) + 1; if (X >> LOG) return 0; mint res = 0; bool ok = 1; FOR_R(k, LOG) { int LIM1 = LIM >> k & 1; int X1 = X >> k & 1; if (LIM1) { ll mk = LIM - (LIM >> k << k); mint a = mint(2).pow(k), b = mk + 1; tie(a, b) = mp(a + b, a - b); a = a.pow(N), b = b.pow(N); tie(a, b) = mp(a + b, a - b); a *= inv<mint>(2), b *= inv<mint>(2); mint now = (X1 ? b : a); if ((N & 1) == X1) now -= mint(mk + 1).pow(N); now /= mint(2).pow(k); res += now; } if (LIM1 * (N & 1) != X1) { ok = 0; break; } } if (ok) res += mint(1); return res; } // [0, LIM)^N のうちで、xor = X となるものの個数。N = 0,1,...,nmax template <typename mint> vc<mint> count_seq_with_fixed_xor_iota(ll nmax, ll LIM, ll X) { assert(LIM >= 1); --LIM; // closed vc<mint> res(nmax + 1); if (LIM == 0) { if (X == 0) fill(all(res), mint(1)); return res; } int LOG = topbit(LIM) + 1; if (X >> LOG) return res; vc<bool> ok(nmax + 1, 1); mint x2 = inv<mint>(2); mint px2 = x2.pow(LOG); FOR_R(k, LOG) { px2 += px2; int LIM1 = LIM >> k & 1; int X1 = X >> k & 1; if (LIM1) { ll mk = LIM - (LIM >> k << k); mint a = mint(2).pow(k), b = mk + 1; tie(a, b) = mp(a + b, a - b); mint pa = 1, pb = 1, pc = 1; FOR(n, nmax + 1) { if (ok[n]) { mint x = (X1 ? (pa - pb) : (pa + pb)) * x2; if ((n & 1) == X1) x -= pc; res[n] += x * px2; } pa *= a, pb *= b, pc *= mint(mk + 1); } } FOR(n, nmax + 1) { if (LIM1 * (n & 1) != X1) { ok[n] = 0; } } } FOR(n, nmax + 1) if (ok[n]) res[n] += mint(1); return res; }
#line 2 "mod/modint_common.hpp" struct has_mod_impl { template <class T> static auto check(T &&x) -> decltype(x.get_mod(), std::true_type{}); template <class T> static auto check(...) -> std::false_type; }; template <class T> class has_mod : public decltype(has_mod_impl::check<T>(std::declval<T>())) {}; template <typename mint> mint inv(int n) { static const int mod = mint::get_mod(); static vector<mint> dat = {0, 1}; assert(0 <= n); if (n >= mod) n %= mod; while (len(dat) <= n) { int k = len(dat); int q = (mod + k - 1) / k; dat.eb(dat[k * q - mod] * mint::raw(q)); } return dat[n]; } template <typename mint> mint fact(int n) { static const int mod = mint::get_mod(); assert(0 <= n && n < mod); static vector<mint> dat = {1, 1}; while (len(dat) <= n) dat.eb(dat[len(dat) - 1] * mint::raw(len(dat))); return dat[n]; } template <typename mint> mint fact_inv(int n) { static vector<mint> dat = {1, 1}; if (n < 0) return mint(0); while (len(dat) <= n) dat.eb(dat[len(dat) - 1] * inv<mint>(len(dat))); return dat[n]; } template <class mint, class... Ts> mint fact_invs(Ts... xs) { return (mint(1) * ... * fact_inv<mint>(xs)); } template <typename mint, class Head, class... Tail> mint multinomial(Head &&head, Tail &&... tail) { return fact<mint>(head) * fact_invs<mint>(std::forward<Tail>(tail)...); } template <typename mint> mint C_dense(int n, int k) { static vvc<mint> C; static int H = 0, W = 0; auto calc = [&](int i, int j) -> mint { if (i == 0) return (j == 0 ? mint(1) : mint(0)); return C[i - 1][j] + (j ? C[i - 1][j - 1] : 0); }; if (W <= k) { FOR(i, H) { C[i].resize(k + 1); FOR(j, W, k + 1) { C[i][j] = calc(i, j); } } W = k + 1; } if (H <= n) { C.resize(n + 1); FOR(i, H, n + 1) { C[i].resize(W); FOR(j, W) { C[i][j] = calc(i, j); } } H = n + 1; } return C[n][k]; } template <typename mint, bool large = false, bool dense = false> mint C(ll n, ll k) { assert(n >= 0); if (k < 0 || n < k) return 0; if constexpr (dense) return C_dense<mint>(n, k); if constexpr (!large) return multinomial<mint>(n, k, n - k); k = min(k, n - k); mint x(1); FOR(i, k) x *= mint(n - i); return x * fact_inv<mint>(k); } template <typename mint, bool large = false> mint C_inv(ll n, ll k) { assert(n >= 0); assert(0 <= k && k <= n); if (!large) return fact_inv<mint>(n) * fact<mint>(k) * fact<mint>(n - k); return mint(1) / C<mint, 1>(n, k); } // [x^d](1-x)^{-n} template <typename mint, bool large = false, bool dense = false> mint C_negative(ll n, ll d) { assert(n >= 0); if (d < 0) return mint(0); if (n == 0) { return (d == 0 ? mint(1) : mint(0)); } return C<mint, large, dense>(n + d - 1, d); } #line 3 "mod/modint.hpp" template <int mod> struct modint { static constexpr u32 umod = u32(mod); static_assert(umod < u32(1) << 31); u32 val; static modint raw(u32 v) { modint x; x.val = v; return x; } constexpr modint() : val(0) {} constexpr modint(u32 x) : val(x % umod) {} constexpr modint(u64 x) : val(x % umod) {} constexpr modint(u128 x) : val(x % umod) {} constexpr modint(int x) : val((x %= mod) < 0 ? x + mod : x){}; constexpr modint(ll x) : val((x %= mod) < 0 ? x + mod : x){}; constexpr modint(i128 x) : val((x %= mod) < 0 ? x + mod : x){}; bool operator<(const modint &other) const { return val < other.val; } modint &operator+=(const modint &p) { if ((val += p.val) >= umod) val -= umod; return *this; } modint &operator-=(const modint &p) { if ((val += umod - p.val) >= umod) val -= umod; return *this; } modint &operator*=(const modint &p) { val = u64(val) * p.val % umod; return *this; } modint &operator/=(const modint &p) { *this *= p.inverse(); return *this; } modint operator-() const { return modint::raw(val ? mod - val : u32(0)); } modint operator+(const modint &p) const { return modint(*this) += p; } modint operator-(const modint &p) const { return modint(*this) -= p; } modint operator*(const modint &p) const { return modint(*this) *= p; } modint operator/(const modint &p) const { return modint(*this) /= p; } bool operator==(const modint &p) const { return val == p.val; } bool operator!=(const modint &p) const { return val != p.val; } modint inverse() const { int a = val, b = mod, u = 1, v = 0, t; while (b > 0) { t = a / b; swap(a -= t * b, b), swap(u -= t * v, v); } return modint(u); } modint pow(ll n) const { assert(n >= 0); modint ret(1), mul(val); while (n > 0) { if (n & 1) ret *= mul; mul *= mul; n >>= 1; } return ret; } static constexpr int get_mod() { return mod; } // (n, r), r は 1 の 2^n 乗根 static constexpr pair<int, int> ntt_info() { if (mod == 120586241) return {20, 74066978}; if (mod == 167772161) return {25, 17}; if (mod == 469762049) return {26, 30}; if (mod == 754974721) return {24, 362}; if (mod == 880803841) return {23, 211}; if (mod == 943718401) return {22, 663003469}; if (mod == 998244353) return {23, 31}; if (mod == 1004535809) return {21, 836905998}; if (mod == 1045430273) return {20, 363}; if (mod == 1051721729) return {20, 330}; if (mod == 1053818881) return {20, 2789}; return {-1, -1}; } static constexpr bool can_ntt() { return ntt_info().fi != -1; } }; #ifdef FASTIO template <int mod> void rd(modint<mod> &x) { fastio::rd(x.val); x.val %= mod; // assert(0 <= x.val && x.val < mod); } template <int mod> void wt(modint<mod> x) { fastio::wt(x.val); } #endif using modint107 = modint<1000000007>; using modint998 = modint<998244353>; #line 2 "other/count_seq_with_fixed_xor_value.hpp" // [0, LIM)^N のうちで、xor = X となるものの個数 template <typename mint> mint count_seq_with_fixed_xor(ll N, ll LIM, ll X) { assert(LIM >= 1); --LIM; // closed if (LIM == 0) return (X == 0 ? 1 : 0); int LOG = topbit(LIM) + 1; if (X >> LOG) return 0; mint res = 0; bool ok = 1; FOR_R(k, LOG) { int LIM1 = LIM >> k & 1; int X1 = X >> k & 1; if (LIM1) { ll mk = LIM - (LIM >> k << k); mint a = mint(2).pow(k), b = mk + 1; tie(a, b) = mp(a + b, a - b); a = a.pow(N), b = b.pow(N); tie(a, b) = mp(a + b, a - b); a *= inv<mint>(2), b *= inv<mint>(2); mint now = (X1 ? b : a); if ((N & 1) == X1) now -= mint(mk + 1).pow(N); now /= mint(2).pow(k); res += now; } if (LIM1 * (N & 1) != X1) { ok = 0; break; } } if (ok) res += mint(1); return res; } // [0, LIM)^N のうちで、xor = X となるものの個数。N = 0,1,...,nmax template <typename mint> vc<mint> count_seq_with_fixed_xor_iota(ll nmax, ll LIM, ll X) { assert(LIM >= 1); --LIM; // closed vc<mint> res(nmax + 1); if (LIM == 0) { if (X == 0) fill(all(res), mint(1)); return res; } int LOG = topbit(LIM) + 1; if (X >> LOG) return res; vc<bool> ok(nmax + 1, 1); mint x2 = inv<mint>(2); mint px2 = x2.pow(LOG); FOR_R(k, LOG) { px2 += px2; int LIM1 = LIM >> k & 1; int X1 = X >> k & 1; if (LIM1) { ll mk = LIM - (LIM >> k << k); mint a = mint(2).pow(k), b = mk + 1; tie(a, b) = mp(a + b, a - b); mint pa = 1, pb = 1, pc = 1; FOR(n, nmax + 1) { if (ok[n]) { mint x = (X1 ? (pa - pb) : (pa + pb)) * x2; if ((n & 1) == X1) x -= pc; res[n] += x * px2; } pa *= a, pb *= b, pc *= mint(mk + 1); } } FOR(n, nmax + 1) { if (LIM1 * (n & 1) != X1) { ok[n] = 0; } } } FOR(n, nmax + 1) if (ok[n]) res[n] += mint(1); return res; }