Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

SPOJ 913 - Query on a tree II

問題文 : Query on a tree II

解法 :

複数テストケースなので初期化する位置に気をつけないと死んでしまう問題 (ぼく死んでしまいました).

距離を求めるところはeuler-tourっぽくやるのも, doubling でやるのもよしです.
k 番目の頂点は, 根からの深さとlca(u, v)の深さの差でうまく求めることができます.

O(T * (N + Q) log N) 時間ほどで動作するプログラムが出来ます.

コード :

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>

#define MAX_N (10000)
#define LOG_N (14)

using namespace std;

struct Edge {
	int to;
	int cost;
	Edge(int to, int cost) : to(to), cost(cost) {}
	Edge(){}
};

int T, N;
vector<Edge> G[MAX_N];
int par[LOG_N][MAX_N];
int depth[MAX_N];
int id[MAX_N];
int bit[MAX_N * 2];

void add(int pos, int val)
{
	while (pos < MAX_N * 2){
		bit[pos] += val;
		pos += pos & -pos;
	}
}

int sum(int pos)
{
	int ret = 0;
	while (pos){
		ret += bit[pos];
		pos &= (pos - 1);
	}
	return (ret);
}

void dfs(int v, int p, int d, int &k)
{
	depth[v] = d;
	par[0][v] = p;
	id[v] = k++;
	
	for (int i = 0; i < G[v].size(); i++){
		Edge e = G[v][i];
		if (e.to != p){
			add(k, e.cost);
			
			dfs(e.to, v, d + 1, k);
			
			add(++k, -e.cost);
		}
	}
}

void init()
{
	int k = 0;
	memset(par, -1, sizeof(par));
	memset(bit, 0, sizeof(bit));
	dfs(0, -1, 0, k);
	
	for (int i = 1; i < LOG_N; i++){
		for (int j = 0; j < N; j++){
			if (~par[i - 1][j]) par[i][j] = par[i - 1][par[i - 1][j]];
		}
	}
}

int getLCA(int u, int v)
{
	if (depth[u] > depth[v]) swap(u, v);
	
	for (int i = LOG_N - 1; i >= 0; i--){
		if ((depth[v] - depth[u]) >> i & 1){
			v = par[i][v];
		}
	}
	
	if (u == v) return (u);
	
	for (int i = LOG_N - 1; i >= 0; i--){
		if (par[i][u] != par[i][v]){
			u = par[i][u]; v = par[i][v];
		}
	}
	
	return (par[0][v]);
}

int getParent(int dst, int v)
{
	for (int i = LOG_N - 1; i >= 0; i--){
		if (dst >> i & 1) v = par[i][v];
	}
	
	return (v);
}

int main()
{
	
	scanf("%d", &T);
	
	int x = 0;
	while (T--){
		scanf("%d", &N);
		
		for (int i = 0; i < N; i++) G[i].clear();
		
		for (int i = 0; i < N - 1; i++){
			int a, b, c;
			scanf("%d %d %d", &a, &b, &c);
			--a; --b;
			G[a].push_back(Edge(b, c));
			G[b].push_back(Edge(a, c));
		}
		
		init();
		
		if (x++) printf("\n");
		char query[128];
		while (scanf("%s", query) && query[1] != 'O'){
			int a, b, c;
			if (query[0] == 'D'){
				scanf("%d %d", &a, &b);
				int p = getLCA(--a, --b);
				printf("%d\n", sum(id[a]) + sum(id[b]) - 2 * sum(id[p]));
			}
			else {
				scanf("%d %d %d", &a, &b, &c);
				c--;
				int p = getLCA(--a, --b);
				int dA = depth[a] - depth[p], dB = depth[b] - depth[p];
				int ret;
				if (dA >= c) ret = getParent(c, a);
				else ret = getParent(dB - (c - dA), b);
				
				printf("%d\n", ret + 1);
			}
		}
	}
	
	return (0);
}