読者です 読者をやめる 読者になる 読者になる

Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

ICPC Asia Tokyo Regional 2014 G - Flipping Parentheses

問題文

これ

概要

対応のとれた括弧からなる文字列が与えられる.
括弧の向きを変えるというクエリが大量にくるから, 最も左の別の括弧の向きを変えてまた対応がとれているようにせよ.

解法

括弧の対応を列として考える. 具体的には, 開き括弧を +1, 閉じ括弧を -1 として列をつくる.
この列の累積和テーブルにおいて, 負の項が生じないように最も右の括弧を取ればよい.

開き括弧から閉じ括弧にある括弧を変えるのはコストを 2 減らし, 逆はコストを 2 増やせば良い.
この操作ができるデータ構造として, Starry Sky 木が挙げられる.

Segment Tree は定数が重いので, 不必要なクエリは出来るだけ投げないように気をつけるべき(TLE が出続けていた).

コード

w/遅延評価版と w/o遅延評価版を書きました.

#include <bits/stdc++.h>

using namespace std;

typedef struct {
	int val;
	int add;
} Node;

Node seg[1 << 20];
int n, q;
char s[300001];

inline void evaluate(int idx)
{
	seg[idx].val += seg[idx].add;
	if (idx < (1 << 19) - 1){
		seg[idx * 2 + 1].add += seg[idx].add;
		seg[idx * 2 + 2].add += seg[idx].add;
	}
	
	seg[idx].add = 0;
}

void update(int k)
{
	seg[k].val = min(seg[k * 2 + 1].val, seg[k * 2 + 2].val);
}

void add(int a, int b, int x, int k = 0, int l = 0, int r = n)
{
	evaluate(k);
	if (r <= a || b <= l) return;
	
	if (a <= l && r <= b){
		seg[k].add += x;
		evaluate(k);
		return;
	}
	
	add(a, b, x, k * 2 + 1, l, (l + r) / 2);
	add(a, b, x, k * 2 + 2, (l + r) / 2, r);
	update(k);
}

int getMin(int a, int b, int k = 0, int l = 0, int r = n)
{
	evaluate(k);
	if (r <= a || b <= l) return (INT_MAX);
	
	if (a <= l && r <= b){
		evaluate(k);
		return (seg[k].val);
	}
	
	int left = getMin(a, b, k * 2 + 1, l, (l + r) / 2);
	int right = getMin(a, b, k * 2 + 2, (l + r) / 2, r);
	update(k);
	
	return (min(left, right));
	
}

void fix(set<int> *a, set<int> *b, char before, int x)
{
    a->erase(x);
    b->insert(x);
    
    int base = before == '(' ? -1 : 1;
    s[x] = before == '(' ? ')' : '(';
    
    add(x, n, 2 * base);
    
    int l = 0, r = x;
    int p;
    
    while (l != r){
        int mid = l + r >> 1;
        p = *b->lower_bound(mid);
        if (getMin(p, n) >= 2 * base) r = mid;
        else l = mid + 1;
    }
    
    p = *b->lower_bound(l);
    add(p, n, -2 * base);
    b->erase(p);
    a->insert(p);
    s[p] = before;
    
    printf("%d\n", p + 1);
}

int main()
{
    scanf("%d %d", &n, &q);
    scanf("%s", s);
    
    set<int> lf, rg;
    
    for (int i = 0; s[i]; i++){
        if (s[i] == '('){
            add(i, n, 1);
            lf.insert(i);
        }
        else {
            add(i, n, -1);
            rg.insert(i);
        }
    }
    
    for (int i = 0; i < q; i++){
        int x;
        scanf("%d", &x); x--;
        
        if (s[x] == '(') fix(&lf, &rg, '(', x);
        else fix(&rg, &lf, ')', x);
    }
    
    return (0);
}
w/o 遅延評価 (区間加算の定数がでかい)
#include <bits/stdc++.h>
 
using namespace std;
 
int segMin[1 << 20], segAdd[1 << 20];
int n, q;
char s[300001];
 
void add(int a, int b, int x, int k = 0, int l = 0, int r = n)
{
    if (r <= a || b <= l) return;
     
    if (a <= l && r <= b){
        segAdd[k] += x;
        while (k){
            k = (k - 1) / 2;
            segMin[k] = min(segMin[k * 2 + 1] + segAdd[k * 2 + 1], segMin[k * 2 + 2] + segAdd[k * 2 + 2]);
        }
        return;
    }
     
    add(a, b, x, k * 2 + 1, l, (l + r) / 2);
    add(a, b, x, k * 2 + 2, (l + r) / 2, r);
}
 
int getMin(int a, int b, int k = 0, int l = 0, int r = n)
{
    if (r <= a || b <= l) return (INT_MAX);
     
    if (a <= l && r <= b) return (segMin[k] + segAdd[k]);
     
    int left = getMin(a, b, k * 2 + 1, l, (l + r) / 2);
    int right = getMin(a, b, k * 2 + 2, (l + r) / 2, r);
     
    return (min(left, right) + segAdd[k]);
     
}
 
void fix(set<int> *a, set<int> *b, char before, int x)
{
    a->erase(x);
    b->insert(x);
     
    int base = before == '(' ? -1 : 1;
    s[x] = before == '(' ? ')' : '(';
     
    add(x, n, 2 * base);
     
    int l = 0, r = x;
    int p;
     
    while (l != r){
        int mid = l + r >> 1;
        p = *b->lower_bound(mid);
        if (getMin(p, n) >= 2 * base) r = mid;
        else l = mid + 1;
    }
     
    p = *b->lower_bound(l);
    add(p, n, -2 * base);
    b->erase(p);
    a->insert(p);
    s[p] = before;
     
    printf("%d\n", p + 1);
}
 
int main()
{
    scanf("%d %d", &n, &q);
    scanf("%s", s);
     
    set<int> lf, rg;
     
    for (int i = 0; s[i]; i++){
        if (s[i] == '('){
            add(i, n, 1);
            lf.insert(i);
        }
        else {
            add(i, n, -1);
            rg.insert(i);
        }
    }
     
    for (int i = 0; i < q; i++){
        int x;
        scanf("%d", &x); x--;
         
        if (s[x] == '(') fix(&lf, &rg, '(', x);
        else fix(&rg, &lf, ')', x);
    }
     
    return (0);
}