library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub maspypy/library

:heavy_check_mark: mod/mod_log_998244353.hpp

Depends on

Verified with

Code

#include "mod/discrete_log_998244353.hpp"
#include "mod/mod_inv.hpp"

int mod_log_998244353(int a, int b) {
  int x = discrete_log_mod_998244353_primitive_root(a);
  int y = discrete_log_mod_998244353_primitive_root(b);
  int m = 998244353 - 1;
  int g = gcd(x, m);
  if (y % g != 0) return -1;
  x /= g, y /= g, m /= g;
  return mod_inv(x, g) * y % m;
}
#line 2 "mod/modint_common.hpp"

struct has_mod_impl {
  template <class T>
  static auto check(T &&x) -> decltype(x.get_mod(), std::true_type{});
  template <class T>
  static auto check(...) -> std::false_type;
};

template <class T>
class has_mod : public decltype(has_mod_impl::check<T>(std::declval<T>())) {};

template <typename mint>
mint inv(int n) {
  static const int mod = mint::get_mod();
  static vector<mint> dat = {0, 1};
  assert(0 <= n);
  if (n >= mod) n %= mod;
  while (len(dat) <= n) {
    int k = len(dat);
    int q = (mod + k - 1) / k;
    dat.eb(dat[k * q - mod] * mint::raw(q));
  }
  return dat[n];
}

template <typename mint>
mint fact(int n) {
  static const int mod = mint::get_mod();
  assert(0 <= n && n < mod);
  static vector<mint> dat = {1, 1};
  while (len(dat) <= n) dat.eb(dat[len(dat) - 1] * mint::raw(len(dat)));
  return dat[n];
}

template <typename mint>
mint fact_inv(int n) {
  static vector<mint> dat = {1, 1};
  if (n < 0) return mint(0);
  while (len(dat) <= n) dat.eb(dat[len(dat) - 1] * inv<mint>(len(dat)));
  return dat[n];
}

template <class mint, class... Ts>
mint fact_invs(Ts... xs) {
  return (mint(1) * ... * fact_inv<mint>(xs));
}

template <typename mint, class Head, class... Tail>
mint multinomial(Head &&head, Tail &&... tail) {
  return fact<mint>(head) * fact_invs<mint>(std::forward<Tail>(tail)...);
}

template <typename mint>
mint C_dense(int n, int k) {
  assert(n >= 0);
  if (k < 0 || n < k) return 0;
  static vvc<mint> C;
  static int H = 0, W = 0;
  auto calc = [&](int i, int j) -> mint {
    if (i == 0) return (j == 0 ? mint(1) : mint(0));
    return C[i - 1][j] + (j ? C[i - 1][j - 1] : 0);
  };
  if (W <= k) {
    FOR(i, H) {
      C[i].resize(k + 1);
      FOR(j, W, k + 1) { C[i][j] = calc(i, j); }
    }
    W = k + 1;
  }
  if (H <= n) {
    C.resize(n + 1);
    FOR(i, H, n + 1) {
      C[i].resize(W);
      FOR(j, W) { C[i][j] = calc(i, j); }
    }
    H = n + 1;
  }
  return C[n][k];
}

template <typename mint, bool large = false, bool dense = false>
mint C(ll n, ll k) {
  assert(n >= 0);
  if (k < 0 || n < k) return 0;
  if constexpr (dense) return C_dense<mint>(n, k);
  if constexpr (!large) return multinomial<mint>(n, k, n - k);
  k = min(k, n - k);
  mint x(1);
  FOR(i, k) x *= mint(n - i);
  return x * fact_inv<mint>(k);
}

template <typename mint, bool large = false>
mint C_inv(ll n, ll k) {
  assert(n >= 0);
  assert(0 <= k && k <= n);
  if (!large) return fact_inv<mint>(n) * fact<mint>(k) * fact<mint>(n - k);
  return mint(1) / C<mint, 1>(n, k);
}

// [x^d](1-x)^{-n}
template <typename mint, bool large = false, bool dense = false>
mint C_negative(ll n, ll d) {
  assert(n >= 0);
  if (d < 0) return mint(0);
  if (n == 0) { return (d == 0 ? mint(1) : mint(0)); }
  return C<mint, large, dense>(n + d - 1, d);
}
#line 3 "mod/modint.hpp"

template <int mod>
struct modint {
  static constexpr u32 umod = u32(mod);
  static_assert(umod < u32(1) << 31);
  u32 val;

