library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub maspypy/library

:heavy_check_mark: mod/modular_subset_sum.hpp

Depends on

Verified with

Code

#include "random/base.hpp"

#include "mod/modint61.hpp"


// Faster Deterministic Modular Subset Sum. arXiv preprint arXiv:2012.06062.

// modular subset sum のための、シフト付きセグ木

// shift には 2^(N-k) 時間かかる

struct ShiftTree {
  using M61 = modint61;
  int delta;
  int N, n;
  M61 base;
  vc<M61> dat;
  vc<M61> base_pow;

  ShiftTree(int N, ll base) : delta(0), N(N), n(topbit(N)), base(base) {
    assert(N == (1 << n));
    dat.assign(2 * N, 0);

    base_pow.assign(n, 1);
    base_pow[n - 1] = base;
    FOR_R(i, n - 1) base_pow[i] = base_pow[i + 1] * base_pow[i + 1];
  }

  inline int skew(int k) { return (delta >> (n - k)) & 1; }

  inline int left(int k, int i) {
    int mask = (1 << (k + 1)) - 1;
    return ((2 * i + 0 - skew(k + 1)) & mask) + (1 << (k + 1));
  }

  inline int right(int k, int i) {
    int mask = (1 << (k + 1)) - 1;
    return ((2 * i + 1 - skew(k + 1)) & mask) + (1 << (k + 1));
  }

  inline int parent(int k, int i) {
    int mask = (1 << k) - 1;
    return (((i + skew(k)) & mask) + (1 << k)) / 2;
  }

  inline void update(int k, int i) {
    M61 b = base_pow[k];
    dat[i] = b * dat[left(k, i)] + dat[right(k, i)];
  }

  inline void set(int i, ll x) {
    i = (i + N - delta) % N + N;
    dat[i] = x;
    int k = n;
    while (i != 1) {
      i = parent(k, i);
      --k;
      update(k, i);
    }
  }

  void shift(int k) {
    k %= N;
    if (k < 0) k += N;
    if (k == 0) return;
    int j = lowbit(k);
    delta = (delta + k) % N;
    FOR_R(k, n - j) { FOR3(i, 1 << k, 2 << k) update(k, i); }
  }

  // [a,b) における difference の列挙。output sensitive。

  // T のノード i、Q のノード j が (x,y) を指すとする。

  static void find_differences(vc<int>& res, ShiftTree& T, ShiftTree& Q, int a,
                               int b, int k, int i, int j, int x, int y) {
    if (T.dat[i] == Q.dat[j]) return;
    if (max(a, x) >= min(b, y)) return;
    if (y == x + 1) {
      res.eb(x);
      return;
    }
    int z = (x + y) / 2;
    find_differences(res, T, Q, a, b, k + 1, T.left(k, i), Q.left(k, j), x, z);
    find_differences(res, T, Q, a, b, k + 1, T.right(k, i), Q.right(k, j), z,
                     y);
  }

  static vc<int> diff(ShiftTree& T, ShiftTree& Q, int a, int b) {
    assert(T.N == Q.N);
    vc<int> res;
    find_differences(res, T, Q, a, b, 0, 1, 1, 0, T.N);
    return res;
  }
};

/*
計算量:(|vals| + mod) * log(mod)
・can(x) または [x] で bool を返す。
・restore(x) で復元。
コンストラクタには、(mod, vals) をわたす
*/
template <typename INT>
struct Modular_Subset_Sum {
  int mod;
  vc<INT>& vals;
  vc<int> par;

  Modular_Subset_Sum(int mod, vc<INT>& vals) : mod(mod), vals(vals) {
    for (auto&& x: vals) assert(0 <= x && x < mod);
    par.assign(mod, -1);

    const ll base = RNG(0, (1LL << 61) - 1);

    int k = 1;
    while ((1 << k) < 2 * mod) ++k;

    int L = 1 << k;
    assert(L >= 2 * mod);

    ShiftTree T1(L, base);
    ShiftTree T2(L, base);
    T1.set(0, 1);
    T2.set(0, 1);
    T2.set(L - mod, 1);

    auto bit_rev = [&](int i) -> int {
      int x = 0;
      FOR(k) {
        x = 2 * x + (i & 1);
        i >>= 1;
      }
      return x;
    };

    vc<vi> IDS(L);
    FOR(i, len(vals)) { IDS[vals[i]].eb(i); }

    FOR(i, 1, L) {
      int x = bit_rev(i);
      if (len(IDS[x]) == 0) continue;
      T2.shift(x - T2.delta);
      for (auto&& idx: IDS[x]) {
        auto diff = ShiftTree::diff(T1, T2, 0, mod);
        for (auto&& d: diff) {
          if (can(d)) continue;
          par[d] = idx;
          T1.set(d, 1);
          T2.set((d + x) % L, 1);
          T2.set((L + d + x - mod) % L, 1);
        }
      }
    }
  }

