Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

JOI 春合宿 2012-day3 Sokoban

問題文 : 倉庫番

解法 :
倉庫番の逆」を考える. すなわち, 目標地点から箱を引っ張りまわすBFS を考える. 箱の置き方がO(MN) 個, 隣接する頂点がたかだか4 個であるため, 状態数はO(MN) である.

各状態に遷移したら, 答えには連結成分の大きさを足してやれば良い. 連結成分の大きさは何回かのDFS で求まる.
ある頂点を取り除いた際の4 つの頂点の接続関係を用いれば, うまくこの足す作業を行うことができる.

コード :

#include <cstdio>
#include <cstring>
#include <queue>
#include <algorithm>

using namespace std;

typedef long long lint;

int M, N;
int reachSize;
int size[1024][1024][4];
int type[1024][1024][4];
char map[1024][1024];
bool vis[1024][1024];
bool done[1024][1024][4];

int dy[] = {0, 1, 0, -1}, dx[] = {1, 0, -1, 0};

struct Event {
	int ty, tx, dir;
	Event(int ty, int tx, int dir) : ty(ty), tx(tx), dir(dir) {}
	Event(){}
};

bool valid(int ny, int nx)
{
	return (0 <= ny && ny < M && 0 <= nx && nx < N && map[ny][nx] != '#');
}

int rev(int dir)
{
	if (dir >= 2) return (dir - 2);
	return (dir + 2);
}

int countReach(int sy, int sx)
{
	int ret = 1;
	vis[sy][sx] = true;
	
	for (int i = 0; i < 4; i++){
		int ny = sy + dy[i], nx = sx + dx[i];
		if (valid(ny, nx) && !vis[ny][nx]){
			ret += countReach(ny, nx);
		}
	}
	return (ret);
}

int fill(int my, int mx, int dir)
{
	int ret = 1;
	vis[my][mx] = true;
	
	for (int i = 0; i < 4; i++){
		int ny = my + dy[i], nx = mx + dx[i];
		if (valid(ny, nx) && !vis[ny][nx]){
			int c = fill(ny, nx, i);
			ret += c;
			size[my][mx][i] = c;
		}
	}
	
	if (~dir) size[my][mx][rev(dir)] = reachSize - ret;
	
	return (ret);
}
int t[1024][1024];
int ord[1024][1024], fin[1024][1024], low[1024][1024];
vector<pair<int, int> > chi[1024][1024];

void predfs(int sy, int sx, int py, int px, int &k)
{
	vis[sy][sx] = true;
	ord[sy][sx] = low[sy][sx] = k++;
	
	for (int i = 0; i < 4; i++){
		int ny = sy + dy[i], nx = sx + dx[i];
		if (valid(ny, nx) && !vis[ny][nx]){
			predfs(ny, nx, sy, sx, k);
			chi[sy][sx].push_back(make_pair(ny, nx));
			low[sy][sx] = min(low[sy][sx], low[ny][nx]);
		}
		else if (valid(ny, nx) && !(ny == py && nx == px)){
			low[sy][sx] = min(low[sy][sx], ord[ny][nx]);
		}
	}
	
	fin[sy][sx] = k++;
}

bool isDecendant(int uy, int ux, int vy, int vx)
{
	return (ord[vy][vx] <= ord[uy][ux] && fin[uy][ux] <= fin[vy][vx]);
}

pair<int, int> getChild(int uy, int ux, int vy, int vx)
{
	for (int i = 0; i < chi[uy][ux].size(); i++){
		if (isDecendant(vy, vx, chi[uy][ux][i].first, chi[uy][ux][i].second)) return (chi[uy][ux][i]);
	}
	return (make_pair(-1, -1));
}

bool isConnected(int ny, int nx, int ny2, int nx2, int vy, int vx)
{
	if (!isDecendant(ny, nx, vy, vx) && !isDecendant(ny2, nx2, vy, vx))
		return (true);
	else if (isDecendant(ny, nx, vy, vx) && isDecendant(ny2, nx2, vy, vx)){
		pair<int, int> u1 = getChild(vy, vx, ny, nx), u2 = getChild(vy, vx, ny2, nx2);
		
		if (u1 == u2 ||
			(low[u1.first][u1.second] < ord[vy][vx] && low[u2.first][u2.second] < ord[vy][vx])){
			return (true);
		}
	}
	else {
		pair<int, int> u;
		if (isDecendant(ny, nx, vy, vx)) u = getChild(vy, vx, ny, nx);
		else u = getChild(vy, vx, ny2, nx2);
		
		if (low[u.first][u.second] < ord[vy][vx]) return (true);
	}
	return (false);
}

void dfs(int vy, int vx)
{
	vis[vy][vx] = true;
	if (!valid(vy, vx)) return;
	for (int i = 0; i < 4; i++) type[vy][vx][i] = i;
	
	for (int i = 0; i < 4; i++){
		int ny = vy + dy[i], nx = vx + dx[i];
		for (int j = i + 1; j < 4; j++){
			int ny2 = vy + dy[j], nx2 = vx + dx[j];
			if (!valid(ny, nx) || !valid(ny2, nx2)) continue;
			if (isConnected(ny, nx, ny2, nx2, vy, vx)){
				type[vy][vx][j] = min(type[vy][vx][j], type[vy][vx][i]);
				type[vy][vx][i] = min(type[vy][vx][j], type[vy][vx][i]);
			}
		}
	}
	
	for (int i = 0; i < 4; i++){
		int ny = vy + dy[i], nx = vx + dx[i];
		if (valid(ny, nx) && !vis[ny][nx]) dfs(ny, nx);
	}
}

int main()
{
	
	scanf("%d %d", &M, &N);
	
	int sy, sx;
	for (int i = 0; i < M; i++){
		scanf("%s", map[i]);
		for (int j = 0; j < N; j++){
			if (map[i][j] == 'X'){
				sy = i; sx = j;
			}
		}
	}
	
	memset(vis, 0, sizeof(vis));
	reachSize = countReach(sy, sx) - 1;
	
	memset(vis, 0, sizeof(vis));
	fill(sy, sx, -1);
	
	memset(vis, 0, sizeof(vis));
	int k = 0;
	predfs(sy, sx, -1, -1, k);
	
	memset(vis, 0, sizeof(vis));
	dfs(sy, sx);
	
	lint ans = 0;
	queue<Event> q;
	for (int i = 0; i < 4; i++){
		done[sy][sx][i] = true;
		if (valid(sy + dy[i], sx + dx[i])){
			q.push(Event(sy + dy[i], sx + dx[i], i));
		}
	}
	
	for (; q.size(); q.pop()){
		Event x = q.front();
		
		if (done[x.ty][x.tx][type[x.ty][x.tx][x.dir]]) continue;
		
		int ny = x.ty + dy[x.dir], nx = x.tx + dx[x.dir];
		if (valid(ny, nx)){
			for (int i = 0; i < 4; i++){
				if (type[x.ty][x.tx][i] == type[x.ty][x.tx][x.dir]){
					ans += size[x.ty][x.tx][i];
					t[x.ty][x.tx] += size[x.ty][x.tx][i];
				}
			}
			
			for (int i = 0; i < 4; i++){
				if (type[x.ty][x.tx][i] == type[x.ty][x.tx][x.dir])
					if (valid(x.ty + dy[i], x.tx + dx[i])) q.push(Event(x.ty + dy[i], x.tx + dx[i], i));
			}
			done[x.ty][x.tx][type[x.ty][x.tx][x.dir]] = true;
		}
		
	}
	
	printf("%lld\n", ans);
	
	return (0);
}