Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

PKU 2763 : Housewife Wind

問題文 : Housewife Wind

解法 :
頂点u, vの最小共通祖先をpとすると, cost(u, v) = cost(u, p) + cost(v, p) = cost(root, u) + cost(root, v) - 2 * cost(root, p)となる.
rootからある頂点への距離は, 通った辺を相殺する形で, BITで管理することでO(log N)で求めることが出来る.
この処理は, セグメント木を用いて最小共通祖先を求めれば, その初期化に付随してできるので簡潔に行える.

全体の計算量は, 初期化にO(n log n), 各クエリにO(log n)で応答するので, O ((q + n) log n)となり, 高速に答えを求めることが出来る.


コード :

#include <cstdio>
#include <cstring>
#include <vector>

#define MAX_N (111111)

using namespace std;

struct Edge {
	int id, to, cost;
	Edge(int id, int to, int cost) : id(id), to(to), cost(cost){}
	Edge(){}
};

struct Node {
	int pos, val;
	Node(int pos, int val) : pos(pos), val(val){}
	Node(){}
};

bool operator < (const Node &a, const Node &b)
{
	return (a.val < b.val);
}

vector<Edge> G[MAX_N];
int sz; //セグメント木の大きさ
int id[MAX_N];
int vs[MAX_N * 2 - 1], depth[MAX_N * 2 - 1];
int es[2 * (MAX_N - 1)];
Node seg[1 << 19];

int cost[MAX_N];
int bit[2 * MAX_N];

void add(int k, int x)
{
	while (k <= 2 * MAX_N - 1){
		bit[k] += x;
		k += k & -k;
	}
}

int sum(int k)
{
	int ret = 0;
	while (k){
		ret += bit[k];
		k &= (k - 1);
	}
	
	return (ret);
}

void dfs(int v, int p, int d, int &k)
{
	id[v] = k;
	vs[k]= v;
	depth[k++] = d;
	
	for (int i = 0; i < G[v].size(); i++){
		if (G[v][i].to != p){
			add(k, G[v][i].cost);
			es[G[v][i].id * 2] = k;
			dfs(G[v][i].to, v, d + 1, k);
			vs[k] = v;
			depth[k++] = d;
			add(k, -G[v][i].cost);
			es[G[v][i].id * 2 + 1] = k;
		}
	}
}

Node null = Node(-1, 999999);

void _init(int _n)
{
	sz = 1;
	while (sz < _n) sz *= 2;
	
	for (int i = 0; i < 1 << 19; i++) seg[i] = null;
	
	for (int i = 0; i < _n; i++){
		if (~vs[i]){
			int k = i + sz - 1;
			seg[k].pos = i;
			seg[k].val = depth[i];
			while (k){
				k = (k - 1) / 2;
				seg[k] = min(seg[k * 2 + 1], seg[k * 2 + 2]);
			}
		}
	}
}

void init(int v)
{
	int k = 0;
	dfs(0, -1, 0, k);
	
	_init(v * 2 - 1);
}

Node getMin(int a, int b, int k = 0, int l = 0, int r = sz)
{
	if (r <= a || b <= l){
		return (null);
	}
	
	if (a <= l && r <= b){
		return (seg[k]);
	}
	
	Node left = getMin(a, b, k * 2 + 1, l, (l + r) / 2);
	Node right = getMin(a, b, k * 2 + 2, (l + r) / 2, r);
	
	return (min(left, right)); 
}

int lca(int u, int v)
{
	int idu = id[u], idv = id[v];
	if (idu > idv) swap(idu, idv);
	
	return (vs[getMin(idu, idv + 1).pos]);
}

int main()
{
	int n, q, s;
	scanf("%d %d %d", &n, &q, &s);
	
	s--;
	
	for (int i = 0; i < n - 1; i++){
		int a, b, c;
		scanf("%d %d %d", &a, &b, &c);
		a--; b--;
		G[a].push_back(Edge(i, b, c));
		G[b].push_back(Edge(i, a, c));
		cost[i] = c;
	}
	
	init(n);
	
	for (int i = 0; i < q; i++){
		int qtype;
		scanf("%d", &qtype);
		
		if (qtype == 0){
			int next;
			scanf("%d", &next);
			int p = lca(s, next - 1);
			printf("%d\n", sum(id[s]) + sum(id[next - 1]) - sum(id[p]) * 2);
			s = next - 1;
		}
		else {
			int nt, nc;
			scanf("%d %d", &nt, &nc);
			int x = nt - 1;
			add(es[x * 2], nc - cost[x]);
			add(es[x * 2 + 1], cost[x] - nc);
			cost[x] = nc;
		}
	}
	
	return (0);
}