๐ฅฅ CoCoNut; Training Large Language Models to Reason in a Continuous Latent Space (Meta, 2024)
1/ Chain-Of-Thought (CoT)
CoT ํ๊ณ: LLM์ reasoning์ด ํ ์คํธ ํํ๋ก ์์ฑ๋์ด์ผ ํ๋ค๋ ์ ์ ์ ์ฝ์ ๊ฐํ ์ ์๋ค.
Neuroimaging ์ฐ๊ตฌ์ ์ํ๋ฉด ์ธ์ด ์ดํด ๋ฐ ์์ฑ์ ๋ด๋นํ๋ ์ธ๊ฐ ๋๋ ์์ญ์ด ์ถ๋ก ๊ณผ์ ์ค์๋ ๋นํ์ฑํ๋๋ค๊ณ ํจ. ์ด๋ ์ธ์ด๋ communication์ ์ ํฉํ ๋ฟ ๋ณต์กํ ๋ฌธ์ ํด๊ฒฐ์๋ ๋ถํ์ํ๋ค๋ ๊ฒ์ ์์ฌํ๋ค.
์ธ๊ฐ์ด ์ถ๋ก ์ค์ ์๊ฐ์ ์ธ์ด๋ก ๋ฐ๊ฟ์ผ ํ ํ์๊ฐ ์๋ ๊ฒ์ฒ๋ผ AI๋ ๋ง์ฐฌ๊ฐ์ง์ด๋ค. (์ธ๊ณต์ง๋ฅ์ด ์ธ๊ฐ์ ๋ฐ๋ผํ๋ค๋ ๊ฒ์ด ๋ง์ด ๋๊ปด์ง๋ ๋ถ๋ถ) โ LLM๋ language space ๋์ latent space์์ reasoning์ ์ํํ ํ์๊ฐ ์๋ค. ๋ชจ๋ธ์ด ์์ฑํ๋ ๋๋ถ๋ถ์ tokens๋ ํ ์คํธ์ ์ผ๊ด์ฑ์ ์ ์งํ๋ ๋ฐ ํ์ํ ๋ฟ ์ค์ ๋ก ์ถ๋ก ์ ํฌ๊ฒ ๊ธฐ์ฌํ์ง ์๋๋ค.
์ธ์ด ์ ์ฝ ์์ด ์์ ๋ก์ด ์ถ๋ก ์ด ๊ฐ๋ฅํ๋๋ก. ํ์ํ ๋๋ง ๊ฒฐ๊ณผ๋ฅผ ์ธ์ด๋ก ๋ฒ์ญํ ํ์๊ฐ ์๋ค.
๋ณธ ์ฐ๊ตฌ๋ word-based reasoning์ ์ ์ฝ์์ ๋ฒ์ด๋ LLM์ด continuous latent space์์ ์ถ๋ก ํด์ผํ๋ค๊ณ ์ ์ํ๋ค. ํด๋น method๋ฅผ CoCoNUT (Chain of Continuous Thought)๋ผ๊ณ ๋ถ๋ฅธ๋ค.
2/ CoT vs CoCoNUT (Chain of Continuous Thought)

CoT๋ ์ถ๋ก ๊ณผ์ ์ word token sequence๋ก ์์ฑํ๋ ๋ฐ๋ฉด, CoCoNUT๋ last hidden state๋ฅผ reasoning state(continuous thought)๋ก ํํํ์ฌ next input embedding์ผ๋ก ์ง์ ์ด์ฉํ๋ค.
โ LLM์ด language space๊ฐ ์๋๋ผ ์ ์ฝ์ด ์๋ latent space์์ ์ถ๋ก ํ ์ ์๋๋ก ํ๋ค.
3/ CoT Method
question โ LLM์ input tokens๋ก embedding๋์ด input.
โ response์ ์ฒซ ๋ฒ์งธ token์ ๋ฐ์ (์ถ๋ก ๊ณผ์ ์์), ํด๋น token์ last hidden state์์ ๊ฐ์ ธ์ด. (์ฆ, backbone Transformer์ ๋ง์ง๋ง ๋ ์ด์ด์ ์ถ๋ ฅ)
โ forward pass ๋ฐ๋ณต, ํ์ฌ stage๊น์ง ๊ฐ์ง reasoning process tokens๊ณผ question์ ๊ณต๊ธ
4/ Coconut Method
language mode์์ latent thought mode๋ก ๋ณ๊ฒฝ.
๋ชจ๋ธ์ ๊ธฐ๋ณธ language model๋ก ์๋ํ๋ฉด์ next token์ ์์ฑ.
latent mode์์ last hidden state๋ฅผ next step์ input์ผ๋ก ์ฌ์ฉ.
last hidden state๋ current reasoning state๋ฅผ ๋ํ๋ด๊ณ , ์ด๋ฅผ โcontinuous thoughtโ๋ผ๊ณ ํจ.
<bot> special token + question *<bot> : latent thought mode ์์
โ question์ ์ฒ๋ฆฌํ๊ณ last hidden state๋ฅผ ์์ฐ (์ด์ ์๋ language token์ผ๋ก ๋ฐ๊ฟจ๋๋ฐ, ์ฌ๊ธฐ์๋ ์๋)
๋์ hidden state๊ฐ ๋ค์ ๋ชจ๋ธ์ input embedding์ผ๋ก question์ embeddings์ special token๊ณผ ํจ๊ป ๋ค์ด๊ฐ๋ค.
โ ๋ฐ๋ณต
โ ๋ฐ๋ณตํ๋ฉด์ ์ ์ ๋ ๋ง์ thought tokens๋ฅผ input์ผ๋ก ํ์ฉ
โ <eot> special token ์ฌ์ฉ *<eot> : latent thought mode ์ข
๋ฃ ๋ฐ language mode ์์
5/ Training

