Neterukun's Library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub Neterukun1993/Library

:heavy_check_mark: 任意 MOD 畳み込み
(NumberTheory/Convolution/arbitrary_mod_convolve.py)

使い方

arbitrary_mod_convolve(a: Sequence[int], b: Sequence[int], p: int) -> List[int]
長さ $N$ の数列 $a$ と長さ $M$ の数列 $b$ について、$c_k = \sum_{i + j \equiv k} a_ib_j \bmod p$ となる長さ $N + M - 1$ の数列 $c$ を返す。計算量 $O((N + M) \log (N + M))$

Verified with

Code

M1, R1 = 167772161, 3
M2, R2 = 469762049, 3
M3, R3 = 1224736769, 3


def MOD1(): return M1
def ROOT1(): return R1
def MOD2(): return M2
def ROOT2(): return R2
def MOD3(): return M3
def ROOT3(): return R3


def _ntt(a, h, MOD, ROOT):
    roots = [pow(ROOT(), (MOD() - 1) >> i, MOD()) for i in range(h + 1)]
    for i in range(h):
        m = 1 << (h - i - 1)
        for j in range(1 << i):
            w = 1
            j *= 2 * m
            for k in range(m):
                a[j + k], a[j + k + m] = \
                    (a[j + k] + a[j + k + m]) % MOD(), \
                    (a[j + k] - a[j + k + m]) * w % MOD()
                w *= roots[h - i]
                w %= MOD()


def _intt(a, h, MOD, ROOT):
    roots = [pow(ROOT(), (MOD() - 1) >> i, MOD()) for i in range(h + 1)]
    iroots = [pow(r, MOD() - 2, MOD()) for r in roots]
    for i in range(h):
        m = 1 << i
        for j in range(1 << (h - i - 1)):
            w = 1
            j *= 2 * m
            for k in range(m):
                a[j + k], a[j + k + m] = \
                    (a[j + k] + a[j + k + m] * w) % MOD(), \
                    (a[j + k] - a[j + k + m] * w) % MOD()
                w *= iroots[i + 1]
                w %= MOD()
    inv = pow(1 << h, MOD() - 2, MOD())
    for i in range(1 << h):
        a[i] *= inv
        a[i] %= MOD()


def _ntt_convolve(a, b, MOD, ROOT):
    n = 1 << (len(a) + len(b) - 1).bit_length()
    h = n.bit_length() - 1
    a = list(a) + [0] * (n - len(a))
    b = list(b) + [0] * (n - len(b))

    _ntt(a, h, MOD, ROOT), _ntt(b, h, MOD, ROOT)
    a = [va * vb % MOD() for va, vb in zip(a, b)]
    _intt(a, h, MOD, ROOT)
    return a


def arbitrary_mod_convolve(a, b, p):
    x = _ntt_convolve(a, b, MOD1, ROOT1)
    y = _ntt_convolve(a, b, MOD2, ROOT2)
    z = _ntt_convolve(a, b, MOD3, ROOT3)
    mod = p

    inv1_2 = pow(MOD1(), MOD2() - 2, MOD2())
    inv12_3 = pow(MOD1() * MOD2(), MOD3() - 2, MOD3())
    mod12 = MOD1() * MOD2() % mod

    res = [0] * (len(a) + len(b) - 1)
    for i in range(len(res)):
        v1 = (y[i] - x[i]) * inv1_2 % MOD2()
        v2 = (z[i] - (x[i] + MOD1() * v1) % MOD3()) * inv12_3 % MOD3()
        res[i] = (x[i] + MOD1() * v1 + mod12 * v2) % mod
    return res
Traceback (most recent call last):
  File "/opt/hostedtoolcache/Python/3.12.4/x64/lib/python3.12/site-packages/onlinejudge_verify/documentation/build.py", line 71, in _render_source_code_stat
    bundled_code = language.bundle(stat.path, basedir=basedir, options={'include_paths': [basedir]}).decode()
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/hostedtoolcache/Python/3.12.4/x64/lib/python3.12/site-packages/onlinejudge_verify/languages/python.py", line 96, in bundle
    raise NotImplementedError
NotImplementedError
Back to top page