読者です 読者をやめる 読者になる 読者になる

Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

JOI春合宿 2012-day1 fish

問題文 :

解法 :
初めに, 尺取り法で, 起こりうる魚の組み合わせ(赤a 匹, 緑b 匹, 青c 匹 以下の組み合わせなら全て作成可能である) というものを列挙する. これをa * b * c の直方体として扱う.
その後に, セグメント木で体積を, 以下の要領で求めていく.

①直方体をa の大きさで降順にソートする.
②0 ~ b[i] の範囲の大きさを, 元の大きさと値c[i]のうち大きい方に更新する.
③0 ~ 500000 の範囲の和 * (a[i - 1] - a[i]) を合計の体積に加える.

これは, 2 つのセグメント木を用いることで, O(N log N) で実現できる.

コード :

#include <cstdio>
#include <vector>
#include <algorithm>

#define MAX_N (500000)

using namespace std;

typedef long long lint;
typedef pair<lint, pair<lint, lint> > T;

const int sz = 1 << 19;

struct Node {
	lint val, lazy;
} seg[2 * sz];

void changeMax(int a, int b, lint x, int k = 0, int l = 0, int r = sz)
{
	seg[k].val = max(seg[k].val, (r - l) * seg[k].lazy);
	if (k < sz - 1){
		seg[k * 2 + 1].lazy = max(seg[k * 2 + 1].lazy, seg[k].lazy);
		seg[k * 2 + 2].lazy = max(seg[k * 2 + 2].lazy, seg[k].lazy);
	}
	seg[k].lazy = 0;
	if (b <= a || b <= l || r <= a) return;
	if (a <= l && r <= b){
		seg[k].lazy = max(seg[k].lazy, x);
		seg[k].val = max(seg[k].val, (r - l) * seg[k].lazy);
		if (k < sz - 1){
			seg[k * 2 + 1].lazy = max(seg[k * 2 + 1].lazy, seg[k].lazy);
			seg[k * 2 + 2].lazy = max(seg[k * 2 + 2].lazy, seg[k].lazy);
		}
		seg[k].lazy = 0;
		return;
	}
	changeMax(a, b, x, k * 2 + 1, l, (l + r) / 2);
	changeMax(a, b, x, k * 2 + 2, (l + r) / 2, r);
	seg[k].val = seg[k * 2 + 1].val + seg[k * 2 + 2].val;
}
 
inline lint getSum()
{
	seg[0].val = max(seg[0].val, sz * seg[0].lazy);
	seg[1].lazy = max(seg[1].lazy, seg[0].lazy);
	seg[2].lazy = max(seg[2].lazy, seg[0].lazy);
	return (seg[0].val);
}

int seg2[sz * 2];

void set(int pos, int val)
{
	pos += sz - 1;
	while (pos){
		seg2[pos] = max(seg2[pos], val);
		pos = (pos - 1) / 2;
	}
}

int find(int val)
{
	int pos = 0;
	while (pos < sz - 1){
		if (seg2[pos * 2 + 2] > val) pos = pos * 2 + 2;
		else pos = pos * 2 + 1;
	}
	return (pos - sz + 1);
}

int n;
pair<int, char> dat[MAX_N];
T v[2 * MAX_N];
int main()
{
	scanf("%d", &n);
	
	for (int i = 0; i < n; i++){
		char buf[2];
		scanf("%d %s", &dat[i].first, buf);
		dat[i].second = buf[0];
	}
	
	sort(dat, dat + n);
	reverse(dat, dat + n);
	
	int head, tail;
	head = tail = 0;
	
	int rcount = 0, gcount = 0, bcount = 0;
	int rec = 0;
	while (head != n){
		while (tail != n && dat[tail].first * 2 > dat[head].first){
			(dat[tail].second == 'R') ? (rcount++) : ((dat[tail].second == 'G') ? (gcount++) : (bcount++));
			tail++;
		}
		
		v[rec++] = make_pair(rcount + 1, make_pair(gcount + 1, bcount + 1));
		
		(dat[head].second == 'R') ? (rcount--) : ((dat[head].second == 'G') ? (gcount--) : (bcount--));
		head++;
	}
	v[rec++] = (make_pair(0, make_pair(0, 0)));
	sort(v, v + rec);
	reverse(v, v + rec);
	
	lint ans = 0;
	
	for (int i = 0; i < rec; i++){
		if (i) ans += getSum() * (v[i - 1].first - v[i].first);
		if (v[i].first == 0) break;
		int pos = find(v[i].second.second);
		if (seg2[pos + sz - 1] > v[i].second.second) pos++;
		changeMax(pos, v[i].second.first, v[i].second.second);
		set(v[i].second.first - 1, v[i].second.second);
	}
	
	printf("%lld\n", ans - 1);
	
	return (0);
}