當前位置: 妍妍網 > 碼農

大模型訓練中最佳化策略(數據並列、模型並列、ZeRO等)

2024-05-27碼農

GPU 視訊記憶體分析

GPU視訊記憶體分布.png

在微調時,模型視訊記憶體占用主要包括 模型參數 參數梯度 最佳化器 中間結果 四個部份。

對於一個 6B 參數量的模型,它的模型參數占用為:

將模型參數視為基準,模型梯度占用量與模型參數相同。

最佳化器主采用 Adam Optimizer ,它核心計算公式如下:

由於需要保存 m 和 v,而 m 和 v 規模與參數梯度相同,因此最佳化器需要兩倍視訊記憶體容量。

同時,在計算中得到的中間結果需要保存在視訊記憶體中,以便反向傳播時計算梯度。對於每一個中間結果,其數據形狀為 [Batch, SeqLen, Dim]。

Collective Operations

為了節省視訊記憶體,可以將模型或者數據分配到不同的顯卡上,顯卡之間有如下幾種 Collective Operations。

Broadcast

廣播.png

The Broadcast operation copies an N-element buffer on the root rank to all ranks.

廣播操作將一張顯卡上數據廣播到所有顯卡。

AllReduce、Reduce、ReduceScatter

AllReduce.png

reduce.png

ReduceScatter.png

The AllReduce operation is performing reductions on data (for example, sum, min, max) across devices and writing the result in the receive buffers of every rank.

The Reduce operation is performing the same operation as AllReduce, but writes the result only in the receive buffers of a specified root rank.

The ReduceScatter operation performs the same operation as the Reduce operation, except the result is scattered in equal blocks between ranks, each rank getting a chunk of data based on its rank index.

AllReduce 操作將所有顯卡上數據進行聚合如 求和 取最大值 取最小值 ,並將結果寫入所有顯卡。

Reduce 只會將結果寫入一張顯卡。

ReduceScatter 則將結果分散在所有顯卡中。

AllGather

AllGather.png

The AllGather operation gathers N values from k ranks into an output of size k*N, and distributes that result to all ranks.

AllGather 操作會收集所有顯卡數據,並寫入所有顯卡中。

數據並列

數據並列是將數據分成若幹份,裝載到不同節點上進行計算。

數據並列.png

數據平行計算流程如下:

  1. 有個參數伺服器保存模型參數。

  2. 參數被復制到不同的裝置中,構成若幹 replicas 。每個 replica 處理一部份數據,進行前向傳播和反向傳播。

  3. 每個裝置得到梯度進行 Reduce 操作,得到最終梯度,並按照這個梯度更新參數伺服器中的模型參數。

  4. 在後向傳播時,每計算完一層的梯度,就可以進行 Reduce 操作,提高並列性。

分布式數據並列

分布式數據並列.png

分布式數據並列中不存在參數伺服器,其計算流程如下:

  1. 每個 replica 都保存模型參數,但是分別計算部份數據,進行前向傳播和反向傳播。

  2. 每個裝置都得到梯度後進行 AllReduce 操作,將梯度寫入所有裝置,每個裝置根據自己的最佳化器和梯度更新參數。

分布式數據並列中,每個裝置視訊記憶體占用情況如圖:

分布式數據並列視訊記憶體占用.png

其中每個裝置仍需要保存模型參數、梯度和最佳化器參數。

模型並列

由於模型越來越大,單個裝置保存模型參數、梯度和最佳化器越來越難。因為深度學習主要是矩陣計算,而矩陣計算可以分塊計算,因此可以將模型參數拆成若幹份,每份單獨計算,以減少視訊記憶體占用。

模型並列.png

其計算流程如下:

  1. 將參數矩陣分成若幹子矩陣,分發到不同裝置中。

  2. 每個裝置計算不同矩陣,然後將結果收集起來。

模型並列後,視訊記憶體占用如下:

模型並列視訊記憶體占用.png

由於每個裝置處理所有數據,因此中間結果都會保存在所有裝置中。

ZeRO

在分布式數據並列中,最後梯度更新在不同裝置進行的操作相同,多個裝置中參數相同,梯度相同,最佳化器狀態相同,存在大量冗余。

ZeRO-1 對最佳化器狀態進行分片。

ZeRO-1.png

ZeRO-1 計算流程如下:

  1. 每個 replica 處理一部份數據輸入。

  2. 獨立進行前向傳播。

  3. 獨立進行反向傳播。

  4. 得到完整梯度後進行 ReduceScatter,每個 replica 得到對應梯度。

  5. 每個 replica 更新梯度對應的部份參數。

  6. 使用 AllGather 同步更新所有參數。

ZeRO-2 計算流程與1基本相同,ZeRO-2在後向傳播時,每計算一層梯度,就可以使用 ReduceScatter 進行同步,提高並列度。同時由於不需要完整計算梯度之後進行 ReduceScatter,每個 replica 只需要保存部份梯度即可。

ZeRO-3 在 2 的基礎上,將模型參數進行分片。

ZeRO-3.png

ZeRO-3 計算流程如下:

  1. 每個 replica 處理一部份輸入。

  2. 前向傳播時,當需要別的層參數,使用 AllGather 獲取。

  3. 反向傳播時,當需要別的層參數時,使用 AllGather 獲取,同時計算出每一層梯度時,使用 ReduceScatter 分發到對應 replica。

  4. 每個 replica 用於部份最佳化器參數和梯度,進行對應參數更新。

不同 ZeRO 對應的視訊記憶體占用情況:

ZeRO視訊記憶體占用.png

流水線並列

將模型一層一層分開,不同層放入不同 GPU 進行計算。個人理解與模型並列不同的是,模型並列保留從頭到尾每一層的部份參數,輸入可以計算出結果。流水線並列需要等前一層計算完畢才能進行計算。

流水線並列.png

流水線並列視訊記憶體分析:

流水線並列視訊記憶體分析.png

混合精度

FP16 相較於 FP32 計算更快,同時占用更少的視訊記憶體。但同時 FP16 表示的範圍小,可能產生溢位錯誤。

特別的,在權重更新時 gradient * lr 導致下溢位。

混合精度訓練的思路在最佳化器中保留一份 FP32 格式的參數副本,而模型權重、梯度等數據在訓練中都是用 FP16 來儲存。

混合精度.png

最佳化器中參數更新在 FP32 格式下保證精度,之後轉換為 FP16 格式。

Checkpointing

由於模型反向傳播需要中間結果計算梯度,大量中間結果占用大量視訊記憶體。

Checkpointing 思路是保存部份隱藏層的結果(作為檢查點),其余的中間結果直接釋放。當反向傳播需要計算梯度時,從檢查點開始重新前向傳播計算中間結果,得到梯度後再次釋放。