https://leetcode.cn/problems/minimize-malware-spread/description/?envType=daily-question&envId=2024-04-11

思路很好想的一道题。连通的节点中出现恶意节点,那么里面的所有节点都会被感染。题目要求去除一个最初被感染的节点后,整个网络中被感染节点数最少,那么对于每个最初被感染的节点就有两种情况:

  1. 节点所在的连通分量中只有该初始节点被感染;
  2. 节点所在的联通分量中还有其他被感染的节点;
因为只考虑被感染的节点,所以没有列出没有节点被感染的连通分量。

对于第一种情况的节点,通过DFS或BFS记录下每个连通分量的节点数,最多的即为题目所需的答案。

对于第二种情况,无论取走哪个节点,该连通分量都会被感染,所以直接置为0就好。

DFS

下面是我用DFS实现的解答:

int minMalwareSpread(std::vector<std::vector<int>>& graph, std::vector<int>& initial) {
    int graphSize = graph.size(), initialSize = initial.size();
    std::vector<int> infected(graphSize, -1);
    std::unordered_map<int, int> infectNum;

    std::function<void(const int&, const int&)> dfs = [&] (const int &root, const int& pathogen) {
        if (infected[root] == pathogen)
            return;

        infected[root] = pathogen;
        for (int node = 0; node < graphSize; ++node) {
            if (graph[root][node]) {
                dfs(node, pathogen);
                infectNum[pathogen]++;
            }
        }
    };

    for (int i = 0; i < initialSize; ++i) {
        int curNode = initial[i];
        if (infected[curNode] != -1) {
            infectNum[curNode] = -1, infectNum[infected[curNode]] = -1;
        } else {
            infectNum[curNode] = 0;
            dfs(curNode, curNode);
        }
    }

    int res = graphSize, maxNum = -1;
    for (auto &element : infectNum) {
        if (element.second > maxNum)
            res = element.first, maxNum = element.second;
        else if (element.second == maxNum && element.first < res)
            res = element.first;
    }
    return res;
}

时间复杂度 $O(n^2)$ ,空间复杂度 $O(n)$ 。

并查集

执行用时 148ms ,击败 32.27% 的解答,我不满意 XD

题解有用并查集维护连通关系,同时统计每个连通分量大小,最后再遍历initial找出答案。

于是去学了一下并查集,参考了LeetCode上一篇不错的回答,最后写出如下:

int minMalwareSpread(std::vector<std::vector<int>>& graph, std::vector<int>& initial) {
    class UnionFind {
    private:
        std::vector<int> parent;
        std::vector<int> unionSize;

    public:
        UnionFind(const int &n) {
            parent.resize(n);
            unionSize.resize(n);
            std::fill(parent.begin(), parent.end(), -1);
            std::fill(unionSize.begin(), unionSize.end(), 1);
        }

        void unite(const int &x, const int &y) {
            int xRoot = find(x), yRoot = find(y);
            if (xRoot == yRoot)
                return;

            if (parent[xRoot] < parent[yRoot]) {
                parent[yRoot] = xRoot;
                unionSize[xRoot] += unionSize[yRoot];
            } else if (parent[xRoot] > parent[yRoot]) {
                parent[xRoot] = yRoot;
                unionSize[yRoot] += unionSize[xRoot];
            } else {
                parent[xRoot] = yRoot;
                parent[yRoot] -= 1;
                unionSize[yRoot] += unionSize[xRoot];
            }
        }

        int find(const int &x) {
            return parent[x] < 0 ? x : parent[x] = find(parent[x]);
        }

        int getInfectedSize(const int &x) {
            return unionSize[find(x)];
        }

        int getInfectedSizeByRoot(const int &root) {
            return unionSize[root];
        }
    };

    int n = graph.size();
    UnionFind unionFind(n);

    for (int i = 0; i < n; ++i) {
        for (int j = i + 1; j < n; ++j) {
            if (graph[i][j])
                unionFind.unite(i, j);
        }
    }

    std::vector<int> pathogenNum(n, 0);
    for (const auto &x : initial)
        pathogenNum[unionFind.find(x)] += 1;

    int res = n, maxRange = -1;
    for (const auto &x : initial) {
        int root = unionFind.find(x);
        int affectRange = pathogenNum[root] == 1 ? unionFind.getInfectedSizeByRoot(root) : 0;
        if (affectRange > maxRange)
            res = x, maxRange = affectRange;
        else if (affectRange == maxRange && x < res)
            res = x;
    }

    return res;
}

时间复杂度 $O(n^2 \times \alpha(n))$ ,空间复杂度 $O(n)$ 。

用时129ms,击败 75.29% ,暂时就这样吧,懒了

最后修改:2024 年 04 月 21 日
如果觉得我的文章对你有用,请随意赞赏