library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub maspypy/library

:question: poly/online/online_square.hpp

Depends on

Verified with

Code

#pragma once
#include "poly/ntt.hpp"

/*
query(i):a[i]] を与えて (a^2)[i] を得る。
2^{17}:52ms
2^{18}:107ms
2^{19}:237ms
2^{20}:499ms
*/
template <class mint>
struct Online_Square {
  vc<mint> f, h, b0, b1;
  vvc<mint> fm;
  int p;

  Online_Square() : p(0) { assert(mint::can_ntt()); }

  mint query(int i, mint f_i) {
    assert(i == p);
    f.eb(f_i);
    int z = __builtin_ctz(p + 2), w = 1 << z, s;
    if (p + 2 == w) {
      b0 = f, b0.resize(2 * w);
      ntt(b0, false);
      fm.eb(b0.begin(), b0.begin() + w);
      FOR(i, 2 * w) b0[i] *= b0[i];
      s = w - 2;
      h.resize(2 * s + 2);
    } else {
      b0.assign(f.end() - w, f.end()), b0.resize(2 * w);
      ntt(b0, false);
      FOR(i, 2 * w) b0[i] *= mint(2) * fm[z][i];
      s = w - 1;
    }
    ntt(b0, true);
    FOR(i, s + 1) h[p + i] += b0[s + i];
    return h[p++];
  }
};
#line 2 "poly/ntt.hpp"

template <class mint>
void ntt(vector<mint>& a, bool inverse) {
  assert(mint::can_ntt());
  const int rank2 = mint::ntt_info().fi;
  const int mod = mint::get_mod();
  static array<mint, 30> root, iroot;
  static array<mint, 30> rate2, irate2;
  static array<mint, 30> rate3, irate3;

  assert(rank2 != -1 && len(a) <= (1 << max(0, rank2)));

  static bool prepared = 0;
  if (!prepared) {
    prepared = 1;
    root[rank2] = mint::ntt_info().se;
    iroot[rank2] = mint(1) / root[rank2];
    FOR_R(i, rank2) {
      root[i] = root[i + 1] * root[i + 1];
      iroot[i] = iroot[i + 1] * iroot[i + 1];
    }
    mint prod = 1, iprod = 1;
    for (int i = 0; i <= rank2 - 2; i++) {
      rate2[i] = root[i + 2] * prod;
      irate2[i] = iroot[i + 2] * iprod;
      prod *= iroot[i + 2];
      iprod *= root[i + 2];
    }
    prod = 1, iprod = 1;
    for (int i = 0; i <= rank2 - 3; i++) {
      rate3[i] = root[i + 3] * prod;
      irate3[i] = iroot[i + 3] * iprod;
      prod *= iroot[i + 3];
      iprod *= root[i + 3];
    }
  }

  int n = int(a.size());
  int h = topbit(n);
  assert(n == 1 << h);
  if (!inverse) {
    int len = 0;
    while (len < h) {
      if (h - len == 1) {
        int p = 1 << (h - len - 1);
        mint rot = 1;
        FOR(s, 1 << len) {
          int offset = s << (h - len);
          FOR(i, p) {
            auto l = a[i + offset];
            auto r = a[i + offset + p] * rot;
            a[i + offset] = l + r;
            a[i + offset + p] = l - r;
          }
          rot *= rate2[topbit(~s & -~s)];
        }
        len++;
      } else {
        int p = 1 << (h - len - 2);
        mint rot = 1, imag = root[2];
        for (int s = 0; s < (1 << len); s++) {
          mint rot2 = rot * rot;
          mint rot3 = rot2 * rot;
          int offset = s << (h - len);
          for (int i = 0; i < p; i++) {
            u64 mod2 = u64(mod) * mod;
            u64 a0 = a[i + offset].val;
            u64 a1 = u64(a[i + offset + p].val) * rot.val;
            u64 a2 = u64(a[i + offset + 2 * p].val) * rot2.val;
            u64 a3 = u64(a[i + offset + 3 * p].val) * rot3.val;
            u64 a1na3imag = (a1 + mod2 - a3) % mod * imag.val;
            u64 na2 = mod2 - a2;
            a[i + offset] = a0 + a2 + a1 + a3;
            a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3));
            a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
            a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag);
          }
          rot *= rate3[topbit(~s & -~s)];
        }
        len += 2;
      }
    }
  } else {
    mint coef = mint(1) / mint(len(a));
    FOR(i, len(a)) a[i] *= coef;
    int len = h;
    while (len) {
      if (len == 1) {
        int p = 1 << (h - len);
        mint irot = 1;
        FOR(s, 1 << (len - 1)) {
          int offset = s << (h - len + 1);
          FOR(i, p) {
            u64 l = a[i + offset].val;
            u64 r = a[i + offset + p].val;
            a[i + offset] = l + r;
            a[i + offset + p] = (mod + l - r) * irot.val;
          }
          irot *= irate2[topbit(~s & -~s)];
        }
        len--;
      } else {
        int p = 1 << (h - len);
        mint irot = 1, iimag = iroot[2];
        FOR(s, (1 << (len - 2))) {
          mint irot2 = irot * irot;
          mint irot3 = irot2 * irot;
          int offset = s << (h - len + 2);
          for (int i = 0; i < p; i++) {
            u64 a0 = a[i + offset + 0 * p].val;
            u64 a1 = a[i + offset + 1 * p].val;
            u64 a2 = a[i + offset + 2 * p].val;
            u64 a3 = a[i + offset + 3 * p].val;
            u64 x = (mod + a2 - a3) * iimag.val % mod;
            a[i + offset] = a0 + a1 + a2 + a3;
            a[i + offset + 1 * p] = (a0 + mod - a1 + x) * irot.val;
            a[i + offset + 2 * p] = (a0 + a1 + 2 * mod - a2 - a3) * irot2.val;
            a[i + offset + 3 * p] = (a0 + 2 * mod - a1 - x) * irot3.val;
          }
          irot *= irate3[topbit(~s & -~s)];
        }
        len -= 2;
      }
    }
  }
}
#line 3 "poly/online/online_square.hpp"

/*
query(i):a[i]] を与えて (a^2)[i] を得る。
2^{17}:52ms
2^{18}:107ms
2^{19}:237ms
2^{20}:499ms
*/
template <class mint>
struct Online_Square {
  vc<mint> f, h, b0, b1;
  vvc<mint> fm;
  int p;

  Online_Square() : p(0) { assert(mint::can_ntt()); }

  mint query(int i, mint f_i) {
    assert(i == p);
    f.eb(f_i);
    int z = __builtin_ctz(p + 2), w = 1 << z, s;
    if (p + 2 == w) {
      b0 = f, b0.resize(2 * w);
      ntt(b0, false);
      fm.eb(b0.begin(), b0.begin() + w);
      FOR(i, 2 * w) b0[i] *= b0[i];
      s = w - 2;
      h.resize(2 * s + 2);
    } else {
      b0.assign(f.end() - w, f.end()), b0.resize(2 * w);
      ntt(b0, false);
      FOR(i, 2 * w) b0[i] *= mint(2) * fm[z][i];
      s = w - 1;
    }
    ntt(b0, true);
    FOR(i, s + 1) h[p + i] += b0[s + i];
    return h[p++];
  }
};
Back to top page