  bool can(int x) {
    if (x >= mod) return false;
    return (x == 0 || par[x] != -1);
  }

  bool operator[](int x) { return can(x); }
  vc<int> restore(int x) {
    assert(can(x));
    vc<int> res;
    while (x) {
      int i = par[x];
      res.eb(i);
      x -= vals[i];
      if (x < 0) x += mod;
    }
    reverse(all(res));
    return res;
  }
};
#line 2 "random/base.hpp"

u64 RNG_64() {
  static uint64_t x_
      = uint64_t(chrono::duration_cast<chrono::nanoseconds>(
                     chrono::high_resolution_clock::now().time_since_epoch())
                     .count())
        * 10150724397891781847ULL;
  x_ ^= x_ << 7;
  return x_ ^= x_ >> 9;
}

u64 RNG(u64 lim) { return RNG_64() % lim; }

ll RNG(ll l, ll r) { return l + RNG_64() % (r - l); }
#line 2 "mod/modint61.hpp"

struct modint61 {
  static constexpr u64 mod = (1ULL << 61) - 1;
  u64 val;
  constexpr modint61() : val(0ULL) {}
  constexpr modint61(u32 x) : val(x) {}
  constexpr modint61(u64 x) : val(x % mod) {}
  constexpr modint61(int x) : val((x < 0) ? (x + static_cast<ll>(mod)) : x) {}
  constexpr modint61(ll x)
      : val(((x %= static_cast<ll>(mod)) < 0) ? (x + static_cast<ll>(mod))
                                              : x) {}
  static constexpr u64 get_mod() { return mod; }
  modint61 &operator+=(const modint61 &a) {
    val = ((val += a.val) >= mod) ? (val - mod) : val;
    return *this;
  }
  modint61 &operator-=(const modint61 &a) {
    val = ((val -= a.val) >= mod) ? (val + mod) : val;
    return *this;
  }
  modint61 &operator*=(const modint61 &a) {
    const unsigned __int128 y = static_cast<unsigned __int128>(val) * a.val;
    val = (y >> 61) + (y & mod);
    val = (val >= mod) ? (val - mod) : val;
    return *this;
  }
  modint61 operator-() const { return modint61(val ? mod - val : u64(0)); }
  modint61 &operator/=(const modint61 &a) { return (*this *= a.inverse()); }
  modint61 operator+(const modint61 &p) const { return modint61(*this) += p; }
  modint61 operator-(const modint61 &p) const { return modint61(*this) -= p; }
  modint61 operator*(const modint61 &p) const { return modint61(*this) *= p; }
  modint61 operator/(const modint61 &p) const { return modint61(*this) /= p; }
  bool operator==(const modint61 &p) const { return val == p.val; }
  bool operator!=(const modint61 &p) const { return val != p.val; }
  modint61 inverse() const {
    ll a = val, b = mod, u = 1, v = 0, t;
    while (b > 0) {
      t = a / b;
      swap(a -= t * b, b), swap(u -= t * v, v);
    }
    return modint61(u);
  }
  modint61 pow(ll n) const {
    assert(n >= 0);
    modint61 ret(1), mul(val);
    while (n > 0) {
      if (n & 1) ret *= mul;
      mul *= mul, n >>= 1;
    }
    return ret;
  }
};

#ifdef FASTIO
void rd(modint61 &x) {
  fastio::rd(x.val);
  assert(0 <= x.val && x.val < modint61::mod);
}

void wt(modint61 x) { fastio::wt(x.val); }
#endif
#line 3 "mod/modular_subset_sum.hpp"

// Faster Deterministic Modular Subset Sum. arXiv preprint arXiv:2012.06062.

// modular subset sum のための、シフト付きセグ木

// shift には 2^(N-k) 時間かかる

struct ShiftTree {
  using M61 = modint61;
  int delta;
  int N, n;
  M61 base;
  vc<M61> dat;
  vc<M61> base_pow;

  ShiftTree(int N, ll base) : delta(0), N(N), n(topbit(N)), base(base) {
    assert(N == (1 << n));
    dat.assign(2 * N, 0);

    base_pow.assign(n, 1);
    base_pow[n - 1] = base;
    FOR_R(i, n - 1) base_pow[i] = base_pow[i + 1] * base_pow[i + 1];
  }

