SPOJ 913 - Query on a tree II
問題文 : Query on a tree II
解法 :
複数テストケースなので初期化する位置に気をつけないと死んでしまう問題 (ぼく死んでしまいました).
距離を求めるところはeuler-tourっぽくやるのも, doubling でやるのもよしです.
k 番目の頂点は, 根からの深さとlca(u, v)の深さの差でうまく求めることができます.
O(T * (N + Q) log N) 時間ほどで動作するプログラムが出来ます.
コード :
#include <cstdio> #include <cstring> #include <algorithm> #include <vector> #define MAX_N (10000) #define LOG_N (14) using namespace std; struct Edge { int to; int cost; Edge(int to, int cost) : to(to), cost(cost) {} Edge(){} }; int T, N; vector<Edge> G[MAX_N]; int par[LOG_N][MAX_N]; int depth[MAX_N]; int id[MAX_N]; int bit[MAX_N * 2]; void add(int pos, int val) { while (pos < MAX_N * 2){ bit[pos] += val; pos += pos & -pos; } } int sum(int pos) { int ret = 0; while (pos){ ret += bit[pos]; pos &= (pos - 1); } return (ret); } void dfs(int v, int p, int d, int &k) { depth[v] = d; par[0][v] = p; id[v] = k++; for (int i = 0; i < G[v].size(); i++){ Edge e = G[v][i]; if (e.to != p){ add(k, e.cost); dfs(e.to, v, d + 1, k); add(++k, -e.cost); } } } void init() { int k = 0; memset(par, -1, sizeof(par)); memset(bit, 0, sizeof(bit)); dfs(0, -1, 0, k); for (int i = 1; i < LOG_N; i++){ for (int j = 0; j < N; j++){ if (~par[i - 1][j]) par[i][j] = par[i - 1][par[i - 1][j]]; } } } int getLCA(int u, int v) { if (depth[u] > depth[v]) swap(u, v); for (int i = LOG_N - 1; i >= 0; i--){ if ((depth[v] - depth[u]) >> i & 1){ v = par[i][v]; } } if (u == v) return (u); for (int i = LOG_N - 1; i >= 0; i--){ if (par[i][u] != par[i][v]){ u = par[i][u]; v = par[i][v]; } } return (par[0][v]); } int getParent(int dst, int v) { for (int i = LOG_N - 1; i >= 0; i--){ if (dst >> i & 1) v = par[i][v]; } return (v); } int main() { scanf("%d", &T); int x = 0; while (T--){ scanf("%d", &N); for (int i = 0; i < N; i++) G[i].clear(); for (int i = 0; i < N - 1; i++){ int a, b, c; scanf("%d %d %d", &a, &b, &c); --a; --b; G[a].push_back(Edge(b, c)); G[b].push_back(Edge(a, c)); } init(); if (x++) printf("\n"); char query[128]; while (scanf("%s", query) && query[1] != 'O'){ int a, b, c; if (query[0] == 'D'){ scanf("%d %d", &a, &b); int p = getLCA(--a, --b); printf("%d\n", sum(id[a]) + sum(id[b]) - 2 * sum(id[p])); } else { scanf("%d %d %d", &a, &b, &c); c--; int p = getLCA(--a, --b); int dA = depth[a] - depth[p], dB = depth[b] - depth[p]; int ret; if (dA >= c) ret = getParent(c, a); else ret = getParent(dB - (c - dA), b); printf("%d\n", ret + 1); } } } return (0); }