hly1204's library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub hly1204/library

:heavy_check_mark: test/number_theory/sqrt_mod.0.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/sqrt_mod"

#include "rng.hpp"
#include "xgcd.hpp"
#include <iostream>
#include <random>

int inv_mod(int a, int mod) {
    const int res = std::get<0>(inv_gcd(a, mod));
    return res < 0 ? res + mod : res;
}

int pow_mod(int a, int e, int mod) {
    if (e < 0) return pow_mod(inv_mod(a, mod), -e, mod);
    for (int res = 1;; a = (long long)a * a % mod) {
        if (e & 1) res = (long long)res * a % mod;
        if ((e /= 2) == 0) return res;
    }
}

// Tonelli--Shanks's algorithm
// see:
// [1]: Daniel. J. Bernstein. Faster Square Roots in Annoying Finite Fields.
int sqrt_mod(int a, int mod) {
    // mod must be prime
    if (a == 0 || mod == 2) return a;

    auto is_square = [](int a, int mod) { return pow_mod(a, (mod - 1) / 2, mod) == 1; };

    if (!is_square(a, mod)) return -1;

    static xoshiro256starstar rng{std::random_device{}()};
    std::uniform_int_distribution<> dis(2, mod - 1);

    int r;
    do { r = dis(rng); } while (is_square(r, mod));

    int n = 1, m = (mod - 1) / 2;
    while (m % 2 == 0) ++n, m /= 2;
    // mod = 2^n m

    const int am = pow_mod(a, m, mod);
    // ord(c) = 2^n
    const int c = pow_mod(r, m, mod);

    // find e such that a^m=c^e
    int e = 0;
    for (int i = 1, j = 2; i < n; ++i, j *= 2) {
        // One can reduce the constant factor by
        // calculating something during the iteration,
        // but it is not necessary.
        if (pow_mod((long long)am * pow_mod(c, -e, mod) % mod, (mod - 1) / (m * j * 2), mod) == 1)
            continue;
        e += j;
    }

    // now set m=2j+1 => a^(2j)a=c^e => a=c^ea^(-2j)
    return (long long)pow_mod(c, e / 2, mod) * pow_mod(a, -(m / 2), mod) % mod;
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    int T;
    std::cin >> T;
    while (T--) {
        int a, mod;
        std::cin >> a >> mod;
        std::cout << sqrt_mod(a, mod) << '\n';
    }
    return 0;
}
#line 1 "test/number_theory/sqrt_mod.0.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/sqrt_mod"

#line 2 "rng.hpp"

#include <cstdint>
#include <limits>

// see: https://prng.di.unimi.it/xoshiro256starstar.c
// original license CC0 1.0
class xoshiro256starstar {
    using u64 = std::uint64_t;

    static inline u64 rotl(const u64 x, int k) { return (x << k) | (x >> (64 - k)); }

    u64 s_[4];

    u64 next() {
        const u64 res = rotl(s_[1] * 5, 7) * 9;
        const u64 t   = s_[1] << 17;
        s_[2] ^= s_[0];
        s_[3] ^= s_[1];
        s_[1] ^= s_[2];
        s_[0] ^= s_[3];
        s_[2] ^= t;
        s_[3] = rotl(s_[3], 45);
        return res;
    }

public:
    // see: https://prng.di.unimi.it/splitmix64.c
    // original license CC0 1.0
    explicit xoshiro256starstar(u64 seed) {
        for (int i = 0; i < 4; ++i) {
            u64 z = (seed += 0x9e3779b97f4a7c15);
            z     = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
            z     = (z ^ (z >> 27)) * 0x94d049bb133111eb;
            s_[i] = z ^ (z >> 31);
        }
    }
    // see: https://en.cppreference.com/w/cpp/named_req/UniformRandomBitGenerator
    using result_type = u64;
    static constexpr u64 min() { return std::numeric_limits<u64>::min(); }
    static constexpr u64 max() { return std::numeric_limits<u64>::max(); }
    u64 operator()() { return next(); }
};
#line 2 "xgcd.hpp"

#include <array>
#include <type_traits>
#include <utility>

// returns [x, y, gcd(a, b)] s.t. ax+by = gcd(a, b)
template<typename Int>
inline std::enable_if_t<std::is_signed_v<Int>, std::array<Int, 3>> xgcd(Int a, Int b) {
    Int x11 = 1, x12 = 0, x21 = 0, x22 = 1;
    while (b) {
        std::add_const_t<Int> q = a / b;
        x11                     = std::exchange(x21, x11 - x21 * q);
        x12                     = std::exchange(x22, x12 - x22 * q);
        a                       = std::exchange(b, a - b * q);
    }
    return {x11, x12, a};
}

// returns [a^(-1) mod b, gcd(a, b)]
template<typename Int>
inline std::enable_if_t<std::is_signed_v<Int>, std::array<Int, 2>> inv_gcd(Int a, Int b) {
    Int x11 = 1, x21 = 0;
    while (b) {
        std::add_const_t<Int> q = a / b;
        x11                     = std::exchange(x21, x11 - x21 * q);
        a                       = std::exchange(b, a - b * q);
    }
    return {x11, a}; // check x11 < 0, check a = 1
}
#line 5 "test/number_theory/sqrt_mod.0.test.cpp"
#include <iostream>
#include <random>

int inv_mod(int a, int mod) {
    const int res = std::get<0>(inv_gcd(a, mod));
    return res < 0 ? res + mod : res;
}

int pow_mod(int a, int e, int mod) {
    if (e < 0) return pow_mod(inv_mod(a, mod), -e, mod);
    for (int res = 1;; a = (long long)a * a % mod) {
        if (e & 1) res = (long long)res * a % mod;
        if ((e /= 2) == 0) return res;
    }
}

// Tonelli--Shanks's algorithm
// see:
// [1]: Daniel. J. Bernstein. Faster Square Roots in Annoying Finite Fields.
int sqrt_mod(int a, int mod) {
    // mod must be prime
    if (a == 0 || mod == 2) return a;

    auto is_square = [](int a, int mod) { return pow_mod(a, (mod - 1) / 2, mod) == 1; };

    if (!is_square(a, mod)) return -1;

    static xoshiro256starstar rng{std::random_device{}()};
    std::uniform_int_distribution<> dis(2, mod - 1);

    int r;
    do { r = dis(rng); } while (is_square(r, mod));

    int n = 1, m = (mod - 1) / 2;
    while (m % 2 == 0) ++n, m /= 2;
    // mod = 2^n m

    const int am = pow_mod(a, m, mod);
    // ord(c) = 2^n
    const int c = pow_mod(r, m, mod);

    // find e such that a^m=c^e
    int e = 0;
    for (int i = 1, j = 2; i < n; ++i, j *= 2) {
        // One can reduce the constant factor by
        // calculating something during the iteration,
        // but it is not necessary.
        if (pow_mod((long long)am * pow_mod(c, -e, mod) % mod, (mod - 1) / (m * j * 2), mod) == 1)
            continue;
        e += j;
    }

    // now set m=2j+1 => a^(2j)a=c^e => a=c^ea^(-2j)
    return (long long)pow_mod(c, e / 2, mod) * pow_mod(a, -(m / 2), mod) % mod;
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
    int T;
    std::cin >> T;
    while (T--) {
        int a, mod;
        std::cin >> a >> mod;
        std::cout << sqrt_mod(a, mod) << '\n';
    }
    return 0;
}
Back to top page