読者です 読者をやめる 読者になる 読者になる

Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

UnKoder Challenges - XOR Graph

すぎむさんが問題を書いている UnKoder の問題をちょっとずつ解いています。
解いていて面白かった問題を紹介しようと思います。 UnKoder の問題には, このリンクから挑戦することができます。

今回はこの問題セットの中から XOR Graph の解法を紹介しようと思います。

問題概要

$N$ 個の頂点と $ M $ 本の辺からなる無向グラフが与えられる。頂点は $0$ から $N−1$ まで非負整数の番号が振られている。グラフは自己ループや多重辺を含まない。
あなたは, このグラフに何本かの辺を追加し, グラフ全体を連結にしたい。
$x$ 番目の頂点と $y$ 番目の頂点の間に辺を追加するには, $x \bigoplus y$ だけのコストが掛かる。ただし, $\bigoplus$ はビット XOR の記号である。
グラフ全体を連結にするのに必要な総コストの最小値を求めよ。

制約
  • $1 \leqq N \leqq 10^5$
  • $0 \leqq M \leqq 10^5$
  • グラフは自己ループや多重辺を含まない
解法

この問題は, 頂点 $i$, $j$ 間のコスト $c_{ij}$ を以下のように定義すると, 最小全域木を作る問題に帰着できる。

$$
c_{ij} = \left\{
\begin{array}{l}
i \bigoplus j &\text{iff. edge}\ (i,\ j) \not\in E\\
0 &\text{iff. edge}\ (i,\ j) \in E
\end{array}
\right.
$$

ここで, $E$ はグラフの辺を表す集合である。

ここで, 全ての辺を列挙して最小全域木を構築すると, 辺の数が $O(N^2)$ 本になってしまうため, 時間計算量が $O(N^2 \log N)$ となってしまい, 時間制限を超過してしまう。

そこで, 頂点間のコスト $c_{ij}$ について考察をする。 $c_{ij} = 0$ となる辺はたかだか $ M $ 本しか無いので, すべて最小全域木の構築に使用して構わない。
$c_{ij} \neq 0$ のとき, できるだけ $c_{ij}$ が小さい辺を用いて最小全域木を構築したい。そこで, $c_{ij}$ の値が小さい順に, どういうふうな辺の結びつきができているかを調べてみる。

まずは, $c_{ij}$ が $1$ であるような辺はどういう辺か考えてみる。
$c_{ij} = i \bigoplus j = 1$ ということは, $i$ と $j$ の下位 $1$ bit のみが異なるということである。すなわち,

0--1 2--3 4--5 6--7 8--9 ...

という風に, $2i$ と $2i+1$ 間の辺のコストが $1$ になっているということである。

$c_{ij} = 2$ であるような辺は

0--2 1--3 4--6 5--7 8--10 9--11 ...

という辺である。 これと $c_{ij} = 1$ である辺を組み合わせると,

0--1--2--3 4--5--6--7

という風に, 4 つの頂点ごとにコストが最小の木ができていることが分かる。

次に, $c_{ij} = 3$ であるような辺について考えてみると

0--3 1--2 4--7 5--6 ...

という辺であることがわかるが, これは $c_{ij} \leqq 2$ の辺の集合で連結になっている頂点の組み合わせである。
つまり, 下位 2 bit のみが変化する頂点の繋がりは, $c_{ij} = 1$ または $c_{ij} = 2$ の頂点の組み合わせで網羅することが可能であるということである。

同様な考察を続けていくと, 使うべき辺は $c_{ij} = 2^n$ となるような辺だけであることがわかる。このような辺は $O(n \log n)$ 本のみ存在するため, 計算量は $O((m + n \log n) \log n)$ 時間でこの問題を解くことが可能である。

コード
#include <bits/stdc++.h>

#define MAX_N (100000)

using namespace std;

struct Edge {
    int cost;
    int from, to;
    Edge(){}
    Edge(int cost, int from, int to) : cost(cost), from(from), to(to){}
    bool operator < (const Edge &a) const {
        return (cost < a.cost);
    }
};

int par[MAX_N];

void init()
{
    memset(par, -1, sizeof(par));
}

int find(int x)
{
    return (par[x] < 0 ? x : find(par[x]));
}

void merge(int x, int y)
{
    x = find(x); y = find(y);
    if (x == y) return;
    if (-par[x] < -par[y]) swap(x, y);
    par[x] += par[y];
    par[y] = x;
}

bool same(int x, int y)
{
    return (find(x) == find(y));    
}

int main()
{
    int n, m;
    
    scanf("%d %d", &n, &m);
    
    vector<Edge> e;
    
    for (int i = 0; i < m; i++){
        int x, y;
        scanf("%d %d", &x, &y);
        e.push_back(Edge(0, x, y));
    }
    
    for (int i = 0; i < n; i++){
        for (int j = 0; j <= 17; j++){
            if ((i ^ (1 << j)) < n){
                e.push_back(Edge((1 << j), i, (i ^ (1 << j))));
            }
        }
    }
    
    sort(e.begin(), e.end());
    
    int ans = 0;
    
    init();
    for (int i = 0; i < e.size(); i++){
        if (!same(e[i].from, e[i].to)){
            merge(e[i].from, e[i].to);
            ans += e[i].cost;
        }
    }
    
    printf("%d\n", ans);
    
    return (0);
}
追記

zerokugi さんと tokoharu さんの指摘により, 使用する辺を更に節約できる事が分かりました。
イメージとしては, $2^n$ 個ずつの頂点の集まりの, 隣接する2つの集まりを順に接続していく感じです。
使用する辺の本数は $O(n)$ 本に抑えられるため, 計算量は実装の仕方次第で $O( (n+m) \log (n + m) )$ だったり $O(M \alpha(M) + n \log n)$ になったりします。 ただし $\alpha(n)$ はアッカーマン関数逆関数です。

#include <bits/stdc++.h>

#define MAX_N (100000)

using namespace std;

struct Edge {
    int cost;
    int from, to;
    Edge(){}
    Edge(int cost, int from, int to) : cost(cost), from(from), to(to){}
    bool operator < (const Edge &a) const {
        return (cost < a.cost);
    }
};

int par[MAX_N];

void init()
{
    memset(par, -1, sizeof(par));
}

int find(int x)
{
    return (par[x] < 0 ? x : find(par[x]));
}

void merge(int x, int y)
{
    x = find(x); y = find(y);
    if (x == y) return;
    if (-par[x] < -par[y]) swap(x, y);
    par[x] += par[y];
    par[y] = x;
}

bool same(int x, int y)
{
    return (find(x) == find(y));    
}

int main()
{
    int n, m;
    
    scanf("%d %d", &n, &m);
    
    vector<Edge> e;
    
    for (int i = 0; i < m; i++){
        int x, y;
        scanf("%d %d", &x, &y);
        e.push_back(Edge(0, x, y));
    }
    
    for (int i = 1; i < n; i *= 2){
        for (int j = 0; j + i < n; j += i * 2){
            e.push_back(Edge(i, j, j + i));
        }
    }
    
    sort(e.begin(), e.end());
    
    int ans = 0;
    
    init();
    for (int i = 0; i < e.size(); i++){
        if (!same(e[i].from, e[i].to)){
            merge(e[i].from, e[i].to);
            ans += e[i].cost;
        }
    }
    
    printf("%d\n", ans);
    
    return (0);
}