2025-03-19:标记所有节点需要的时间。用go语言,给定一棵无向树,树中的节点编号从 0 到 n-1。同时给出一个长度为 n-1 的二维整数数组 edges,其中 edges[i] = [ui, vi] 表示节点 ui 和 vi 之间有一条边。
初始时,所有节点都未被标记。对于节点 i:
1.如果 i 是奇数,且在前一个时刻(x-1)至少有一个相邻节点被标记,那么节点 i 会在时刻 x 被标记。
2.如果 i 是偶数,且在前两个时刻(x-2)至少有一个相邻节点被标记,那么节点 i 会在时刻 x 被标记。
你需要返回一个数组 times,其中 times[i] 表示:如果在时刻 t=0 标记节点 i,那么时刻 times[i] 时,树中所有节点都会被标记。
注意:每个 times[i] 的计算是独立的,即在计算 times[i] 时,其他节点都未被标记。
2 <= n <= 100000。
edges.length == n - 1。
edges[i].length == 2。
0 <= edges[i][0], edges[i][1] <= n - 1。
输入保证 edges 表示一棵合法的树。
输入:edges = [[0,1],[0,2]]。
输出:[2,4,3]。
解释:
对于 i = 0 :
节点 1 在时刻 t = 1 被标记,节点 2 在时刻 t = 2 被标记。
对于 i = 1 :
节点 0 在时刻 t = 2 被标记,节点 2 在时刻 t = 4 被标记。
对于 i = 2 :
节点 0 在时刻 t = 2 被标记,节点 1 在时刻 t = 3 被标记。
题目来自leetcode3241。
大体步骤如下:
- 构建图的邻接表表示:
- 根据输入的边列表
edges
,构建一个邻接表g
来表示树的结构。邻接表g
是一个二维数组,其中g[x]
存储与节点x
直接相连的所有节点。这一步的目的是将树的结构存储为便于遍历的形式。
- 深度优先搜索(DFS)计算子树的最大深度:
- 定义一个递归的 DFS 函数
dfs
,用于计算每个节点的子树的最大深度maxD
和次大深度maxD2
。 - 对于每个节点
x
,遍历其所有相邻节点y
(不包括父节点fa
),递归计算y
的子树深度,并加上从x
到y
的边权(边权根据节点编号的奇偶性决定:如果y
是奇数,边权为1
;如果是偶数,边权为2
)。 - 更新节点
x
的maxD
和maxD2
,并记录达到maxD
的子节点y
。这一步的目的是为后续的重新根化(Rerooting)做准备。
- 重新根化(Rerooting)计算每个节点的标记时间:
- 定义一个递归的
reroot
函数,用于计算每个节点的标记时间ans[x]
。 - 对于每个节点
x
,其标记时间ans[x]
是fromUp
(从父节点方向传递过来的最大深度)和p.maxD
(从子节点方向传递过来的最大深度)中的较大值。 - 对于每个子节点
y
,根据y
是否是达到maxD
的子节点,决定传递fromUp
的值:
- 如果
y
是达到maxD
的子节点,则传递maxD2
(次大深度)加上边权。 - 否则,传递
maxD
(最大深度)加上边权。
- 这一步的目的是通过动态调整根节点,计算每个节点作为根时的标记时间。
- 输出结果:
- 最终,
ans
数组存储了每个节点的标记时间,即times[i]
。这个数组表示:如果以节点i
为起点开始标记,那么所有节点被标记完成的时间为times[i]
。
总的时间复杂度和总的额外空间复杂度
- 时间复杂度:
- 构建邻接表的时间复杂度为
O(n)
,其中n
是节点的数量。 - DFS 遍历每个节点一次,计算
maxD
和maxD2
的时间复杂度为O(n)
。 - Rerooting 过程同样遍历每个节点一次,计算每个节点的标记时间,时间复杂度为
O(n)
。 - 因此,总的时间复杂度为
O(n)
。
- 额外空间复杂度:
- 邻接表
g
的空间复杂度为O(n)
。 nodes
数组存储每个节点的maxD
、maxD2
和y
,空间复杂度为O(n)
。ans
数组存储每个节点的标记时间,空间复杂度为O(n)
。- 递归调用栈的深度最多为树的高度,最坏情况下为
O(n)
。 - 因此,总的额外空间复杂度为
O(n)
。
总结
- 时间复杂度:
O(n)
。 - 额外空间复杂度:
O(n)
。 - 该算法通过两次深度优先搜索(DFS)和动态调整根节点的方式,高效地计算了每个节点作为起点时的标记时间。
Go完整代码如下:
package main
import (
"fmt"
)
func timeTaken(edges [][]int) []int {
g := make([][]int, len(edges)+1)
for _, e := range edges {
x, y := e[0], e[1]
g[x] = append(g[x], y)
g[y] = append(g[y], x)
}
// nodes[x] 保存子树 x 的最大深度 maxD,次大深度 maxD2,以及最大深度要往儿子 y 走
nodes := make([]struct{ maxD, maxD2, y int }, len(g))
var dfs func(int, int) int
dfs = func(x, fa int) int {
p := &nodes[x]
for _, y := range g[x] {
if y == fa {
continue
}
maxD := dfs(y, x) + 2 - y%2 // 从 x 出发,往 y 方向的最大深度
if maxD > p.maxD {
p.maxD2 = p.maxD
p.maxD = maxD
p.y = y
} else if maxD > p.maxD2 {
p.maxD2 = maxD
}
}
return p.maxD
}
dfs(0, -1)
ans := make([]int, len(g))
var reroot func(int, int, int)
reroot = func(x, fa, fromUp int) {
p := nodes[x]
ans[x] = max(fromUp, p.maxD)
for _, y := range g[x] {
if y == fa {
continue
}
w := 2 - x%2 // 从 y 到 x 的边权
if y == p.y { // 对于 y 来说,上面要选次大的
reroot(y, x, max(fromUp, p.maxD2)+w)
} else { // 对于 y 来说,上面要选最大的
reroot(y, x, max(fromUp, p.maxD)+w)
}
}
}
reroot(0, -1, 0)
return ans
}
func main() {
edges := [][]int{{0, 1}, {0, 2}}
result := timeTaken(edges)
fmt.Println(result)
}
Python完整代码如下:
# -*-coding:utf-8-*-
from typing import List
def time_taken(edges: List[List[int]]) -> List[int]:
n = len(edges) + 1 # 节点数量
g = [[] for _ in range(n)] # 邻接表
for x, y in edges:
g[x].append(y)
g[y].append(x)
# 定义节点信息:maxD 最大深度,maxD2 次大深度,y 最大深度对应的子节点
nodes = [{'maxD': 0, 'maxD2': 0, 'y': -1} for _ in range(n)]
# 第一次 DFS:计算每个子树的最大深度和次大深度
def dfs(x: int, fa: int) -> int:
p = nodes[x]
for y in g[x]:
if y == fa:
continue
max_d = dfs(y, x) + 2 - y % 2 # 从 x 出发,往 y 方向的最大深度
if max_d > p['maxD']:
p['maxD2'] = p['maxD']
p['maxD'] = max_d
p['y'] = y
elif max_d > p['maxD2']:
p['maxD2'] = max_d
return p['maxD']
dfs(0, -1)
ans = [0] * n
# 第二次 DFS:换根,计算每个节点作为根时的最大深度
def reroot(x: int, fa: int, from_up: int) -> None:
p = nodes[x]
ans[x] = max(from_up, p['maxD'])
for y in g[x]:
if y == fa:
continue
w = 2 - x % 2 # 从 y 到 x 的边权
if y == p['y']: # 对于 y 来说,上面要选次大的
reroot(y, x, max(from_up, p['maxD2']) + w)
else: # 对于 y 来说,上面要选最大的
reroot(y, x, max(from_up, p['maxD']) + w)
reroot(0, -1, 0)
return ans
# 测试用例
if __name__ == "__main__":
edges = [[0, 1], [0, 2]]
result = time_taken(edges)
print(result) # 输出结果