๐ฅ๐ฅ[KCache] EFFICIENT LLM INFERENCE WITH KCACHE(@Qiaozhi He, Zhihua Wu)
https://arxiv.org/pdf/2404.18057
TL; DR
์ด ๋
ผ๋ฌธ์ "KCache"๋ผ๋ ์๋ก์ด ๊ธฐ์ ์ ๋ํด ์ค๋ช
ํฉ๋๋ค. ์ด ๊ธฐ์ ์ ๋ํ ์ธ์ด ๋ชจ๋ธ(LLM)์ ์ถ๋ก ๊ณผ์ ์์ ๋ฉ๋ชจ๋ฆฌ ๋ณ๋ชฉ ํ์์ ์ํํ๊ณ ์์คํ
์ฒ๋ฆฌ๋์ 40% ํฅ์์ํค๋ฉด์๋ ์ ํ์ฑ์ ์ ์งํ๋ ๋ฐฉ๋ฒ์ ์ ์ํฉ๋๋ค. KCache๋ ๊ธฐ์กด์ KV Cache๋ฅผ ๋์ฒดํ์ฌ, ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด๋ฉด์ ์ถ๋ก ์๋๋ฅผ ๊ฐ์ ํฉ๋๋ค. ์ด๋ฅผ ์ํด K ์บ์๋ HBM(High Bandwidth Memory)์ ์ ์งํ๊ณ , V ์บ์๋ CPU ๋ฉ๋ชจ๋ฆฌ์ ์ ์ฅํ์ฌ ํ์ํ ์ ๋ณด๋ง ๋์ ์ผ๋ก ์ ํํ๊ณ GPU๋ก ๋ณต์ฌํฉ๋๋ค.
๋ํ, ์ด ๋
ผ๋ฌธ์ LLM์ ๋ ๊ฐ์ง ์ฃผ์ ๋จ๊ณ์ธ prefill ๋จ๊ณ์ decode ๋จ๊ณ์์ KCache์ ๊ตฌํ ๋ฐฉ๋ฒ์ ์ค๋ช
ํฉ๋๋ค. Prefill ๋จ๊ณ์์๋ ์
๋ ฅ ํ๋กฌํํธ๋ฅผ ๋ฐ์ ๋ณ๋ ฌ๋ก ๊ณ์ฐ์ ์ํํ๊ณ , decode ๋จ๊ณ์์๋ ์์ฑ๋ ํ ํฐ์ ๊ธฐ๋ฐ์ผ๋ก ํ๋์ฉ ์ถ๋ ฅ ํ ํฐ์ ์์ฑํฉ๋๋ค. ์ด ๊ณผ์ ์์ KCache๋ Attention ๊ณ์ฐ์ ๊ฒฐ๊ณผ๋ฅผ ๋ฐํ์ผ๋ก ์ค์ํ V ์บ์ ๋ฐ์ดํฐ๋ง์ ์ ํ์ ์ผ๋ก GPU๋ก ์ ์กํ์ฌ ์ฒ๋ฆฌ๋์ ์ต์ ํํฉ๋๋ค.
๋
ผ๋ฌธ์ KCache๊ฐ ์ฌ๋ฌ LLM ๋ชจ๋ธ์์ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ๊ฐ์ ํ๊ณ , ๋์ ์ฒ๋ฆฌ๋์ ๋ฌ์ฑํ๋ฉด์๋ ์ถ๋ก ์ ํ๋๋ฅผ ์ ์งํ๊ฑฐ๋ ํฅ์์ํฌ ์ ์์์ ์คํ์ ํตํด ๋ณด์ฌ์ค๋๋ค. ๊ฒฐ๊ณผ์ ์ผ๋ก, KCach๋ ์ถ๋ก ๊ณผ์ ์์์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๊ณผ GPU ๋ถํ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ๊ด๋ฆฌํ๋ฉด์ ๋์ ์ฑ๋ฅ์ ์ ๊ณตํ๋ ์ ์ฐํ๊ณ ํ์ฅ ๊ฐ๋ฅํ ์๋ฃจ์
์ ์ ๊ณตํฉ๋๋ค.
Instruction
LLM ์ถ๋ก ์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ฃผ๋ก ์ธ ๋ถ๋ถ์ผ๋ก ๊ตฌ์ฑ๋ฉ๋๋ค: Weight, Activations, ๊ทธ๋ฆฌ๊ณ KV Cache์ ๋๋ค. ์๋ฅผ ๋ค์ด, LLaMA2-7B ๋ชจ๋ธ์ ๊ฒฝ์ฐ, ๊ฐ์ค์น๋ fp16์์ ์ฝ 14GB์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฐจ์งํฉ๋๋ค. ๋ฐฐ์น ํฌ๊ธฐ๊ฐ 8์ด๊ณ ์ํ์ค ๊ธธ์ด๊ฐ 32 × 1024์ผ ๋, KV ์บ์๋ ์ฝ 128GB์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฐจ์งํ๋ฉฐ, layer-wise memory sharing ์ ๋ต์ ์ฌ์ฉํ๋ฉด activation์ ๋จ์ง 2GB์ ๋ฉ๋ชจ๋ฆฌ๋ง์ ์ฐจ์งํฉ๋๋ค. ๋ฐฐ์น ํฌ๊ธฐ์ ์ํ์ค ๊ธธ์ด๊ฐ ์ฆ๊ฐํจ์ ๋ฐ๋ผ, KV ์บ์์ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ ํ์ ์ผ๋ก ์ฆ๊ฐํฉ๋๋ค.
KV Cache์ ๋ํ ๋ด์ฉ์ ๋ ์์ธํ ์์๋ณด์๊ณ ์ถ์ผ์๋ค๋ฉด, ์ด์ ์ ๋ธ๋ก๊ทธ์ ์์ฑํ ๋ด์ฉ์ ์ฐธ๊ณ ํ์๋ฉด ์ข์ต๋๋ค
์ต๊ทผ์ LLM๋ค์ ์ถ๋ก ๊ณผ์ ์์ KV Cache์์ ๋ณ๋ชฉ ํ์์ด ๋ฐ์ํฉ๋๋ค. ์ด๋ฌํ ๋ณ๋ชฉ ํ์์ ์ํํ๊ธฐ ์ํด ์ฌ๋ฌ ์ต์ ํ ๋ฐฉ๋ฒ์ด ์ ์๋๊ณ ์์ต๋๋ค. KV ์บ์๋ฅผ ์์ถํ๊ธฐ ์ํ ๋ฐฉ์๋ค๋ก์จ ๋ฐ์ดํธ์ ๊ด์ ์์ ์์ถํ๋ ์์ํ ์๊ณ ๋ฆฌ์ฆ, ์ํ์ค ๊ธธ์ด(s) ๊ด์ ์์ ์์ถํ๋ ์ปจํ ์คํธ ์๋์ฐ ์์ถ ์๊ณ ๋ฆฌ์ฆ์ด ์ ์๋์์ต๋๋ค. Adaptive computation ์ฐ์ฐ ์๊ณ ๋ฆฌ์ฆ์ ๋ ์ด์ด(l) ๊ด์ ์์ ๊ณ์ฐ์ ์ค์ด๊ธฐ ์ํด early exit decoding์ ์ ์ํ์์ผ๋ฉฐ, ๋ฉํฐ ํค๋ ์ดํ ์ (MHA)์ ๊ตฌ์กฐ๋ฅผ ๊ฐ์ ํ์ฌ ์ถ๋ก ์ ๊ฐ์ํํ๋ ๋ฐฉ๋ฒ๋ ์ ์๋์์ต๋๋ค. K ์บ์์ V ์บ์ ๊ด์ ์์, ์ถ๋ก ์ค์ CPU๋ก ์คํ๋ก๋ฉํ๊ณ GPU๋ก ๋ค์ ๋ก๋ฉํ๋ ๊ฒ์ GPU ๋ฉ๋ชจ๋ฆฌ์ ๋ํ ์๋ฐ์ ์ํํ ์ ์์ง๋ง, ํ์ฌ ํธ์คํธ-๋๋ฐ์ด์ค(H2D) ๋ฐ ๋๋ฐ์ด์ค-ํธ์คํธ(D2H) ๋์ญํญ์ด ์ถ๋ก ์ ์๋ก์ด ๋ณ๋ชฉ ์ง์ ์ด ๋ฉ๋๋ค.
Method
์๋๋ KCache๋ฅผ ํํํ๋ ๊ทธ๋ฆผ์ ๋๋ค.
Prefill ๋จ๊ณ ๋์์๋ ๊ฐ ๋ ์ด์ด์ ๊ณ์ฐ ๊ฒฐ๊ณผ๊ฐ ๊ณ ๋์ญํญ ๋ฉ๋ชจ๋ฆฌ(HBM)๋ก push๋ฉ๋๋ค. ๊ทธ ํ, V ์บ์์ ์ผ๋ถ๊ฐ CPU๋ก ๋น๋๊ธฐ์ ์ผ๋ก ๋ณต์ฌ๋๋ฉฐ, ์ด ๊ณผ์ ์์ ํด๋น V ์บ์๊ฐ ์ฐจ์งํ๋ GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ํด์ ๋ฉ๋๋ค. Decode ๋จ๊ณ์์๋ K state๊ฐ KV Cache๋ก push๋๊ณ pull๋ฉ๋๋ค. ๊ทธ๋ฌ๋ KCache๋ TopN์ attention score๋ฅผ ๊ณ์ฐํ๊ณ , TopN ๊ฒฐ๊ณผ์ ์ธ๋ฑ์ค์ ๊ธฐ๋ฐํ์ฌ, ํด๋น V Cache๋ฅผ CPU์์ HBM์ผ๋ก ์ค์๊ฐ์ผ๋ก ๊ฐ์ ธ์ ํ์ ๊ณ์ฐ์ ์๋ฃํฉ๋๋ค.
์ฅ๋ฌธ์ Context์์ ์ฌ์ฉ์๋ ์ผ๋ฐ์ ์ผ๋ก ๊ธด ์ํ์ค์ ๊ธฐ๋ฐํ ์ฌ๋ฌ ๋ผ์ด๋์ ์ง๋ฌธ์ ํฉ๋๋ค. ๊ฐ ์ง๋ฌธ์ ๊ธด ๋งฅ๋ฝ์ ๋ค๋ฅธ ์ธ๊ทธ๋จผํธ์ ์ด์ ์ ๋ง์ถ ์ ์์ต๋๋ค. ๊ฐ ๋ผ์ด๋์์ ๊ฒฐ๊ณผ์ ์ ํ์ฑ์ ๊ทน๋ํํ๊ธฐ ์ํด, KCache์์๋ KV state์ ์ถ์๋ ์์ถ์ ํผํจ์ผ๋ก์จ ๋ชจ๋ธ์ upper bound๋ฅผ ๋ณด์ฅํฉ๋๋ค. ๊ทธ๋ฌ๋, ์ถ๋ก ์ค์ KV state๋ฅผ CPU ๋ฉ๋ชจ๋ฆฌ๋ก ์คํ๋ก๋ฉํ๊ณ GPU๋ก ๋ค์ ๋ก๋ฉํ๋ ๊ฒ์ ์ถ๋ก ์ E2E ์๊ฐ์ ํฌ๊ฒ ์ฆ๊ฐ์ํฌ ๊ฒ์ ๋๋ค. ๋ฐ๋ผ์ ๋ชจ๋ธ์ ํจ๊ณผ์ฑ๊ณผ ์ถ๋ก ์ง์ฐ ์๊ฐ ์ฌ์ด์ ๊ท ํ์ ๋ง์ถ๊ธฐ ์ํด์๋ ํ์ํ ์ ๋ณด๋ง HBM์ผ๋ก ๋ค์ ๋ก๋ํ๋ ๋ฐฉ๋ฒ์ ์ฐพ์์ผ ํ๋ฉฐ, ์ด๋ ์ด๋ค ์ ๋ณด๊ฐ ์ค์ํ์ง ๊ฒฐ์ ํ๋ ๋ชจ๋์ด ํ์ํจ์ ์๋ฏธํฉ๋๋ค. ๋คํํ๋, Attention ๋ฉ์ปค๋์ฆ์์ Key์ Value ์์ ์๋ฏธ๋ฅผ ๊ณ ๋ คํ ๋, Key๋ Query์์ ๊ด๋ จ์ฑ์ ๊ณ์ฐํ๋ ๋ฐ ์ฌ์ฉ๋๊ณ Value๋ Key์ ๊ด๋ จ๋ ์ค์ ์ ๋ณด๋ฅผ ๋ํ๋ด๋ฏ๋ก, K ์บ์์ V ์บ์์ ์ผ๋ถ๋ฅผ CPU ๋ฉ๋ชจ๋ฆฌ๋ก ์คํ๋ก๋ํ๋ ๊ฒ์ด KCache์ ์๊ฐ์ ์คฌ์ต๋๋ค.
Analysis of KCache Performance
Prefill ํ์ด์ฆ์์๋ V Cache์ ์ผ๋ถ๋ CPU ๋ฉ๋ชจ๋ฆฌ์ ๋น๋๊ธฐ์ ์ผ๋ก ๋ณต์ฌ๋์ด์ผ ํฉ๋๋ค. ์ด๋ฌํ ์ํฉ์์ ์ ์๋ค์ ๊ฐ ๋ ์ด์ด์ ๋ํ computation time์ด ์ด์ ๋ ์ด์ด์ ๋ฐ์ดํฐ๋ฅผ ์นดํผํ๋ ์๊ฐ๊ณผ ์ค๋ฒ๋ฉ๋ ๊ฒ์ผ๋ก ๋ฐ๋ฌ์ต๋๋ค. ๊ฐ ํธ๋์คํฌ๋จธ ๋ธ๋ก์ ๋ํด์ $ 2bsd $ bytes์ ๋ฐ์ดํฐ๊ฐ D2H๋ก ์ ์ก๋์ด์ผํ๊ณ , $ 22bsd^2 + 4bs^2d $ FLops๊ฐ ์ฐ์ฐ๋์ด์ผ ํฉ๋๋ค.
NVIDIA A100 (80GB) GPU๋ฅผ ์ฌ์ฉํ์ ๋, LLaMA2-7B๋ $ d = 4096 $์ด๊ณ , ์ด๋ ์๋์ ๊ฐ์ ๋ถ๋ฑ์์ ๋ง์กฑํฉ๋๋ค.
Decode ํ์ด์ฆ์์๋ MHA ๋ชจ๋์ ์ ํ์ ์ธ memory-boud ์์ ์ ๋๋ค. Arithmetic Intensity๋ ๋ถ๋ ์์์ ์ฐ์ฐ(FLOPs)๊ณผ ์ ์ถ๋ ฅ ๋ฐ์ดํธ์ ๋น์จ๋ก ์ ์๋ฉ๋๋ค. ์ด๋ ๋์ฝ๋ฉ ์ค MHA ๋ชจ๋์ ๊ณ์ฐ ์๊ฐ์ด ๋ฉ๋ชจ๋ฆฌ ์ ๊ทผ๋์ ํฌ๊ฒ ์์กดํจ์ ๋ํ๋ ๋๋ค. ํนํ, ๋์ฝ๋ ๋จ๊ณ์์ MHA ๋ชจ๋์ ์ฑ๋ฅ์ hidden size์ ์ํ์ค ๊ธธ์ด์๋ ๋ ๋ฆฝ์ ์ด๋ฉฐ ์ค๋ก์ง batch size์๋ง ์ํฅ์ ๋ฐ์ต๋๋ค. ์ด ๊ด์ฐฐ์ ๋ค์๊ณผ ๊ฐ์ ๊ธฐ๋๋ฅผ ์ด๋์ด๋ ๋๋ค: ์ ์๋ KCache MHA ๋ชจ๋์ ๊ณ์ฐ ์๊ฐ๊ณผ ๋ฐ์ดํฐ ์ ์ก ์๊ฐ์ ๊ธฐ์กด์ KV ์บ์ MHA ๊ตฌํ๋ณด๋ค ์ ์ ์ ์์ต๋๋ค.
NVIDIA A100 (80GB) GPU๋ฅผ ์ฌ์ฉํ์ ๋, ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ์ 2039GB/s GPU Memory Bandwidth and 32GB/s H2D Bandwidth์ด๊ณ , KCache๋ ์ด์ ๋ฐ๋ผ A100์์ $s/N > 64$์ธ ๊ฒฝ์ฐ์๋ ์ฑ๋ฅ์ด ๋จ์ด์ง์ง ์์ต๋๋ค.
Analysis of KCache Accuracy
๋ฐฉ๋ฒ๋ก ์ ๋ฐ๋ฅด๋ฉด Value ํ ์๊ฐ CPU ๋ฉ๋ชจ๋ฆฌ์ ๋น๋๊ธฐ์ ์ผ๋ก ๋ณต์ฌ๋๋ prefill ํ์ด์ฆ์์๋ ์ถ๋ก ์ฑ๋ฅ๊ณผ ์ ํ๋์ ์ํฅ์ ์ฃผ์ง ์์ต๋๋ค. Decode ํ์ด์ฆ์์๋ H2D ๋ฐ์ดํฐ์ ์์ ์ค์ด๋ ๊ฒ์ด ํ์ฐ์ ์ ๋๋ค. ๋ง์ฝ ์ดํ ์ ์ค์ฝ์ด์ ๊ฒฐ๊ณผ๊ฐ ์ถฉ๋ถํ sparseํ๋ค๋ฉด, ์ต์ข ๊ฒฐ๊ณผ์ ๋ฏธ์น๋ ์ํฅ์ ๋ฌด์ํ ์ ์์ ์ ๋๋ก ์์์ง๋๋ค.
Experiments
Accuracy
Accuracy ์ธก๋ฉด์์, KCache๋ ๊ธฐ๋ณธ์ ์ผ๋ก ์์ค ์์ด ์ ํ์ฑ์ ์ ์งํ์ผ๋ฉฐ, ์ฌ๋ฌ ๋ฐ์ดํฐ ์ธํธ์ ๋ชจ๋ธ์์ ๋ ๋์ ์ฑ๋ฅ์ ๋ฌ์ฑํ์ต๋๋ค. L์ ํนํ ๋ชจ๋ธ ์ฑ๋ฅ์ด ์ถฉ๋ถํ ๊ฐ๋ ฅํ ๊ฒฝ์ฐ ๋ชจ๋ธ ์ ํ๋์ ์๋์ ์ผ๋ก ์์ ์ํฅ์ ๋ฏธ์นฉ๋๋ค. ๋ชจ๋ธ ์ฑ๋ฅ์ด ์๋์ ์ผ๋ก ์ฝํ ๊ฒฝ์ฐ์๋ L์ ๋ ํฌ๊ฒ ์ค์ ํ๋ ๊ฒ์ด ์ข์ต๋๋ค.
N์ด ํด์๋ก ์ ํ๋๊ฐ ๋์์ง๋ฉฐ, N = 128์ผ ๋ KCache๋ ๊ธฐ์ค๊ณผ ๋์ผํ๊ฑฐ๋ ๋ ๋์ ์ ํ๋๋ฅผ ์ ์งํ์ต๋๋ค. ์ ์๋ค์ TopN์ด ์ํํธ๋งฅ์ค๋ฅผ ์ ๊ทํํ๊ณ ๋ ธ์ด์ฆ ์ ๋ณด๋ฅผ ์ถ๊ฐ๋ก ํํฐ๋งํ๋ค๊ณ ๋ฏฟ์ต๋๋ค. ์ธ ๊ฐ์ง ๋ฐ์ดํฐ ์ธํธ์ ๋ํ ์คํ์ ํตํด ์ปจํ ์คํธ ๊ธธ์ด๊ฐ ์ฝ 2K ์ดํ์ธ ๊ฒฝ์ฐ N์ 64 ๋๋ 128๋ก ์ค์ ํด๋ ์ ํ๋์ ํฐ ์ํฅ์ ๋ฏธ์น์ง ์๋ ๊ฒ์ผ๋ก ํ์ธ๋์์ต๋๋ค.
Performance
Performance ์ธก๋ฉด์์, 64GB ๋ฉ๋ชจ๋ฆฌ, 1TB GPU ๋ฉ๋ชจ๋ฆฌ ๋์ญํญ ๋ฐ 180TFLOPS๋ฅผ ๊ฐ์ถ GPU์์ ์คํ์ ์ํํ์ต๋๋ค. LLaMA2-7B ๋ชจ๋ธ์ ํ๊ฐํ ๊ฒฐ๊ณผ, ํ 3์ ์คํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋๋ค. ์ ๋ฐ์ ์ผ๋ก S >> N ์ผ ๋ KCache๊ฐ ์ฑ๋ฅ ์ด์ ์ ๋ณด์ฌ์ค๋๋ค. ๋์์ KCache๋ 15K ์ปจํ ์คํธ ๊ธธ์ด์ N = 128 ์ค์ ์์ ์ถ๋ก ์ฒ๋ฆฌ๋์ 40% ์ด์ ํฅ์์์ผฐ์ต๋๋ค.
'AI > Papers' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
S-LoRA: Serving Thousands of Concurrent LoRA Adapters (0) | 2023.12.13 |
---|