読者です 読者をやめる 読者になる 読者になる

Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

Codeforces 342E - Xenia and Tree

問題文 : Xenia and Tree

概要 : 大きさ$n$ の木が与えられる. 最初頂点0 は赤色で, その他の頂点は青色である. 次の$m$ 個のクエリに答えよ :

・頂点$v$ の色を赤色に変える.
・頂点$v$ から赤色のノードまでの最短距離を求める.

$1 \leqq n,\ m \leqq 10^5$


解法 :

クエリを平方分割する. $\sqrt{m}$ 個のクエリごとにそのクエリブロックの赤色のノードから$O(n+m)$ 時間でbfs を行い, 赤色のノードからの最短距離を更新する作業を行えば, 最短距離クエリに対しては$O(\sqrt{m})$ 個の赤ノードからの距離だけを比較すれば良いことになるから, クエリを$O(\sqrt{m} \log n)$ 時間, 各段階での処理の合計を$O((n+m)\sqrt{m})$ 時間で行うことができ, $O((n+m)\sqrt{m} \log n)$ 時間でこの問題を解くことが出来る. 定数が軽いのでこれで通る.

bfs では同じノードを二回通らないように気をつける. また, 最短距離はLCA で求めれば良い.

コード :

#include <cstdio>
#include <cstring>
#include <queue>
#include <vector>
#include <algorithm>

using namespace std;

vector<int> G[100000];
int depth[100000], par[17][100000], d[100000];

void dfs(int v, int p, int dep)
{
	par[0][v] = p;
	depth[v] = d[v] = dep;
	for (int i = 0; i < G[v].size(); i++){
		if (G[v][i] != p){
			dfs(G[v][i], v, dep + 1);
		}
	}
}

void init(int n)
{
	memset(par, -1, sizeof(par));
	dfs(0, -1, 0);
	
	for (int i = 0; i + 1 < 17; i++){
		for (int j = 0; j < n; j++){
			if (~par[i][j]) par[i + 1][j] = par[i][par[i][j]];
		}
	}
}

int lca(int u, int v)
{
	if (depth[u] < depth[v]) swap(u, v);
	
	for (int i = 0; i < 17; i++){
		if ((depth[u] - depth[v]) >> i & 1){
			u = par[i][u];
		}
	}
	
	if (u == v) return (u);
	
	for (int i = 16; i >= 0; i--){
		if (par[i][u] != par[i][v]){
			u = par[i][u]; v = par[i][v];
		}
	}
	
	return (par[0][u]);
}

void bfs(vector<int> v)
{
	bool vis[100000];
	fill(vis, vis + 100000, 0);
	queue<pair<int, int> > q;
	
	for (int i = 0; i < v.size(); i++){
		q.push(make_pair(v[i], 0));
		vis[v[i]] = true;
	}
	
	for (; q.size(); q.pop()){
		pair<int, int> x = q.front();
		
		if (x.second < d[x.first]) d[x.first] = x.second;
		
		for (int i = 0; i < G[x.first].size(); i++){
			if (!vis[G[x.first][i]]){
				vis[G[x.first][i]] = true;
				q.push(make_pair(G[x.first][i], x.second + 1));
			}
		}
	}
}

int main()
{
	int n, m;
	
	scanf("%d %d", &n, &m);
	
	for (int i = 0; i < n - 1; i++){
		int a, b;
		scanf("%d %d", &a, &b);
		G[--a].push_back(--b);
		G[b].push_back(a);
	}
	
	init(n);
	
	vector<int> block;
	block.push_back(0);
	
	for (int i = 0; i < m; i++){
		if (i % 317 == 0){
			bfs(block);
			block.clear();
		}
		int q, v;
		scanf("%d %d", &q, &v);
		
		if (q == 1){
			block.push_back(--v);
		}
		else {
			int ret = d[--v];
			for (int j = 0; j < block.size(); j++){
				ret = min(ret, depth[v] + depth[block[j]] - 2 * depth[lca(v, block[j])]);
			}
			printf("%d\n", ret);
		}
	}
	
	return (0);
}