Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

準急さんの問題

昔semiexpさんがこんな問題をつぶやいていたので、解いてみました. 結論から言うと, 僕の環境だと5秒かかります (絶望)

解法 :
二項定理から, (x+1)^n = Σ[i = 0, n] nCi * x^(n-i) と表せるので, x^0 から x^n の値がわかれば, (x + 1) ^ 0 から (x + 1) ^ n の値を求めることが出来る.
これは、線形性が成り立つ演算なので(掛け算と足し算で表されるから), 繰り返し二乗法を使うことで, x^k を, O(k^3 log x) 時間で求めることが出来る.

(ここからすべて0-indexed)

ここで, i番目の成分をx^iとした列ベクトルx_を考える. このとき, (x+1)_ は, 先ほどの議論より, (i, j)成分を iCj としたパスカルの三角形の行列 A を用いて, Ax_ と表すことが出来る.

ここで, 求めたいのは1^k+2^k+...N^kで, x_のk番目の成分を足し合わせればよい. ここで, x_のk番目の成分は, A^x の (k, 0)成分であることに注意すると, (E+)A+A^2+A^3+...+A^xの(k, 0)成分がわかれば, それが答えになっているはずである.
この行列の和も、繰り返し二乗法により O(k^3 log N) 時間で求めることができる.

ゆえに, 全体の計算量は O(k^3 log N) 時間である. (10^7くらいだけど、定数が重めらしく絶望)

コード :

#include <cstdio>
#include <cstring>
#include <vector>

using namespace std;

typedef long long lint;
typedef vector<lint> vec;
typedef vector<vec> matrix;

lint n, k, mod;

lint memo[128][128];

lint comb(int n, int r)
{
	if (n < r) return (0);
	if (!r || n == r) return (1 % mod);
	if (memo[n][r] >= 0) return (memo[n][r]);
	
	return (memo[n][r] = (comb(n - 1, r) + comb(n - 1, r - 1)) % mod);
}

matrix mul(matrix &A, matrix &B)
{
    matrix C(A.size(), vec(B[0].size()));
    
    for (int i = 0; i < A.size(); i++){
        for (int j = 0; j < B[0].size(); j++){
            for (int k = 0; k < A[0].size(); k++){
                C[i][j] = (C[i][j] + A[i][k] * B[k][j]) % mod;
            }
        }
    }
    
    return (C);
}

matrix pow(matrix &A, lint n)
{
    matrix B(A.size(), vec(A.size()));
    
    for (int i = 0; i < A.size(); i++){
        B[i][i] = 1;
    }
    
    while (n > 0){
        if (n & 1){
            B = mul(B, A);
        }
        A = mul(A, A);
        n >>= 1;
    }
    
    return (B);
}

int main()
{
    scanf("%d %d %d", &k, &n, &mod);
    k++;
    memset(memo, -1, sizeof(memo));
    matrix A(2 * k, vec(2 * k));
    
    for (int i = 0; i < k; i++){
        for (int j = 0; j < k; j++){
            A[i][j] = comb(i, j);
        }
        A[k + i][i] = A[k + i][k + i] = 1;
    }
    
    A = pow(A, n + 1);

    printf("%lld\n", A[2 * k - 1][0]);
    
    return (0);
}