Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

AOJ 1341: Longest Chain

問題文 : Longest Chain | Aizu Online Judge

概要

 \mathbb{Z}^3 に半順序関係  \prec

 (x_1, y_1, z_1) \prec (x_2, y_2, z_2) \overset{\triangle{}}{\Leftrightarrow} x_1 < x_2 かつ  y_1 < y_2 かつ  z_1 < z_2

で導入する。 n 個の  \mathbb{Z}^3 の点の中で、上記で定めた半順序を満たす最長の列の長さを求めよ。

  • 制約
    •  1 \leqq n \leqq 3 \times 10^5

解法

入力が  \mathbb{Z}^2 の点で、同じような半順序を入れた問題であった場合、 x 座標で昇順ソートをしておけば、あとは有名な LIS DP を行えば  O(n \log n) で問題が解ける。この際に  x 座標が同じ点については  y 座標が大きいものが先にくるようにソートすると特別な処理をする必要がなくなる。

では入力が  \mathbb{Z}^3 の点である場合はどうだろうか。このときも x 座標で昇順に点をソートしておけば、 y,\ z の 2 つの座標による 2 次元の最長増加部分列問題に帰着される。素朴な DP で解くと  O(n^2) 時間かかってしまい間に合わない。

そこで、1 次元の LIS DP を 2 分探索で高速化した時のように、2 分探索を使って dp をすることにする。
1 次元の LIS DP では  dp[k] を長さ k の増加列の末尾に来る値の最小値という風に定めていたが、今回は 1 次元の場合と異なりすべての点の間に順序を定めることが出来ないため、最小値が一意に定まらない。よって、 dp[k] を長さ k の増加列の末尾に来る点集合とする。この点集合を適切に更新していく必要がある。更新には std::set を用いると楽である。時間計算量は  O(n \log^2 n) となる。

コード

#include <bits/stdc++.h>

using namespace std;

int a, b, C = ~(1<<31), M = (1<<16)-1;
int r()
{
    a = 36969 * (a & M) + (a >> 16);
    b = 18000 * (b & M) + (b >> 16);
      return (C & ((a << 16) + b)) % 1000000;
}

struct Point {
    int x, y, z;
    Point(int x, int y, int z) : x(x), y(y), z(z) {}
    Point(){}
    bool operator < (const Point &a) const {
        return (x == a.x ? y == a.y ? z > a.z : y > a.y : x < a.x);
    }
};

struct Point2 {
    int y, z;
    Point2(int y, int z) : y(y), z(z) {}
    Point2(){}
    bool operator < (const Point2 &a) const {
        return (y < a.y);
    }
};

int main()
{
    int m, n, A, B;

    while (scanf("%d %d %d %d", &m, &n, &A, &B) && m + n){
        a = A; b = B;
        vector<Point> v(m + n);
        vector<set<Point2>> s(m + n);
        for (int i = 0; i < m; i++){
            int x, y, z;
            scanf("%d %d %d", &x, &y, &z);
            v[i] = Point(x, y, z);
        }
        for (int i = 0; i < n; i++){
            int x = r();
            int y = r();
            int z = r();
            v[i + m] = Point(x, y, z);
        }
        sort(v.begin(), v.end());

        int ans = 1;
        s[0].insert(Point2(v[0].y, v[0].z));
        for (int i = 1; i < m + n; i++){
            int lf = -1, rg = ans - 1;
            while (lf != rg){
                int mid = (lf + rg + 1) / 2;
                auto it = s[mid].lower_bound(Point2(v[i].y, v[i].z));
                if (it != s[mid].begin() && (--it)->z < v[i].z) lf = mid;
                else rg = mid - 1;
            }
            rg++;
            ans = max(ans, rg + 1);

            auto it = s[rg].lower_bound(Point2(v[i].y, v[i].z));
            while (it != s[rg].end() && it->z >= v[i].z) it = s[rg].erase(it);

            if (!s[rg].size()) s[rg].insert(Point2(v[i].y, v[i].z));
            else {
                auto it = s[rg].lower_bound(Point2(v[i].y, v[i].z));
                if (it == s[rg].begin() || (--it)->z != v[i].z) s[rg].insert(Point2(v[i].y, v[i].z));
            }
        }
        printf("%d\n", ans);
    }

    return (0);
}