  static modint raw(u32 v) {
    modint x;
    x.val = v;
    return x;
  }
  constexpr modint() : val(0) {}
  constexpr modint(u32 x) : val(x % umod) {}
  constexpr modint(u64 x) : val(x % umod) {}
  constexpr modint(u128 x) : val(x % umod) {}
  constexpr modint(int x) : val((x %= mod) < 0 ? x + mod : x){};
  constexpr modint(ll x) : val((x %= mod) < 0 ? x + mod : x){};
  constexpr modint(i128 x) : val((x %= mod) < 0 ? x + mod : x){};
  bool operator<(const modint &other) const { return val < other.val; }
  modint &operator+=(const modint &p) {
    if ((val += p.val) >= umod) val -= umod;
    return *this;
  }
  modint &operator-=(const modint &p) {
    if ((val += umod - p.val) >= umod) val -= umod;
    return *this;
  }
  modint &operator*=(const modint &p) {
    val = u64(val) * p.val % umod;
    return *this;
  }
  modint &operator/=(const modint &p) {
    *this *= p.inverse();
    return *this;
  }
  modint operator-() const { return modint::raw(val ? mod - val : u32(0)); }
  modint operator+(const modint &p) const { return modint(*this) += p; }
  modint operator-(const modint &p) const { return modint(*this) -= p; }
  modint operator*(const modint &p) const { return modint(*this) *= p; }
  modint operator/(const modint &p) const { return modint(*this) /= p; }
  bool operator==(const modint &p) const { return val == p.val; }
  bool operator!=(const modint &p) const { return val != p.val; }
  modint inverse() const {
    int a = val, b = mod, u = 1, v = 0, t;
    while (b > 0) {
      t = a / b;
      swap(a -= t * b, b), swap(u -= t * v, v);
    }
    return modint(u);
  }
  modint pow(ll n) const {
    assert(n >= 0);
    modint ret(1), mul(val);
    while (n > 0) {
      if (n & 1) ret *= mul;
      mul *= mul;
      n >>= 1;
    }
    return ret;
  }
  static constexpr int get_mod() { return mod; }
  // (n, r), r は 1 の 2^n 乗根
  static constexpr pair<int, int> ntt_info() {
    if (mod == 120586241) return {20, 74066978};
    if (mod == 167772161) return {25, 17};
    if (mod == 469762049) return {26, 30};
    if (mod == 754974721) return {24, 362};
    if (mod == 880803841) return {23, 211};
    if (mod == 943718401) return {22, 663003469};
    if (mod == 998244353) return {23, 31};
    if (mod == 1004535809) return {21, 582313106};
    if (mod == 1012924417) return {21, 368093570};
    return {-1, -1};
  }
  static constexpr bool can_ntt() { return ntt_info().fi != -1; }
};

#ifdef FASTIO
template <int mod>
void rd(modint<mod> &x) {
  fastio::rd(x.val);
  x.val %= mod;
  // assert(0 <= x.val && x.val < mod);
}
template <int mod>
void wt(modint<mod> x) {
  fastio::wt(x.val);
}
#endif

using modint107 = modint<1000000007>;
using modint998 = modint<998244353>;
#line 2 "mod/discrete_log_998244353.hpp"

namespace DISCRETE_LOG_998 {
const int A = 32768;
const int B = 30464;
const int r = 3;
const int mod = 998244353;

u32 rpow_0[A + 1];
u32 rpow_1[A + 1];
u32 AX[4 * B + 1];
u32 AI[4 * B + 1];
u32 BX[4 * B + 1];
u32 BI[4 * B + 1];

u32 get_pow_30(u32 n) { return u64(rpow_1[n / A]) * rpow_0[n % A] % mod; }
u32 get_pow(u64 n) { return get_pow_30(n % (mod - 1)); }
u32 H(u32 x) { return x >> 13; }; // hash func

void __attribute__((constructor)) init_table() {
  rpow_0[0] = rpow_1[0] = 1;
  FOR(i, A) rpow_0[i + 1] = u64(rpow_0[i]) * r % mod;
  FOR(i, A) rpow_1[i + 1] = u64(rpow_1[i]) * rpow_0[A] % mod;
  FOR(i, B) {
    u32 x = get_pow_30(A * i);
    int k = H(x);
    while (AX[k]) ++k;
    AX[k] = x, AI[k] = i;
  }
  FOR(i, A) {
    u32 x = get_pow_30(B * i);
    int k = H(x);
    while (BX[k]) ++k;
    BX[k] = x, BI[k] = i;
  }
}

// 掛け算 17 回 + hashmap 2 回
// 10^7 回 0.6 sec
int discrete_log_mod_998244353_primitive_root(modint998 a) {
  // a^A は 1 の B 乗根なので pow(r, xA) と書ける
  modint998 b = a;
  FOR(15) b *= b;
  int k = H(b.val);
  while (AX[k] != b.val) ++k;
  int x = AI[k];
  // ar^{-x} は 1 の A 乗根なので pow(r, yB) と書ける
  a *= get_pow_30(mod - 1 - x);
  k = H(a.val);
  while (BX[k] != a.val) ++k;
  int y = BI[k];
  return x + y * B;
}
} // namespace DISCRETE_LOG_998
using DISCRETE_LOG_998::discrete_log_mod_998244353_primitive_root;
#line 2 "mod/mod_inv.hpp"

// long でも大丈夫

// (val * x - 1) が mod の倍数になるようにする

// 特に mod=0 なら x=0 が満たす

ll mod_inv(ll val, ll mod) {
  if (mod == 0) return 0;
  mod = abs(mod);
  val %= mod;
  if (val < 0) val += mod;
  ll a = val, b = mod, u = 1, v = 0, t;
  while (b > 0) {
    t = a / b;
    swap(a -= t * b, b), swap(u -= t * v, v);
  }
  if (u < 0) u += mod;
  return u;
}
#line 3 "mod/mod_log_998244353.hpp"

int mod_log_998244353(int a, int b) {
  int x = discrete_log_mod_998244353_primitive_root(a);
  int y = discrete_log_mod_998244353_primitive_root(b);
  int m = 998244353 - 1;
  int g = gcd(x, m);
  if (y % g != 0) return -1;
  x /= g, y /= g, m /= g;
  return mod_inv(x, g) * y % m;
}
Back to top page