continuous latent space์์ ์ถ๋ก ํ๋ ๋ฒ์ LLM์๊ฒ ์ด๋ป๊ฒ ํ๋ จํ ๊น?
Stage 0์์ ๋ชจ๋ธ์ thought tokens๋ฅผ ์์ฑํ์ง ์๋๋ค๋ค. CoT samples์ผ ๋ต์ ๋ง์ถ๋๋ก, reasoning traces๋ฅผ ์์ฐํ๋๋ก ํ๋ จ๋์ด ์์ ๋ฟ์ด๋ค.
๊ทธ ์ดํ ๊ฐ stage์์ sample๋ก๋ถํฐ ํ๋์ reasoning step์ ์ ๊ฑฐํ๊ณ , ๋์ ์ thought tokens์ ๋ฃ๋๋ค. (ํํ ๊ฐ์ singe reasoning step ๋์ ํ ๊ฐ์ thought tokens์ด ๊ฐ stage์ ์ถ๊ฐ, hyperparametr c๋ก ์ปจํธ๋กค)
*c : ํ ๋ฒ์ ์ถ๋ก ๋จ๊ณ์์ ์์ฑํ ์ ์๋ thought (continuous embeddings)์ ๊ฐ์
๊ฐ stage๋ง๋ค ๋จ์ ์๋ reasoning steps๊ณผ answer์ ๊ดํด์๋ง loss ๊ณ์ฐํ๋ค. (thought token๋ loss ๊ณ์ฐ ์ํจ) ๊ฐ pass๋ง๋ค ์๋ก์ด latent thought๋ฅผ ๊ณ์ฐํ๊ณ , ๋จ์ ์๋ text sequence์ ๋ํด์ loss๋ฅผ ์ป๋๋ค.
loss objective๋ continuous thought๊ฐ language thought๋ฅผ ์์ถํ๊ธฐ๋ณด๋ค reasoning์ ์์ธก ๋ฅ๋ ฅ์ ํฅ์์ํค๋ ๋ฐ ์ง์คํ๋ค. ๊ทธ๋ฌ๋ฏ๋ก ๋ชจ๋ธ์ด human language์ ๋น๊ตํ์ ๋ ๋ ํจ์จ์ ์ธ ํํ์ ๋ฐฐ์ธ ์ ์๋ค. language tokens๋ฅผ ์์ฑํ ํ์ ์์ด ๋ด๋ถ์ ์ผ๋ก ๊ณ์ ์ถ๋ก ํ ์ ์๋๋ก ํ๋ค.
6/ Switching
๋ชจ๋ธ์ ์ธ์ latent thought mode์์ language mode๋ก ๋ฐ๋๋์ง ์ด๋ป๊ฒ ์๊น?
- ๋ชจ๋ธ์ด latent thoughts์ ๊ธฐ๋ฐํด binary classifier ์ฌ์ฉ์ ๊ฒฐ์ ํ๋ ๊ฒ์ ๋ด๋ฒ๋ ค๋๊ธฐ
- latent thoughts์ ์ผ์ ํ ์๋ฅผ ์ฌ์ฉ
๋ ์ ๋ต ๋ชจ๋ ์ ์ฌํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์คฌ์ง๋ง, ๋จ์์ฑ์ ์ด์ ๋ก constant number of thoughts๋ฅผ ํํ๋ค.
7/ Results

