This documentation is automatically generated by online-judge-tools/verification-helper
#include "ds/binary_trie.hpp"
#include "ds/node_pool.hpp"
// 非永続ならば、2 * 要素数 のノード数
template <int LOG, bool PERSISTENT, typename UINT = u64,
typename SIZE_TYPE = u32>
struct Binary_Trie {
using T = SIZE_TYPE;
static_assert(is_same_v<T, u32> || is_same_v<T, u64>);
static_assert(0 < LOG && LOG <= numeric_limits<UINT>::digits);
struct Node {
int width;
UINT val;
T cnt;
Node *l, *r;
};
Node_Pool<Node> pool;
using np = Node *;
void reset() { pool.reset(); }
np new_root() { return nullptr; }
np add(np root, UINT val, T cnt = 1) {
if (!root) root = new_node(0, 0);
assert((val >> LOG) == 0);
return add_rec(root, LOG, val, cnt);
}
// f(val, cnt)
template <typename F>
void enumerate(np root, F f) {
auto dfs = [&](auto &dfs, np root, UINT val, int ht) -> void {
if (ht == 0) {
f(val, root->cnt);
return;
}
np c = root->l;
if (c) {
dfs(dfs, c, val << (c->width) | (c->val), ht - (c->width));
}
c = root->r;
if (c) {
dfs(dfs, c, val << (c->width) | (c->val), ht - (c->width));
}
};
if (root) dfs(dfs, root, 0, LOG);
}
// xor_val したあとの値で昇順 k 番目
UINT kth(np root, T k, UINT xor_val) {
assert(root && k < root->cnt);
return kth_rec(root, 0, k, LOG, xor_val) ^ xor_val;
}
// xor_val したあとの値で最小値
UINT min(np root, UINT xor_val) {
assert(root && root->cnt);
return kth(root, 0, xor_val);
}
// xor_val したあとの値で最大値
UINT max(np root, UINT xor_val) {
assert(root && root->cnt);
return kth(root, (root->cnt) - 1, xor_val);
}
// xor_val したあとの値で [0, upper) 内に入るものの個数
T prefix_count(np root, UINT upper, UINT xor_val) {
if (!root) return 0;
return prefix_count_rec(root, LOG, upper, xor_val, 0);
}
// xor_val したあとの値で [lo, hi) 内に入るものの個数
T count(np root, UINT lo, UINT hi, UINT xor_val) {
return prefix_count(root, hi, xor_val) - prefix_count(root, lo, xor_val);
}
private:
inline UINT mask(int k) { return (UINT(1) << k) - 1; }
np new_node(int width, UINT val) {
np c = pool.create();
c->l = c->r = nullptr;
c->width = width, c->val = val, c->cnt = 0;
return c;
}
np clone(np c) {
if (!c || !PERSISTENT) return c;
return pool.clone(c);
}
np add_rec(np root, int ht, UINT val, T cnt) {
root = clone(root);
root->cnt += cnt;
if (ht == 0) return root;
bool go_r = (val >> (ht - 1)) & 1;
np c = (go_r ? root->r : root->l);
if (!c) {
c = new_node(ht, val);
c->cnt = cnt;
if (!go_r) root->l = c;
if (go_r) root->r = c;
return root;
}
int w = c->width;
if ((val >> (ht - w)) == c->val) {
c = add_rec(c, ht - w, val & mask(ht - w), cnt);
if (!go_r) root->l = c;
if (go_r) root->r = c;
return root;
}
int same = w - 1 - topbit((val >> (ht - w)) ^ (c->val));
np n = new_node(same, (c->val) >> (w - same));
n->cnt = c->cnt + cnt;
c = clone(c);
c->width = w - same;
c->val = c->val & mask(w - same);
if ((val >> (ht - same - 1)) & 1) {
n->l = c;
n->r = new_node(ht - same, val & mask(ht - same));
n->r->cnt = cnt;
} else {
n->r = c;
n->l = new_node(ht - same, val & mask(ht - same));
n->l->cnt = cnt;
}
if (!go_r) root->l = n;
if (go_r) root->r = n;
return root;
}
UINT kth_rec(np root, UINT val, T k, int ht, UINT xor_val) {
if (ht == 0) return val;
np left = root->l, right = root->r;
if ((xor_val >> (ht - 1)) & 1) swap(left, right);
T sl = (left ? left->cnt : 0);
np c;
if (k < sl) {
c = left;
}
if (k >= sl) {
c = right, k -= sl;
}
int w = c->width;
return kth_rec(c, val << w | (c->val), k, ht - w, xor_val);
}
T prefix_count_rec(np root, int ht, UINT LIM, UINT xor_val, UINT val) {
UINT now = (val << ht) ^ (xor_val);
if ((LIM >> ht) > (now >> ht)) return root->cnt;
if (ht == 0 || (LIM >> ht) < (now >> ht)) return 0;
T res = 0;
FOR(k, 2) {
np c = (k == 0 ? root->l : root->r);
if (c) {
int w = c->width;
res += prefix_count_rec(c, ht - w, LIM, xor_val, val << w | c->val);
}
}
return res;
}
};
#line 1 "ds/node_pool.hpp"
template <class Node>
struct Node_Pool {
struct Slot {
union alignas(Node) {
Slot* next;
unsigned char storage[sizeof(Node)];
};
};
using np = Node*;
static constexpr int CHUNK_SIZE = 1 << 16;
vc<unique_ptr<Slot[]>> chunks;
Slot* cur = nullptr;
int cur_used = 0;
Slot* free_head = nullptr;
Node_Pool() { alloc_chunk(); }
template <class... Args>
np create(Args&&... args) {
Slot* s = new_slot();
return ::new (s) Node(forward<Args>(args)...);
}
np clone(const np x) {
assert(x);
Slot* s = new_slot();
return ::new (s) Node(*x); // コピーコンストラクタ呼び出し
}
void destroy(np x) {
if (!x) return;
x->~Node();
auto s = reinterpret_cast<Slot*>(x);
s->next = free_head;
free_head = s;
}
void reset() {
free_head = nullptr;
if (!chunks.empty()) {
cur = chunks[0].get();
cur_used = 0;
}
}
private:
void alloc_chunk() {
chunks.emplace_back(make_unique<Slot[]>(CHUNK_SIZE));
cur = chunks.back().get();
cur_used = 0;
}
Slot* new_slot() {
if (free_head) {
Slot* s = free_head;
free_head = free_head->next;
return s;
}
if (cur_used == CHUNK_SIZE) alloc_chunk();
return &cur[cur_used++];
}
};
#line 2 "ds/binary_trie.hpp"
// 非永続ならば、2 * 要素数 のノード数
template <int LOG, bool PERSISTENT, typename UINT = u64,
typename SIZE_TYPE = u32>
struct Binary_Trie {
using T = SIZE_TYPE;
static_assert(is_same_v<T, u32> || is_same_v<T, u64>);
static_assert(0 < LOG && LOG <= numeric_limits<UINT>::digits);
struct Node {
int width;
UINT val;
T cnt;
Node *l, *r;
};
Node_Pool<Node> pool;
using np = Node *;
void reset() { pool.reset(); }
np new_root() { return nullptr; }
np add(np root, UINT val, T cnt = 1) {
if (!root) root = new_node(0, 0);
assert((val >> LOG) == 0);
return add_rec(root, LOG, val, cnt);
}
// f(val, cnt)
template <typename F>
void enumerate(np root, F f) {
auto dfs = [&](auto &dfs, np root, UINT val, int ht) -> void {
if (ht == 0) {
f(val, root->cnt);
return;
}
np c = root->l;
if (c) {
dfs(dfs, c, val << (c->width) | (c->val), ht - (c->width));
}
c = root->r;
if (c) {
dfs(dfs, c, val << (c->width) | (c->val), ht - (c->width));
}
};
if (root) dfs(dfs, root, 0, LOG);
}
// xor_val したあとの値で昇順 k 番目
UINT kth(np root, T k, UINT xor_val) {
assert(root && k < root->cnt);
return kth_rec(root, 0, k, LOG, xor_val) ^ xor_val;
}
// xor_val したあとの値で最小値
UINT min(np root, UINT xor_val) {
assert(root && root->cnt);
return kth(root, 0, xor_val);
}
// xor_val したあとの値で最大値
UINT max(np root, UINT xor_val) {
assert(root && root->cnt);
return kth(root, (root->cnt) - 1, xor_val);
}
// xor_val したあとの値で [0, upper) 内に入るものの個数
T prefix_count(np root, UINT upper, UINT xor_val) {
if (!root) return 0;
return prefix_count_rec(root, LOG, upper, xor_val, 0);
}
// xor_val したあとの値で [lo, hi) 内に入るものの個数
T count(np root, UINT lo, UINT hi, UINT xor_val) {
return prefix_count(root, hi, xor_val) - prefix_count(root, lo, xor_val);
}
private:
inline UINT mask(int k) { return (UINT(1) << k) - 1; }
np new_node(int width, UINT val) {
np c = pool.create();
c->l = c->r = nullptr;
c->width = width, c->val = val, c->cnt = 0;
return c;
}
np clone(np c) {
if (!c || !PERSISTENT) return c;
return pool.clone(c);
}
np add_rec(np root, int ht, UINT val, T cnt) {
root = clone(root);
root->cnt += cnt;
if (ht == 0) return root;
bool go_r = (val >> (ht - 1)) & 1;
np c = (go_r ? root->r : root->l);
if (!c) {
c = new_node(ht, val);
c->cnt = cnt;
if (!go_r) root->l = c;
if (go_r) root->r = c;
return root;
}
int w = c->width;
if ((val >> (ht - w)) == c->val) {
c = add_rec(c, ht - w, val & mask(ht - w), cnt);
if (!go_r) root->l = c;
if (go_r) root->r = c;
return root;
}
int same = w - 1 - topbit((val >> (ht - w)) ^ (c->val));
np n = new_node(same, (c->val) >> (w - same));
n->cnt = c->cnt + cnt;
c = clone(c);
c->width = w - same;
c->val = c->val & mask(w - same);
if ((val >> (ht - same - 1)) & 1) {
n->l = c;
n->r = new_node(ht - same, val & mask(ht - same));
n->r->cnt = cnt;
} else {
n->r = c;
n->l = new_node(ht - same, val & mask(ht - same));
n->l->cnt = cnt;
}
if (!go_r) root->l = n;
if (go_r) root->r = n;
return root;
}
UINT kth_rec(np root, UINT val, T k, int ht, UINT xor_val) {
if (ht == 0) return val;
np left = root->l, right = root->r;
if ((xor_val >> (ht - 1)) & 1) swap(left, right);
T sl = (left ? left->cnt : 0);
np c;
if (k < sl) {
c = left;
}
if (k >= sl) {
c = right, k -= sl;
}
int w = c->width;
return kth_rec(c, val << w | (c->val), k, ht - w, xor_val);
}
T prefix_count_rec(np root, int ht, UINT LIM, UINT xor_val, UINT val) {
UINT now = (val << ht) ^ (xor_val);
if ((LIM >> ht) > (now >> ht)) return root->cnt;
if (ht == 0 || (LIM >> ht) < (now >> ht)) return 0;
T res = 0;
FOR(k, 2) {
np c = (k == 0 ? root->l : root->r);
if (c) {
int w = c->width;
res += prefix_count_rec(c, ht - w, LIM, xor_val, val << w | c->val);
}
}
return res;
}
};