  inline int skew(int k) { return (delta >> (n - k)) & 1; }

  inline int left(int k, int i) {
    int mask = (1 << (k + 1)) - 1;
    return ((2 * i + 0 - skew(k + 1)) & mask) + (1 << (k + 1));
  }

  inline int right(int k, int i) {
    int mask = (1 << (k + 1)) - 1;
    return ((2 * i + 1 - skew(k + 1)) & mask) + (1 << (k + 1));
  }

  inline int parent(int k, int i) {
    int mask = (1 << k) - 1;
    return (((i + skew(k)) & mask) + (1 << k)) / 2;
  }

  inline void update(int k, int i) {
    M61 b = base_pow[k];
    dat[i] = b * dat[left(k, i)] + dat[right(k, i)];
  }

  inline void set(int i, ll x) {
    i = (i + N - delta) % N + N;
    dat[i] = x;
    int k = n;
    while (i != 1) {
      i = parent(k, i);
      --k;
      update(k, i);
    }
  }

  void shift(int k) {
    k %= N;
    if (k < 0) k += N;
    if (k == 0) return;
    int j = lowbit(k);
    delta = (delta + k) % N;
    FOR_R(k, n - j) { FOR3(i, 1 << k, 2 << k) update(k, i); }
  }

  // [a,b) における difference の列挙。output sensitive。

  // T のノード i、Q のノード j が (x,y) を指すとする。

  static void find_differences(vc<int>& res, ShiftTree& T, ShiftTree& Q, int a,
                               int b, int k, int i, int j, int x, int y) {
    if (T.dat[i] == Q.dat[j]) return;
    if (max(a, x) >= min(b, y)) return;
    if (y == x + 1) {
      res.eb(x);
      return;
    }
    int z = (x + y) / 2;
    find_differences(res, T, Q, a, b, k + 1, T.left(k, i), Q.left(k, j), x, z);
    find_differences(res, T, Q, a, b, k + 1, T.right(k, i), Q.right(k, j), z,
                     y);
  }

  static vc<int> diff(ShiftTree& T, ShiftTree& Q, int a, int b) {
    assert(T.N == Q.N);
    vc<int> res;
    find_differences(res, T, Q, a, b, 0, 1, 1, 0, T.N);
    return res;
  }
};

/*
計算量:(|vals| + mod) * log(mod)
・can(x) または [x] で bool を返す。
・restore(x) で復元。
コンストラクタには、(mod, vals) をわたす
*/
template <typename INT>
struct Modular_Subset_Sum {
  int mod;
  vc<INT>& vals;
  vc<int> par;

  Modular_Subset_Sum(int mod, vc<INT>& vals) : mod(mod), vals(vals) {
    for (auto&& x: vals) assert(0 <= x && x < mod);
    par.assign(mod, -1);

    const ll base = RNG(0, (1LL << 61) - 1);

    int k = 1;
    while ((1 << k) < 2 * mod) ++k;

    int L = 1 << k;
    assert(L >= 2 * mod);

    ShiftTree T1(L, base);
    ShiftTree T2(L, base);
    T1.set(0, 1);
    T2.set(0, 1);
    T2.set(L - mod, 1);

    auto bit_rev = [&](int i) -> int {
      int x = 0;
      FOR(k) {
        x = 2 * x + (i & 1);
        i >>= 1;
      }
      return x;
    };

    vc<vi> IDS(L);
    FOR(i, len(vals)) { IDS[vals[i]].eb(i); }

    FOR(i, 1, L) {
      int x = bit_rev(i);
      if (len(IDS[x]) == 0) continue;
      T2.shift(x - T2.delta);
      for (auto&& idx: IDS[x]) {
        auto diff = ShiftTree::diff(T1, T2, 0, mod);
        for (auto&& d: diff) {
          if (can(d)) continue;
          par[d] = idx;
          T1.set(d, 1);
          T2.set((d + x) % L, 1);
          T2.set((L + d + x - mod) % L, 1);
        }
      }
    }
  }

  bool can(int x) {
    if (x >= mod) return false;
    return (x == 0 || par[x] != -1);
  }

  bool operator[](int x) { return can(x); }
  vc<int> restore(int x) {
    assert(can(x));
    vc<int> res;
    while (x) {
      int i = par[x];
      res.eb(i);
      x -= vals[i];
      if (x < 0) x += mod;
    }
    reverse(all(res));
    return res;
  }
};
Back to top page