GPT-2๋ฅผ ๊ธฐ๋ณธ ๋ชจ๋ธ๋ก ์ฌ์ฉ.
- No-CoT ๋๋น ๋ชจ๋ ๋ฐ์ดํฐ์ ์์ ์ฑ๋ฅ ์ฐ์ธ
- CoT ๋๋น ์ํ (GSM8k; ์ด๋ฑํ ์์ค ์ํ ๋ฌธ์ ) ์ CoT๊ฐ ๋ ์ฐ์ธ. ๊ทธ๋ฌ๋ CoT๋ Coconut์ ๋นํด ๋ ๋ง์ tokens๋ฅผ ํ์๋ก ํจ. โ Coconut์ด ํจ์ฌ ํจ์จ์ . ProsQA์ ๊ฐ์ด ๋จ๊ณ์ ์ฌ๊ณ ๊ฐ ๋ ํ์ํ ๋ฐ์ดํฐ์ ์์ ์ฐ์ธ.
- i-CoT (์ถ๋ก ๊ฐ์ ์ ๋ค๋ฅธ ๋ฐฉ๋ฒ์ผ๋ก ๋ด๋ถํํ๋ ค๋ ๋ฐฉ๋ฒ) ๋๋น ์ ์ฌํ๋ ์ํ์ coconut์ด ๋ ์ฐ์ธ. ๊ทธ๋ฌ๋ i-CoT๋ ๋ ์ ์ tokens ์ฌ์ฉ
w/o curriculum๋ multi-stage tranining์ ์ค์์ฑ์ ๋ณด์ฌ์ค๋ค.
8/ BFS-like Reasoning

ProsQA ๋ฐ์ดํฐ์ ๊ณผ ๊ฐ์ planning-intensive task์์ Coconut์ด CoT๋ณด๋ค ์ฐ์ธํ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์ฌ์ค๋ค.
CoT๋ ์ถ๋ก ์ค hallucinated edge๋ฅผ ์์ฑํด ์๋ชป๋ ๊ฒฐ๊ณผ๋ฅผ ๋์ถํ๊ณ , Coconut์ thought tokens๋ฅผ ์ฌ์ฉํด BFS์ ๋น์ทํ๊ฒ ์ฌ๋ฌ ๊ฒฝ๋ก ํ์ํ๋ค. thought token๋ฅผ ํ๋ ์ฌ์ฉํ ๋๋ ์ค๋ต์ด์ง๋ง, ๋ ๊ฐ๋ฅผ ์ฌ์ฉํ ๋๋ ๋ ๋ง์ ๊ฒฝ๋ก๋ฅผ ์ฐพ๊ณ ์ฌ๋ฐ๋ฅธ ๋ต ๋์ถํ๋ค. ๋จ์ผ ๊ฒฝ๋ก๊ฐ ์๋๋ผ ์ฌ๋ฌ ๊ฒฝ๋ก๋ฅผ ํตํด ํ์ํ๋ค๋ ์ ์ BFS์ ๋น์ทํ๋ค.
9/ Takeaways

Auto mode์์๋ LLM์ planning์ ํ์ง ๋ชปํ๋ค.
CoT, ReACT (Reasoning + Acting)
์ฌ๋์ด ๊ณ์ ํ๋กฌํํธ๋ฅผ ์์ ํ๋ ๊ณผ์ ์์ ๋ชจ๋ธ์ด ์ ๋ต์ ๋ง์ถ๋ ๊ฒ๊ณผ ๊ฐ์ Clever Hans ํจ๊ณผ๊ฐ ๋ฐ์ํ ์ ์๋ค.
*Clever Hans : LLM ์์ฒด ๊ฒ์ฆ ๋ฐ ๊ฐ์ ๋ฅ๋ ฅ์ด ๋ถ์กฑํ์ฌ ๋ ผ๋ฆฌ์ ์ผ๋ก ๋ฌธ์ ํด๊ฒฐํ๋ ๊ฒ์ด ์๋๋ผ ํ๋กฌํํธ, ์ฆ shallow heuristic์ ์์กดํ๋ฉด์ ๋ต์ ๋ด๋๋ ๊ฒ.
Clever Hans or Neural Theory of Mind? Stress Testing Social Reasoning in Large Language Models
(๋ค๋ง ํด๋น ๋ ผ๋ฌธ์์ 4o, o1์ ์๋ ์ ์ฐธ๊ณ ํด์ผ ํ ๊ฒ ๊ฐ๋ค… N-ToM;Neural Theory of Mind๊ณผ ๊ฐ์ด ์ธ๊ฐ ์์ค์ ์ ์์ ์ถ๋ก ๋ฅ๋ ฅ์ ๋ชจ๋ฐฉํ๊ธฐ ์ํด์๋ ์์ง ๊ฐ์ ์ด ํ์..?)
Reference
Training Large Language Models to Reason in a Continuous Latent Space
Chain of Continuous Thought (AIPapers Academy)
๋ฉํ์ ์ฝ์ฝ๋(COCONUT): ์ธ์ด ์์ด ์๊ฐํ๋ AI ๋ฐฉ๋ฒ (AI๋ท)