Lilliput Steps

小さな一歩から着実に. 数学やプログラミングのことを書きます.

JOI春合宿 2011-day1 joitter

問題文 : ジョイッター

解法 :
(1) 1の人がいる場合
(2) 1の人がいなくて, 2, 3の人がいる場合
(3) 3の人だけの場合

で場合分けをする.

(1)
1の人がいる時は, その人と他の人を全員友達にしてあげれば, 最小の辺数を達成できる.

(2)
N人の人がいれば, ある人と皆が友人であればN-1回の友だち登録で皆が日記を読める. これをN 人全員について試す.
ただし, 2 の人が二人の場合, その2 人が友達であれば, その2 人とのコストのうち小さい方の人と友だちになった時が最適な解となる可能性もあるので, それもチェックする.

(3)
3 の人だけの場合, これはただの最小全域木を求める問題となる. 隣接行列が与えられているので, 久しぶりにprim で書いてみた.

全体で, O(N^2) 時間で計算が終了する.

コード :

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long lint;

int cost[1024][1024], mincost[1024];
bool done[1024][1024], used[1024];
int n;
int type[1024];
int num[3] = {0};

int prim(void)
{
    int i;
    int res;
    int u, v;
    
    for (i = 0; i < n; i++){
        mincost[i] = 1000000000;
        used[i] = 0;
    }
    
    mincost[0] = 0;
    res = 0;
    
    while (1){
        v = -1;
        for (u = 0; u < n; u++){
            if (!used[u] && (v == -1 || mincost[u] < mincost[v])){
                v = u;
            }
        }
        
        if (v == -1){
            break;
        }
        used[v] = 1;
        res += mincost[v];
        
        for (u = 0; u < n; u++){
            mincost[u] = min(mincost[u], cost[v][u]);
        }
    }
    return (res);
}

int main()
{
	scanf("%d", &n);
	
	for (int i = 0; i < n; i++){
		scanf("%d", &type[i]);
		num[--type[i]]++;
	}
	
	for (int i = 0; i < n; i++){
		for (int j = 0; j < n; j++){
			scanf("%d", &cost[i][j]);
		}
	}
	
	int ret = 0, tot = 0;
	if (num[0]){
		for (int i = 0; i < n; i++){
			if (type[i] == 0){
				for (int j = 0; j < n; j++){
					if (i != j && !done[i][j]){
						done[i][j] = done[j][i] = 1;
						ret++;
						tot += cost[i][j];
					}
				}
			}
		}
		
		printf("%d %d\n", ret, tot);
	}
	else if (num[1]){
		int ans = -1;
		
		for (int i = 0; i < n; i++){
			int temp = 0;
			for (int j = 0; j < n; j++) temp += cost[i][j];
			if (ans == -1 || ans > temp) ans = temp;
		}
		
		if (num[1] == 2){
			int sum = 0;
			for (int i = 0; i < n; i++){
				int tempVal = 1024;
				if (type[i] == 2){
					for (int j = 0; j < n; j++) if (type[j] == 1) tempVal = min(tempVal, cost[i][j]);
				}
				else for (int j = 0; j < n; j++) if (i > j && type[j] == 1) tempVal = cost[i][j];
				sum += (tempVal != 1024 ? tempVal : 0);
			}
			ans = min(ans, sum);
		}
		
		printf("%d %d\n", n - 1, ans);
	}
	else {
		printf("%d %d\n", n - 1, prim());
	}
	return (0);
}