読者です 読者をやめる 読者になる 読者になる

Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

Codeforces 226E - More Queries to Array...

問題文 : More Queries to Array

問題概要:
配列 $ \normalsize a$ が与えられるので,

  • 区間 $ \normalsize [l,\ r] $ の値を $ \normalsize x$ に変更する.
  • $ \normalsize \displaystyle \sum_{i = l}^{r} a_{i} \cdot (i - l + 1)^k$ (ただし $\normalsize 0 \leq k \leq 5$.

...というクエリに $ \normalsize m $ 個答える問題.

解法 :
遅延評価を用いて, $ \normalsize \displaystyle \sum_{i = l}^{r} i^k \cdot a_{i}$ を持つセグメント木を 6 個作っておけば, 二項定理より,


$ \displaystyle \sum_{i = l}^{r} a_{i} \cdot (i - l + 1)^k = \displaystyle \sum_{j = 0}^{k} \biggl( {}_k\text{C}_j \cdot (1 - l)^j \cdot \sum_{i = l}^{r}i^{k-j} \cdot a_{i} \biggr)$

となり, $O (m \log n)$ 時間でどちらのクエリにも答えることができる.

コード :

#include <cstdio>
#include <algorithm>

#define MOD (1000000007)

using namespace std;
typedef long long lint;

struct Node {
	lint sum, lazy;
	Node(){
		sum = 0, lazy = -1;
	}
} seg[6][1 << 18];

lint pSum[6][1 << 17];

inline void evaluate(int k, int power, int l, int r)
{
	lint _pSum = (pSum[power][r] - pSum[power][l] + MOD) % MOD;
	if (~seg[power][k].lazy) seg[power][k].sum = seg[power][k].lazy * _pSum % MOD;
	
	if (k < (1 << 17) - 1 && ~seg[power][k].lazy){
		seg[power][k * 2 + 1].lazy = seg[power][k].lazy;
		seg[power][k * 2 + 2].lazy = seg[power][k].lazy;
	}
	
	seg[power][k].lazy = -1;
}

inline void _update(int power, int k)
{
	seg[power][k].sum = (seg[power][k * 2 + 1].sum + seg[power][k * 2 + 2].sum) % MOD;
}

void update(int a, int b, int x, int power, int k = 0, int l = 0, int r = 1 << 17)
{
	evaluate(k, power, l, r);
	if (b <= l || r <= a) return;
	
	if (a <= l && r <= b){
		seg[power][k].lazy = x;
		evaluate(k, power, l, r);
		return;
	}
	
	update(a, b, x, power, k * 2 + 1, l, (l + r) / 2);
	update(a, b, x, power, k * 2 + 2, (l + r) / 2, r);
	_update(power, k);
}

lint getSum(int a, int b, int power, int k = 0, int l = 0, int r = 1 << 17)
{
	evaluate(k, power, l, r);
	if (b <= l || r <= a) return (0);
	
	if (a <= l && r <= b){
		return (seg[power][k].sum);
	}
	lint sum = 0;
	sum += getSum(a, b, power, k * 2 + 1, l, (l + r) / 2);
	sum += getSum(a, b, power, k * 2 + 2, (l + r) / 2, r);
	_update(power, k);
	
	return (sum % MOD);
}

int main()
{
	int C[6][6] = {0};
	
	for (int i = 0; i < 6; i++){
		C[i][0] = 1;
		for (int j = 1; j <= i; j++){
			C[i][j] = C[i - 1][j] + C[i - 1][j - 1];
		}
	}
	
	for (int i = 0; i < 6; i++){
		for (int j = 1; j < 1 << 17; j++){
			lint v = 1;
			for (int k = 0; k < i; k++) v = (v * j) % MOD;
			pSum[i][j] = (pSum[i][j - 1] + v) % MOD;
		}
	}
	
	int N, Q;
	scanf("%d %d", &N, &Q);
	
	for (int i = 0; i < N; i++){
		int t;
		scanf("%d", &t);
		for (int j = 0; j < 6; j++) update(i, i + 1, t, j);
	}
	
	for (int i = 0; i < Q; i++){
		char Qtype[6];
		int arg[3];
		
		scanf("%s", Qtype);
		for (int j = 0; j < 3; j++) scanf("%d", arg + j);
		
		if (Qtype[0] == '='){
			for (int j = 0; j < 6; j++) update(arg[0] - 1, arg[1], arg[2], j);
		}
		else {
			lint res = 0, mul = 1;
			for (int j = arg[2]; j >= 0; j--){
				res = (res + (mul * C[arg[2]][j]) % MOD * getSum(arg[0] - 1, arg[1], j) + MOD) % MOD;
				mul = (mul * (1 - arg[0]) + MOD) % MOD;
			}
			printf("%lld\n", res);
		}
	}
	
	return (0);
}