Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

Codeforces 472D - Design Tutorial: Inverse the Problem

概要

$n$ 頂点の有向グラフの隣接行列 $A$ が与えられる. この行列が各辺に重みがついた無向グラフであり, 木を表していれば YES を, そうでなければ NO を出力せよ.

制約

$n \leqq 2000$

解法

隣接行列にポカミス($A_{ij} \neq A_{ij},\ A_{ii} \neq 0, A_{ij} = 0$)があれば木でないという判定がすぐに出来る (ただし最後の $A_{ij} = 0$ では $i \neq j$ とする).

グラフが木であるということから次の 3 つの事が考察できる.

  • このグラフが木である、ということを考えるとこのグラフは $n - 1$ 辺から成るはずである. つまり隣接行列のうちの $2(n - 1)$ 要素は辺のコストがそのまま要素になっているはずである(それぞれの要素は正でないといけないので負閉路は生じないから).
  • 直接繋がっていない頂点どうしは, 繋がっている辺をたどった値が距離になるはずである. すなわち, 直接繋がっていない頂点どうしの距離はある辺のコスト以上になる.
  • 最も短い距離をもつ 2 点はその辺が木の中に採用されなければならない.

これらをもとに考えると, 隣接行列から MST を作り, 実際に全点間の距離を求めれば良いということが分かる.

全点間の距離は LCA で求めて $O(n^2 \log n)$ で求まるが, 実装が長くなるので, 今回は木という性質を活かして単に全点で DFS をすれば $O(n^2)$ で点対の距離がすべて求まる.

全点で DFS をすればこの問題は $O(n^2)$ で解ける. (後半で LCA しちゃったけど...)

#include <bits/stdc++.h>

using namespace std;

int p[2048];

void init(int n)
{
    memset(p, -1, sizeof(p));
}

int find(int x)
{
    return (p[x] < 0 ? x : p[x] = find(p[x]));
}

bool same(int x, int y)
{
    return (find(x) == find(y));
}

void merge(int x, int y)
{
    x = find(x);
    y = find(y);
    if (x == y) return;
    
    if (-p[x] < -p[y]) swap(x, y);
    p[x] += p[y];
    p[y] = x;
}

vector<pair<int, int> > G[2048];
int par[13][2048];
int dep[2048];
long long dist[2048];

void dfs(int v, int p, int d, long long dis)
{
    par[0][v] = p;
    dep[v] = d;
    dist[v] = dis;
    
    for (int i = 0; i < G[v].size(); i++){
        if (G[v][i].first != p){
            dfs(G[v][i].first, v, d + 1, dis + G[v][i].second);
        }
    }
}

int lca(int u, int v)
{
    if (dep[u] < dep[v]) swap(u, v);
    for (int i = 12; i >= 0; i--){
        if ((dep[u] - dep[v]) >> i & 1){
            u = par[i][u];
        }
    }
    
    if (u == v) return (u);
    
    for (int i = 12; i >= 0; i--){
        if (par[i][u] != par[i][v]){
            u = par[i][u]; v = par[i][v];
        }
    }
    
    return (par[0][u]);
}

int main()
{
    static int a[2048][2048];
    
    int n;
    
    scanf("%d", &n);
    
    for (int i = 0; i < n; i++){
        for (int j = 0; j < n; j++){
            scanf("%d", &a[i][j]);
        }
    }
    
    vector<pair<int, pair<int, int> > > es;
    
    for (int i = 0; i < n; i++){
        for (int j = i + 1; j < n; j++){
            if (a[i][j] != a[j][i] || a[i][j] == 0) return (!printf("NO\n"));
            es.push_back(make_pair(a[i][j], make_pair(i, j)));
        }
        if (a[i][i] != 0) return (!printf("NO\n"));
    }
    
    init(n);
    sort(es.begin(), es.end());
    
    for (int i = 0; i < es.size(); i++){
        if (!same(es[i].second.first, es[i].second.second)){
            merge(es[i].second.first, es[i].second.second);
            G[es[i].second.first].push_back(make_pair(es[i].second.second, es[i].first));
            G[es[i].second.second].push_back(make_pair(es[i].second.first, es[i].first));
        }
    }
    
    memset(par, -1, sizeof(par));
    dfs(0, -1, 0, 0);
    
    for (int i = 0; i + 1 < 13; i++){
        for (int j = 0; j < n; j++){
            if (~par[i][j]){
                par[i + 1][j] = par[i][par[i][j]];
            }
        }
    }
    
    for (int i = 0; i < n; i++){
        for (int j = i + 1; j < n; j++){
            if (dist[i] + dist[j] - 2 * dist[lca(i, j)] != a[i][j]){
                return (!printf("NO\n"));
            }
        }
    }
    
    printf("YES\n");
    
    return (0);
}