https://leetcode.cn/problems/minimize-malware-spread/description/?envType=daily-question&envId=2024-04-11
思路很好想的一道题。连通的节点中出现恶意节点,那么里面的所有节点都会被感染。题目要求去除一个最初被感染的节点后,整个网络中被感染节点数最少,那么对于每个最初被感染的节点就有两种情况:
- 节点所在的连通分量中只有该初始节点被感染;
- 节点所在的联通分量中还有其他被感染的节点;
因为只考虑被感染的节点,所以没有列出没有节点被感染的连通分量。
对于第一种情况的节点,通过DFS或BFS记录下每个连通分量的节点数,最多的即为题目所需的答案。
对于第二种情况,无论取走哪个节点,该连通分量都会被感染,所以直接置为0就好。
DFS
下面是我用DFS实现的解答:
int minMalwareSpread(std::vector<std::vector<int>>& graph, std::vector<int>& initial) {
int graphSize = graph.size(), initialSize = initial.size();
std::vector<int> infected(graphSize, -1);
std::unordered_map<int, int> infectNum;
std::function<void(const int&, const int&)> dfs = [&] (const int &root, const int& pathogen) {
if (infected[root] == pathogen)
return;
infected[root] = pathogen;
for (int node = 0; node < graphSize; ++node) {
if (graph[root][node]) {
dfs(node, pathogen);
infectNum[pathogen]++;
}
}
};
for (int i = 0; i < initialSize; ++i) {
int curNode = initial[i];
if (infected[curNode] != -1) {
infectNum[curNode] = -1, infectNum[infected[curNode]] = -1;
} else {
infectNum[curNode] = 0;
dfs(curNode, curNode);
}
}
int res = graphSize, maxNum = -1;
for (auto &element : infectNum) {
if (element.second > maxNum)
res = element.first, maxNum = element.second;
else if (element.second == maxNum && element.first < res)
res = element.first;
}
return res;
}
时间复杂度 $O(n^2)$ ,空间复杂度 $O(n)$ 。
并查集
执行用时 148ms ,击败 32.27% 的解答,我不满意 XD
题解有用并查集维护连通关系,同时统计每个连通分量大小,最后再遍历initial找出答案。
于是去学了一下并查集,参考了LeetCode上一篇不错的回答,最后写出如下:
int minMalwareSpread(std::vector<std::vector<int>>& graph, std::vector<int>& initial) {
class UnionFind {
private:
std::vector<int> parent;
std::vector<int> unionSize;
public:
UnionFind(const int &n) {
parent.resize(n);
unionSize.resize(n);
std::fill(parent.begin(), parent.end(), -1);
std::fill(unionSize.begin(), unionSize.end(), 1);
}
void unite(const int &x, const int &y) {
int xRoot = find(x), yRoot = find(y);
if (xRoot == yRoot)
return;
if (parent[xRoot] < parent[yRoot]) {
parent[yRoot] = xRoot;
unionSize[xRoot] += unionSize[yRoot];
} else if (parent[xRoot] > parent[yRoot]) {
parent[xRoot] = yRoot;
unionSize[yRoot] += unionSize[xRoot];
} else {
parent[xRoot] = yRoot;
parent[yRoot] -= 1;
unionSize[yRoot] += unionSize[xRoot];
}
}
int find(const int &x) {
return parent[x] < 0 ? x : parent[x] = find(parent[x]);
}
int getInfectedSize(const int &x) {
return unionSize[find(x)];
}
int getInfectedSizeByRoot(const int &root) {
return unionSize[root];
}
};
int n = graph.size();
UnionFind unionFind(n);
for (int i = 0; i < n; ++i) {
for (int j = i + 1; j < n; ++j) {
if (graph[i][j])
unionFind.unite(i, j);
}
}
std::vector<int> pathogenNum(n, 0);
for (const auto &x : initial)
pathogenNum[unionFind.find(x)] += 1;
int res = n, maxRange = -1;
for (const auto &x : initial) {
int root = unionFind.find(x);
int affectRange = pathogenNum[root] == 1 ? unionFind.getInfectedSizeByRoot(root) : 0;
if (affectRange > maxRange)
res = x, maxRange = affectRange;
else if (affectRange == maxRange && x < res)
res = x;
}
return res;
}
时间复杂度 $O(n^2 \times \alpha(n))$ ,空间复杂度 $O(n)$ 。
用时129ms,击败 75.29% ,暂时就这样吧,懒了 。