hly1204's library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub hly1204/library

:heavy_check_mark: Radix-3 Schönhage's Algorithm (a.k.a. Radix-3 Schönhage–Strassen Algorithm) (standalone_test/convolution/convolution_mod_2_64.radix_3_schoenhage.test.cpp)

This algorithm was published by Schönhage alone, but somehow it was named after Schönhage and Strassen. The Schönhage–Strassen algorithm is for integers.

The idea here is learned from Bernstein’s paper, the pseudocode here is modified from AECF (it gives the Radix-2 version).

I want to write this blog because there is almost no resource/code examples for describing the detailed steps for this algorithm. It is easy for mathematicians, but it is not such easy for programmers who want to really implement it.

If you have read the blog Fast convolution for 64-bit integers on CodeForces, the idea is not new, see Eric Dubois, Anastasios N. Venetsanopoulos, A new algorithm for the radix-3 FFT, IEEE Transactions on Acoustics, Speech, and Signal Processing 26 (1978), 222–225. and combined it with Schönhage’s trick.

Radix-3 Schönhage’s trick

Given polynomials $a, b \in R\lbrack x\rbrack _ {\lt n}$, where $1/3 \in R$, Schönhage’s trick will give us an algorithm works in time $O(n \log(n) \log(\log(n)))$ (arithmetic model) to compute $ab \bmod{(x^{2n} + x^n + 1)}$ where $n$ is a power of $3$. I will omit the details here for the proof of the time complexity which could be found in Mateer’s paper.

\[\begin{array}{ll} & \textbf{Algorithm}\operatorname{\mathsf{Schoenhage}}(a, b, n) \\ & \textbf{Input}\text{: } a, b \in R\lbrack x\rbrack _ {\lt 2n}, n = 3^k, k \in \mathbb{N} \\ & \textbf{Output}\text{: } ab \bmod{(x^{2n} + x^n + 1)} \\ 1 & \textbf{if } n \leq 3 \textbf{ then} \\ 2 & \qquad \text{Compute } ab \text{ use Karatsuba's algorithm.} \\ 3 & \qquad \textbf{return } ab \bmod{(x^{2n} + x^n + 1)} \\ 4 & \textbf{endif} \\ 5 & d \gets 3^{\lceil k / 2 \rceil}, \delta \gets n / d \\ 6 & \hat{a}(x, y) \gets a \bmod{(x^{d} - y)}, \hat{b}(x, y) \gets b \bmod{(x^{d} - y)} \\ 7 & \text{Let } \hat{a}, \hat{b} \in (R\lbrack x\rbrack / (x^{2d} + x^d + 1)) \lbrack y\rbrack / (y^{2\delta} + y^{\delta} + 1) \text{ by doing some zero paddings.} \\ 8 & (\hat{a_0}, \hat{a_1}) \gets \left(\hat{a} \bmod{(y^{\delta} - x^d)}, \hat{a} \bmod{(y^{\delta} - x^{2d})}\right), (\hat{b_0}, \hat{b_1}) \gets \left(\hat{b} \bmod{(y^{\delta} - x^d)}, \hat{b} \bmod{(y^{\delta} - x^{2d})}\right) \\ 9 & \text{Apply Radix-3 FFT to } \hat{a_0}, \hat{a_1}, \hat{b_0}, \hat{b_1} \\ 10 & \textbf{for } i \gets 0 \textbf{ to } \delta - 1 \textbf{ do} \\ 11 & \qquad \hat{c_0} \bmod{(y - x^{e_d(i)})} \gets \operatorname{\mathsf{Schoenhage}}\left(\hat{a_0} \bmod{(y - x^{e_d(i)})}, \hat{b_0} \bmod{(y - x^{e_d(i)})}, d\right) \\ 12 & \qquad \hat{c_1} \bmod{(y - x^{e_{2d}(i)})} \gets \operatorname{\mathsf{Schoenhage}}\left(\hat{a_1} \bmod{(y - x^{e_{2d}(i)})}, \hat{b_1} \bmod{(y - x^{e_{2d}(i)})}, d\right) \\ 13 & \textbf{endfor} \\ 14 & \text{Apply Radix-3 IFFT to restore } \hat{c_0}, \hat{c_1} \\ 15 & \text{Use FFT trick to restore } \hat{a}\hat{b} \text{ from } \hat{c_0}, \hat{c_1} \\ 16 & \textbf{return } \left((\hat{a}\hat{b})\left(x, x^d\right)\right) \bmod{(x^{2n} + x^n + 1)} \end{array}\]

We will firstly show the algorithm for Ln 8 in the pseudocode, thus Ln 15 is also solved. Note that we could compute $A(x) x^N$ for any integer $N$ in $O(\min{(\lvert N \rvert, d)})$ inplace (Because memory movement is free in arithmetic model! e.g. Transposition of a matrix is free.) where $A \in R\lbrack x\rbrack / (x^{2d} + x^d + 1)$. It is also invertible, since $x^{3d} \equiv 1 \pmod{(x^{2d} + x^d + 1)}$.

For $D(x) = L + Hx^d \in R\lbrack x\rbrack / (x^{2d} + x^d + 1)$ and $\deg L \lt d, \deg H \lt d$, we have

\[\begin{aligned} (L + Hx^d)x^d &= Lx^d + Hx^{2d} &= -H + (L - H)x^d \\ (L + Hx^d)x^{2d} &= Lx^{2d} + H &= -L + H - Lx^d \end{aligned}\]

So we could give the pseudocode for Ln 8.

\[\begin{array}{ll} & \textbf{Algorithm}\operatorname{\mathsf{Ln8}}(a, d, \delta) \\ & \textbf{Input}\text{: } a \in (R\lbrack x\rbrack / (x^{2d} + x^d + 1))\lbrack y\rbrack / (y^{2\delta} + y^{\delta} + 1) \\ & \textbf{Output}\text{: } (a \bmod{(y^{\delta} - x^d)}, a \bmod{(y^{\delta} - x^{2d})}) \\ 1 & (a_0, a_1) \gets (a \bmod{y^{\delta}}, (a - a_0) / y^{\delta}) \\ 2 & (b_0, b_1) \gets (0, 0) \\ 3 & \textbf{for } i \gets 0 \textbf{ to } \delta - 1 \textbf{ do} \\ 4 & \qquad (L_{a_0}, H_{a_0}) \gets ((\lbrack y^i \rbrack a_0) \bmod{x^d}, ((\lbrack y^i \rbrack a_0) - (\lbrack y^i \rbrack a_0) \bmod{x^d}) / x^d) \\ 5 & \qquad (L_{a_1}, H_{a_1}) \gets ((\lbrack y^i \rbrack a_1) \bmod{x^d}, ((\lbrack y^i \rbrack a_1) - (\lbrack y^i \rbrack a_1) \bmod{x^d}) / x^d) \\ 6 & \qquad \begin{bmatrix}L_0 \\ H_0 \\ L_1 \\ H_1\end{bmatrix} \gets \begin{bmatrix} 1 & 0 & 0 & -1 \\ 0 & 1 & 1 & -1 \\ 1 & 0 & -1 & 1 \\ 0 & 1 & -1 & 0 \end{bmatrix} \begin{bmatrix}L_{a_0} \\ H_{a_0} \\ L_{a_1} \\ H_{a_1}\end{bmatrix} \\ 7 & \qquad (b_0, b_1) \gets (b_0 + (L_0 + H_0x^d)y^i, b_1 + (L_1 + H_1x^d)y^i) \\ 8 & \textbf{endfor} \\ 9 & \textbf{return } (b_0, b_1) \end{array}\]

Note that

\[\begin{bmatrix} 1 & 0 & 0 & -1 \\ 0 & 1 & 1 & -1 \\ 1 & 0 & -1 & 1 \\ 0 & 1 & -1 & 0 \end{bmatrix}^{-1} = \frac{1}{3} \begin{bmatrix} 1 & 1 & 2 & -1 \\ -1 & 2 & 1 & 1 \\ -1 & 2 & 1 & -2 \\ -2 & 1 & 2 & -1 \end{bmatrix}\]

so Ln 15 is solved automatically. Now the only remaining problem is the Radix-3 FFT trick.

Radix-3 FFT trick

Bernstein showed that the Radix-3 FFT trick is

\[\begin{aligned} S\lbrack y \rbrack / (y^{3N} - t^3) &\to S\lbrack y \rbrack / (y^N - t) \\ &\times S\lbrack y \rbrack / (y^N - \omega t) \\ &\times S\lbrack y \rbrack / (y^N - \omega^2 t) \end{aligned}\]

where $\omega^2 + \omega + 1 = 0$, and the algorithm is invertible if $1 / (3t^2)$ is invertible. In our case, $S = R\lbrack x\rbrack / (x^{2d} + x^d + 1)$, so $\omega = x^{d}$. $t$ is some power of $x$ and $1 / 3 \in R$, so it is invertible.

\[\begin{array}{ll} & \textbf{Algorithm}\operatorname{\mathsf{FFT}}(a, d, \delta, E) \\ & \textbf{Input}\text{: } a \in (R\lbrack x\rbrack / (x^{2d} + x^d + 1))\lbrack y\rbrack / (y^{\delta} - x^E), \delta \leq d, \delta \mid E \\ & \textbf{Output}\text{: } \begin{bmatrix}a \bmod{(y - x^{e_E(0)})} & \cdots & a \bmod{(y - x^{e_E(\delta - 1)})}\end{bmatrix} \\ 1 & \textbf{if } \delta = 1 \textbf{ then} \\ 2 & \qquad \textbf{return } \begin{bmatrix}a\end{bmatrix} \\ 3 & \textbf{endif} \\ 4 & a_0 \gets a \bmod{(y^{\delta / 3} - x^{E / 3})} \\ 5 & a_1 \gets a \bmod{(y^{\delta / 3} - x^{E / 3 + d})} \\ 6 & a_2 \gets a \bmod{(y^{\delta / 3} - x^{E / 3 + 2d})} \\ 7 & \textbf{return } \begin{bmatrix}\operatorname{\mathsf{FFT}}(a_0, d, \delta / 3, E / 3) & \operatorname{\mathsf{FFT}}(a_1, d, \delta / 3, E / 3 + d) & \operatorname{\mathsf{FFT}}(a_2, d, \delta / 3, E / 3 + 2d)\end{bmatrix} \end{array}\]

If we set $a := (A_0(x) + A_1(x)x^d) + (B_0(x) + B_1(x)x^d)y^{\delta / 3} + (C_0(x) + C_1(x)x^d)y^{2\delta / 3}$, then

\[\begin{aligned} a \bmod{(y^{\delta / 3} - x^{E / 3})} &= (A_0(x) + A_1(x)x^d) &+ (B_0(x) + B_1(x)x^d)x^{E / 3} &+ (C_0(x) + C_1(x)x^d)x^{2E / 3} \\ a \bmod{(y^{\delta / 3} - x^{E / 3 + d})} &= (A_0(x) + A_1(x)x^d) &+ (B_0(x) + B_1(x)x^d)x^{E / 3 + d} &+ (C_0(x) + C_1(x)x^d)x^{2E / 3 + 2d} \\ a \bmod{(y^{\delta / 3} - x^{E / 3 + 2d})} &= (A_0(x) + A_1(x)x^d) &+ (B_0(x) + B_1(x)x^d)x^{E / 3 + 2d} &+ (C_0(x) + C_1(x)x^d)x^{2E / 3 + 4d} \end{aligned}\]

we could compute

\[\begin{aligned} (A_0(x) + A_1(x)x^d) &\gets (A_0(x) + A_1(x)x^d) \\ (B_0(x) + B_1(x)x^d) &\gets (B_0(x) + B_1(x)x^d)x^{E / 3} \\ (C_0(x) + C_1(x)x^d) &\gets (C_0(x) + C_1(x)x^d)x^{2E / 3} \end{aligned}\]

first. This could be done inplace. Then we compute

\[\begin{bmatrix} A_0(x) \\ A_1(x) \\ B_0(x) \\ B_1(x) \\ C_0(x) \\ C_1(x) \end{bmatrix} \gets \begin{bmatrix} 1 & 0 & 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 & 0 & 1 \\ 1 & 0 & 0 & -1 & -1 & 1 \\ 0 & 1 & 1 & -1 & -1 & 0 \\ 1 & 0 & -1 & 1 & 0 & -1 \\ 0 & 1 & -1 & 0 & 1 & -1 \end{bmatrix} \begin{bmatrix} A_0(x) \\ A_1(x) \\ B_0(x) \\ B_1(x) \\ C_0(x) \\ C_1(x) \end{bmatrix}\]

and note that

\[\begin{bmatrix} 1 & 0 & 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 & 0 & 1 \\ 1 & 0 & 0 & -1 & -1 & 1 \\ 0 & 1 & 1 & -1 & -1 & 0 \\ 1 & 0 & -1 & 1 & 0 & -1 \\ 0 & 1 & -1 & 0 & 1 & -1 \end{bmatrix}^{-1} = \frac{1}{3} \begin{bmatrix} 1 & 0 & 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 & 0 & 1 \\ 1 & 0 & -1 & 1 & 0 & -1 \\ 0 & 1 & -1 & 0 & 1 & -1 \\ 1 & 0 & 0 & -1 & -1 & 1 \\ 0 & 1 & 1 & -1 & -1 & 0 \end{bmatrix}\]

which gives us the implementation of IFFT.

To reduce $\sharp \times$, we should use Karatsuba’s algorithm when $n$ is so small that cannot be handled by recursion, or sometimes change to Nussbaumer trick as suggested by Bernstein. I learned this from Nachiaさん’s code https://judge.yosupo.jp/submission/351055.

References

  1. Daniel J. Bernstein. “Multidigit multiplication for mathematicians.” Accepted to Advances in Applied Mathematics, but withdrawn by author to prevent irreparable mangling by Academic Press. url: https://cr.yp.to/papers.html#m3
  2. Alin Bostan, Frédéric Chyzak, Marc Giusti, Romain Lebreton, Grégoire Lecerf, Bruno Salvy et Éric Schost. Algorithmes Efficaces en Calcul Formel. 686 pages. Imprimé par CreateSpace. Aussi disponible en version électronique. Palaiseau : Frédéric Chyzak (auto-édit.), sept. 2017. isbn : 979-10-699-0947-2. url: https://hal.science/AECF/
  3. Mateer, Todd, “Fast Fourier Transform Algorithms with Applications” (2008). All Dissertations. 231. url: https://open.clemson.edu/all_dissertations/231

Code

// competitive-verifier: PROBLEM https://judge.yosupo.jp/problem/convolution_mod_2_64

#include <algorithm>
#include <cassert>
#include <iostream>
#include <vector>

using ull = unsigned long long;

constexpr ull InvMod(int a) {
    ull res = 1;
    for (int i = 0; i < 6; ++i) res *= 2ULL - a * res;
    return res;
}

bool IsPowOf3(int a) {
    int b = 1;
    while (b < a) b *= 3;
    return a == b;
}

int PowOf3(int e) {
    for (int x = 3, res = 1;; x *= x) {
        if (e & 1) res *= x;
        if ((e /= 2) == 0) return res;
    }
}

int Log3Ceil(int a) {
    int e = 0;
    for (int c = 1; c < a; c *= 3) ++e;
    return e;
}

int Log3Floor(int a) {
    const int e = Log3Ceil(a);
    return a == PowOf3(e) ? e : e - 1;
}

// Compute a * x^n mod (x^(2*d) + x^d + 1)
// Note that x^(3*d) = 1 mod (x^(2*d) + x^d + 1)
void MultipliedByXToTheN(ull a[], int d, int n) {
    if ((n %= d * 3) < 0) n += d * 3;
    const auto n_leq_d = [](ull a[], int d, int n) {
        assert(n <= d);
        std::rotate(a, a + d * 2 - n, a + d * 2);
        for (int i = 0; i < n; ++i) a[i + d] += (a[i] = -a[i]);
    };
    for (; n >= d; n -= d) n_leq_d(a, d, d);
    if (n) n_leq_d(a, d, n);
}

void FFT3(ull a[], int d, int delta, int E) {
    assert(delta <= d);
    assert(E % delta == 0);
    if (delta == 1) return;
    const int n = d * 2 * (delta / 3);
    for (int i = 0; i < delta / 3; ++i) {
        ull *const b[] = {a + i * d * 2, a + i * d * 2 + n, a + i * d * 2 + n * 2};
        for (int j = 1; j <= 2; ++j) MultipliedByXToTheN(b[j], d, E / 3 * j);
        for (int j = 0; j < d; ++j) {
            enum { L = 0, H = 1 };
            const ull A[] = {b[0][j], b[0][j + d]};
            const ull B[] = {b[1][j], b[1][j + d]};
            const ull C[] = {b[2][j], b[2][j + d]};
            b[0][j]       = A[L] + B[L] + C[L];
            b[0][j + d]   = A[H] + B[H] + C[H];
            b[1][j]       = A[L] - B[H] - C[L] + C[H];
            b[1][j + d]   = A[H] + B[L] - B[H] - C[L];
            b[2][j]       = A[L] - B[L] + B[H] - C[H];
            b[2][j + d]   = A[H] - B[L] + C[L] - C[H];
        }
    }
    for (int i = 0; i < 3; ++i) FFT3(a + n * i, d, delta / 3, E / 3 + d * i);
}

// (R[x] / (x^(2*d) + x^(d) + 1))[y] / (y^(2*delta) + y^(delta) + 1)
//   -> (R[x] / (x^(2*d) + x^d + 1))[y] / (y^delta - x^d)
//   ×  (R[x] / (x^(2*d) + x^d + 1))[y] / (y^delta - x^(2*d))
// then apply radix-3 FFT over (R[x] / (x^(2*d) + x^d + 1))[y] / (y^delta - x^d)
//                    and over (R[x] / (x^(2*d) + x^d + 1))[y] / (y^delta - x^(2*d))
void FFT(ull a[], int d, int delta) {
    ull *const b = a + delta * 2 * d;
    for (int i = 0; i < delta; ++i) {
        ull *const c[] = {a + i * d * 2, b + i * d * 2};
        for (int j = 0; j < d; ++j) {
            enum { L = 0, H = 1 };
            const ull A[] = {c[0][j], c[0][j + d]};
            const ull B[] = {c[1][j], c[1][j + d]};
            c[0][j]       = A[L] - B[H];
            c[0][j + d]   = A[H] + B[L] - B[H];
            c[1][j]       = A[L] - B[L] + B[H];
            c[1][j + d]   = A[H] - B[L];
        }
    }
    // a[] stores S[y] / (y^delta - x^d), b[] stores S[y] / (y^delta - x^(2*d))
    FFT3(a, d, delta, d), FFT3(b, d, delta, d * 2);
}

void InvFFT3(ull a[], int d, int delta, int E) {
    assert(delta <= d);
    assert(E % delta == 0);
    if (delta == 1) return;
    const int n = d * 2 * (delta / 3);
    for (int i = 0; i < 3; ++i) InvFFT3(a + n * i, d, delta / 3, E / 3 + d * i);
    for (int i = 0; i < delta / 3; ++i) {
        ull *const b[] = {a + i * d * 2, a + i * d * 2 + n, a + i * d * 2 + n * 2};
        for (int j = 0; j < d; ++j) {
            enum { L = 0, H = 1 };
            const ull A[] = {b[0][j], b[0][j + d]};
            const ull B[] = {b[1][j], b[1][j + d]};
            const ull C[] = {b[2][j], b[2][j + d]};
            b[0][j]       = A[L] + B[L] + C[L];
            b[0][j + d]   = A[H] + B[H] + C[H];
            b[1][j]       = A[L] - B[L] + B[H] - C[H];
            b[1][j + d]   = A[H] - B[L] + C[L] - C[H];
            b[2][j]       = A[L] - B[H] - C[L] + C[H];
            b[2][j + d]   = A[H] + B[L] - B[H] - C[L];
        }
        for (int j = 1; j <= 2; ++j) MultipliedByXToTheN(b[j], d, E / 3 * -j);
    }
}

void InvFFT(ull a[], int d, int delta) {
    ull *const b = a + delta * 2 * d;
    InvFFT3(a, d, delta, d), InvFFT3(b, d, delta, d * 2);
    const ull inv_3_delta = InvMod(delta * 3);
    for (int i = 0; i < delta; ++i) {
        ull *const c[] = {a + i * d * 2, b + i * d * 2};
        for (int j = 0; j < d; ++j) {
            enum { L = 0, H = 1 };
            const ull A[] = {c[0][j] * inv_3_delta, c[0][j + d] * inv_3_delta};
            const ull B[] = {c[1][j] * inv_3_delta, c[1][j + d] * inv_3_delta};
            c[0][j]       = A[L] + A[H] + B[L] + B[L] - B[H];
            c[0][j + d]   = -A[L] + A[H] + A[H] + B[L] + B[H];
            c[1][j]       = -A[L] + A[H] + A[H] + B[L] - B[H] - B[H];
            c[1][j + d]   = -A[L] - A[L] + A[H] + B[L] + B[L] - B[H];
        }
    }
}

// Compute ab mod (x^(2*n) + x^n + 1)
// see:
// [1]: Daniel J. Bernstein. "Multidigit multiplication for mathematicians."
//      Accepted to Advances in Applied Mathematics,
//      but withdrawn by author to prevent irreparable mangling by Academic Press.
//      https://cr.yp.to/papers.html#m3
void Schoenhage(const ull a[], const ull b[], ull ab[], int n) {
    assert(IsPowOf3(n));
    enum { Threshold = 3 * 3 * 3 };
    static_assert(Threshold >= 3);
    if (n <= Threshold) {
        // (a[L] + a[H]*x^n) * (b[L] + b[H]*x^n)
        //   = (a[L]*b[L]) + (a[L]*b[H] + a[H]*b[L])*x^n + a[H]*b[H]*(-x^n - 1))
        for (int i = 0; i < n; ++i) {
            enum { L = 0, H = 1 };
            const ull A[] = {a[i], a[i + n]};
            for (int j = 0; j < n; ++j) {
                const ull B[]  = {b[j], b[j + n]};
                const ull ALBL = A[L] * B[L];
                const ull AHBH = A[H] * B[H];
                const ull c    = (A[L] + A[H]) * (B[L] + B[H]) - ALBL - AHBH;
                if (i + j < n) {
                    ab[i + j] += ALBL - AHBH;
                    ab[i + j + n] += c - AHBH;
                } else {
                    ab[i + j] += ALBL - c;
                    ab[i + j - n] -= c - AHBH;
                }
            }
        }
        return;
    }
    const int k     = Log3Ceil(n);
    const int d     = PowOf3((k + 1) / 2);
    const int delta = n / d;
    // R[x] / (x^(2 * d * delta) + x^(d * delta) + 1) ->
    //   (R[x][y] / (y^(2 * delta) + y^delta + 1)) / (y - x^d)
    // Lift to R[x][y] / (y^(2 * delta) + y^delta + 1)
    // Since polynomials in R[x][y] / (y^(2*delta) + y^delta + 1) have x-degree < d
    // We could map to (R[x] / (x^(2*d) + x^d + 1))[y] / (y^(2*delta) + y^delta + 1)
    std::vector<ull> a_hat(n * 4), b_hat(n * 4), ab_hat(n * 4);
    for (int i = 0; i < delta * 2; ++i)
        for (int j = 0; j < d; ++j)
            a_hat[i * d * 2 + j] = a[i * d + j], b_hat[i * d * 2 + j] = b[i * d + j];
    FFT(data(a_hat), d, delta), FFT(data(b_hat), d, delta);
    for (int i = 0; i < delta * 2; ++i)
        Schoenhage(data(a_hat) + i * d * 2, data(b_hat) + i * d * 2, data(ab_hat) + i * d * 2, d);
    InvFFT(data(ab_hat), d, delta);
    for (int i = 0; i < delta * 2; ++i)
        for (int j = 0; j < d * 2; ++j)
            if (i * d + j < n * 2) {
                ab[i * d + j] += ab_hat[i * d * 2 + j];
            } else if (i * d + j < n * 3) {
                // x^(2*n) = -x^n - 1
                ab[i * d + j - n * 1] -= ab_hat[i * d * 2 + j];
                ab[i * d + j - n * 2] -= ab_hat[i * d * 2 + j];
            } else {
                __builtin_unreachable();
            }
}

std::vector<ull> Product(std::vector<ull> a, std::vector<ull> b) {
    if (empty(a) || empty(b)) return {};
    const int n = size(a), m = size(b);
    int N = 1;
    while (N < n) N *= 3;
    while (N < m) N *= 3;
    a.resize(N * 2), b.resize(N * 2);
    std::vector<ull> ab(N * 2);
    Schoenhage(data(a), data(b), data(ab), N);
    ab.resize(n + m - 1);
    return ab;
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    int n, m;
    std::cin >> n >> m;
    std::vector<ull> a(n), b(m);
    for (int i = 0; i < n; ++i) std::cin >> a[i];
    for (int i = 0; i < m; ++i) std::cin >> b[i];
    const auto ab = Product(std::move(a), std::move(b));
    for (int i = 0; i < n + m - 1; ++i) std::cout << ab[i] << ' ';
    return 0;
}
#line 1 "standalone_test/convolution/convolution_mod_2_64.radix_3_schoenhage.test.cpp"
// competitive-verifier: PROBLEM https://judge.yosupo.jp/problem/convolution_mod_2_64

#include <algorithm>
#include <cassert>
#include <iostream>
#include <vector>

using ull = unsigned long long;

constexpr ull InvMod(int a) {
    ull res = 1;
    for (int i = 0; i < 6; ++i) res *= 2ULL - a * res;
    return res;
}

bool IsPowOf3(int a) {
    int b = 1;
    while (b < a) b *= 3;
    return a == b;
}

int PowOf3(int e) {
    for (int x = 3, res = 1;; x *= x) {
        if (e & 1) res *= x;
        if ((e /= 2) == 0) return res;
    }
}

int Log3Ceil(int a) {
    int e = 0;
    for (int c = 1; c < a; c *= 3) ++e;
    return e;
}

int Log3Floor(int a) {
    const int e = Log3Ceil(a);
    return a == PowOf3(e) ? e : e - 1;
}

// Compute a * x^n mod (x^(2*d) + x^d + 1)
// Note that x^(3*d) = 1 mod (x^(2*d) + x^d + 1)
void MultipliedByXToTheN(ull a[], int d, int n) {
    if ((n %= d * 3) < 0) n += d * 3;
    const auto n_leq_d = [](ull a[], int d, int n) {
        assert(n <= d);
        std::rotate(a, a + d * 2 - n, a + d * 2);
        for (int i = 0; i < n; ++i) a[i + d] += (a[i] = -a[i]);
    };
    for (; n >= d; n -= d) n_leq_d(a, d, d);
    if (n) n_leq_d(a, d, n);
}

void FFT3(ull a[], int d, int delta, int E) {
    assert(delta <= d);
    assert(E % delta == 0);
    if (delta == 1) return;
    const int n = d * 2 * (delta / 3);
    for (int i = 0; i < delta / 3; ++i) {
        ull *const b[] = {a + i * d * 2, a + i * d * 2 + n, a + i * d * 2 + n * 2};
        for (int j = 1; j <= 2; ++j) MultipliedByXToTheN(b[j], d, E / 3 * j);
        for (int j = 0; j < d; ++j) {
            enum { L = 0, H = 1 };
            const ull A[] = {b[0][j], b[0][j + d]};
            const ull B[] = {b[1][j], b[1][j + d]};
            const ull C[] = {b[2][j], b[2][j + d]};
            b[0][j]       = A[L] + B[L] + C[L];
            b[0][j + d]   = A[H] + B[H] + C[H];
            b[1][j]       = A[L] - B[H] - C[L] + C[H];
            b[1][j + d]   = A[H] + B[L] - B[H] - C[L];
            b[2][j]       = A[L] - B[L] + B[H] - C[H];
            b[2][j + d]   = A[H] - B[L] + C[L] - C[H];
        }
    }
    for (int i = 0; i < 3; ++i) FFT3(a + n * i, d, delta / 3, E / 3 + d * i);
}

// (R[x] / (x^(2*d) + x^(d) + 1))[y] / (y^(2*delta) + y^(delta) + 1)
//   -> (R[x] / (x^(2*d) + x^d + 1))[y] / (y^delta - x^d)
//   ×  (R[x] / (x^(2*d) + x^d + 1))[y] / (y^delta - x^(2*d))
// then apply radix-3 FFT over (R[x] / (x^(2*d) + x^d + 1))[y] / (y^delta - x^d)
//                    and over (R[x] / (x^(2*d) + x^d + 1))[y] / (y^delta - x^(2*d))
void FFT(ull a[], int d, int delta) {
    ull *const b = a + delta * 2 * d;
    for (int i = 0; i < delta; ++i) {
        ull *const c[] = {a + i * d * 2, b + i * d * 2};
        for (int j = 0; j < d; ++j) {
            enum { L = 0, H = 1 };
            const ull A[] = {c[0][j], c[0][j + d]};
            const ull B[] = {c[1][j], c[1][j + d]};
            c[0][j]       = A[L] - B[H];
            c[0][j + d]   = A[H] + B[L] - B[H];
            c[1][j]       = A[L] - B[L] + B[H];
            c[1][j + d]   = A[H] - B[L];
        }
    }
    // a[] stores S[y] / (y^delta - x^d), b[] stores S[y] / (y^delta - x^(2*d))
    FFT3(a, d, delta, d), FFT3(b, d, delta, d * 2);
}

void InvFFT3(ull a[], int d, int delta, int E) {
    assert(delta <= d);
    assert(E % delta == 0);
    if (delta == 1) return;
    const int n = d * 2 * (delta / 3);
    for (int i = 0; i < 3; ++i) InvFFT3(a + n * i, d, delta / 3, E / 3 + d * i);
    for (int i = 0; i < delta / 3; ++i) {
        ull *const b[] = {a + i * d * 2, a + i * d * 2 + n, a + i * d * 2 + n * 2};
        for (int j = 0; j < d; ++j) {
            enum { L = 0, H = 1 };
            const ull A[] = {b[0][j], b[0][j + d]};
            const ull B[] = {b[1][j], b[1][j + d]};
            const ull C[] = {b[2][j], b[2][j + d]};
            b[0][j]       = A[L] + B[L] + C[L];
            b[0][j + d]   = A[H] + B[H] + C[H];
            b[1][j]       = A[L] - B[L] + B[H] - C[H];
            b[1][j + d]   = A[H] - B[L] + C[L] - C[H];
            b[2][j]       = A[L] - B[H] - C[L] + C[H];
            b[2][j + d]   = A[H] + B[L] - B[H] - C[L];
        }
        for (int j = 1; j <= 2; ++j) MultipliedByXToTheN(b[j], d, E / 3 * -j);
    }
}

void InvFFT(ull a[], int d, int delta) {
    ull *const b = a + delta * 2 * d;
    InvFFT3(a, d, delta, d), InvFFT3(b, d, delta, d * 2);
    const ull inv_3_delta = InvMod(delta * 3);
    for (int i = 0; i < delta; ++i) {
        ull *const c[] = {a + i * d * 2, b + i * d * 2};
        for (int j = 0; j < d; ++j) {
            enum { L = 0, H = 1 };
            const ull A[] = {c[0][j] * inv_3_delta, c[0][j + d] * inv_3_delta};
            const ull B[] = {c[1][j] * inv_3_delta, c[1][j + d] * inv_3_delta};
            c[0][j]       = A[L] + A[H] + B[L] + B[L] - B[H];
            c[0][j + d]   = -A[L] + A[H] + A[H] + B[L] + B[H];
            c[1][j]       = -A[L] + A[H] + A[H] + B[L] - B[H] - B[H];
            c[1][j + d]   = -A[L] - A[L] + A[H] + B[L] + B[L] - B[H];
        }
    }
}

// Compute ab mod (x^(2*n) + x^n + 1)
// see:
// [1]: Daniel J. Bernstein. "Multidigit multiplication for mathematicians."
//      Accepted to Advances in Applied Mathematics,
//      but withdrawn by author to prevent irreparable mangling by Academic Press.
//      https://cr.yp.to/papers.html#m3
void Schoenhage(const ull a[], const ull b[], ull ab[], int n) {
    assert(IsPowOf3(n));
    enum { Threshold = 3 * 3 * 3 };
    static_assert(Threshold >= 3);
    if (n <= Threshold) {
        // (a[L] + a[H]*x^n) * (b[L] + b[H]*x^n)
        //   = (a[L]*b[L]) + (a[L]*b[H] + a[H]*b[L])*x^n + a[H]*b[H]*(-x^n - 1))
        for (int i = 0; i < n; ++i) {
            enum { L = 0, H = 1 };
            const ull A[] = {a[i], a[i + n]};
            for (int j = 0; j < n; ++j) {
                const ull B[]  = {b[j], b[j + n]};
                const ull ALBL = A[L] * B[L];
                const ull AHBH = A[H] * B[H];
                const ull c    = (A[L] + A[H]) * (B[L] + B[H]) - ALBL - AHBH;
                if (i + j < n) {
                    ab[i + j] += ALBL - AHBH;
                    ab[i + j + n] += c - AHBH;
                } else {
                    ab[i + j] += ALBL - c;
                    ab[i + j - n] -= c - AHBH;
                }
            }
        }
        return;
    }
    const int k     = Log3Ceil(n);
    const int d     = PowOf3((k + 1) / 2);
    const int delta = n / d;
    // R[x] / (x^(2 * d * delta) + x^(d * delta) + 1) ->
    //   (R[x][y] / (y^(2 * delta) + y^delta + 1)) / (y - x^d)
    // Lift to R[x][y] / (y^(2 * delta) + y^delta + 1)
    // Since polynomials in R[x][y] / (y^(2*delta) + y^delta + 1) have x-degree < d
    // We could map to (R[x] / (x^(2*d) + x^d + 1))[y] / (y^(2*delta) + y^delta + 1)
    std::vector<ull> a_hat(n * 4), b_hat(n * 4), ab_hat(n * 4);
    for (int i = 0; i < delta * 2; ++i)
        for (int j = 0; j < d; ++j)
            a_hat[i * d * 2 + j] = a[i * d + j], b_hat[i * d * 2 + j] = b[i * d + j];
    FFT(data(a_hat), d, delta), FFT(data(b_hat), d, delta);
    for (int i = 0; i < delta * 2; ++i)
        Schoenhage(data(a_hat) + i * d * 2, data(b_hat) + i * d * 2, data(ab_hat) + i * d * 2, d);
    InvFFT(data(ab_hat), d, delta);
    for (int i = 0; i < delta * 2; ++i)
        for (int j = 0; j < d * 2; ++j)
            if (i * d + j < n * 2) {
                ab[i * d + j] += ab_hat[i * d * 2 + j];
            } else if (i * d + j < n * 3) {
                // x^(2*n) = -x^n - 1
                ab[i * d + j - n * 1] -= ab_hat[i * d * 2 + j];
                ab[i * d + j - n * 2] -= ab_hat[i * d * 2 + j];
            } else {
                __builtin_unreachable();
            }
}

std::vector<ull> Product(std::vector<ull> a, std::vector<ull> b) {
    if (empty(a) || empty(b)) return {};
    const int n = size(a), m = size(b);
    int N = 1;
    while (N < n) N *= 3;
    while (N < m) N *= 3;
    a.resize(N * 2), b.resize(N * 2);
    std::vector<ull> ab(N * 2);
    Schoenhage(data(a), data(b), data(ab), N);
    ab.resize(n + m - 1);
    return ab;
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    int n, m;
    std::cin >> n >> m;
    std::vector<ull> a(n), b(m);
    for (int i = 0; i < n; ++i) std::cin >> a[i];
    for (int i = 0; i < m; ++i) std::cin >> b[i];
    const auto ab = Product(std::move(a), std::move(b));
    for (int i = 0; i < n + m - 1; ++i) std::cout << ab[i] << ' ';
    return 0;
}

Test cases

Env Name Status Elapsed Memory
clang++ all_same_00 :heavy_check_mark: AC 1238 ms 232 MB
clang++ all_same_01 :heavy_check_mark: AC 1301 ms 232 MB
clang++ all_same_02 :heavy_check_mark: AC 1294 ms 234 MB
clang++ all_same_03 :heavy_check_mark: AC 1302 ms 234 MB
clang++ example_00 :heavy_check_mark: AC 9 ms 22 MB
clang++ example_01 :heavy_check_mark: AC 8 ms 24 MB
clang++ gen_131072_00 :heavy_check_mark: AC 412 ms 92 MB
clang++ gen_131073_00 :heavy_check_mark: AC 413 ms 90 MB
clang++ gen_177147_00 :heavy_check_mark: AC 430 ms 93 MB
clang++ gen_177148_00 :heavy_check_mark: AC 1194 ms 226 MB
clang++ gen_262144_00 :heavy_check_mark: AC 1216 ms 229 MB
clang++ gen_262145_00 :heavy_check_mark: AC 1232 ms 227 MB
clang++ gen_265721_00 :heavy_check_mark: AC 1223 ms 228 MB
clang++ gen_265722_00 :heavy_check_mark: AC 1236 ms 229 MB
clang++ gen_354294_00 :heavy_check_mark: AC 1258 ms 231 MB
clang++ gen_354295_00 :heavy_check_mark: AC 1259 ms 229 MB
clang++ gen_524288_00 :heavy_check_mark: AC 1308 ms 234 MB
clang++ gen_524288_01 :heavy_check_mark: AC 1323 ms 234 MB
clang++ medium_00 :heavy_check_mark: AC 48 ms 30 MB
clang++ medium_01 :heavy_check_mark: AC 19 ms 23 MB
clang++ medium_02 :heavy_check_mark: AC 20 ms 21 MB
clang++ random_00 :heavy_check_mark: AC 1287 ms 230 MB
clang++ random_01 :heavy_check_mark: AC 1299 ms 231 MB
clang++ random_02 :heavy_check_mark: AC 1213 ms 228 MB
clang++ small_00 :heavy_check_mark: AC 8 ms 18 MB
clang++ small_01 :heavy_check_mark: AC 8 ms 18 MB
clang++ small_02 :heavy_check_mark: AC 8 ms 22 MB
clang++ small_03 :heavy_check_mark: AC 9 ms 22 MB
clang++ small_04 :heavy_check_mark: AC 9 ms 22 MB
clang++ small_05 :heavy_check_mark: AC 9 ms 24 MB
clang++ small_06 :heavy_check_mark: AC 8 ms 22 MB
clang++ small_07 :heavy_check_mark: AC 8 ms 18 MB
clang++ small_08 :heavy_check_mark: AC 9 ms 22 MB
clang++ small_09 :heavy_check_mark: AC 9 ms 22 MB
clang++ small_10 :heavy_check_mark: AC 8 ms 20 MB
clang++ small_11 :heavy_check_mark: AC 8 ms 22 MB
clang++ small_12 :heavy_check_mark: AC 8 ms 22 MB
clang++ small_13 :heavy_check_mark: AC 8 ms 20 MB
clang++ small_14 :heavy_check_mark: AC 8 ms 22 MB
clang++ small_15 :heavy_check_mark: AC 8 ms 22 MB
clang++ small_and_large_00 :heavy_check_mark: AC 1236 ms 228 MB
clang++ small_and_large_01 :heavy_check_mark: AC 1216 ms 229 MB
clang++ small_and_large_02 :heavy_check_mark: AC 1220 ms 228 MB
clang++ small_and_large_03 :heavy_check_mark: AC 1218 ms 226 MB
Back to top page