๐Ÿฅฅ CoCoNut; Training Large Language Models to Reason in a Continuous Latent Space (Meta, 2024)

1/ Chain-Of-Thought (CoT)

CoT ํ•œ๊ณ„: LLM์˜ reasoning์ด ํ…์ŠคํŠธ ํ˜•ํƒœ๋กœ ์ƒ์„ฑ๋˜์–ด์•ผ ํ•œ๋‹ค๋Š” ์ ์€ ์ œ์•ฝ์„ ๊ฐ€ํ•  ์ˆ˜ ์žˆ๋‹ค.

Neuroimaging ์—ฐ๊ตฌ์— ์˜ํ•˜๋ฉด ์–ธ์–ด ์ดํ•ด ๋ฐ ์ƒ์„ฑ์„ ๋‹ด๋‹นํ•˜๋Š” ์ธ๊ฐ„ ๋‘๋‡Œ ์˜์—ญ์ด ์ถ”๋ก  ๊ณผ์ • ์ค‘์—๋Š” ๋น„ํ™œ์„ฑํ™”๋œ๋‹ค๊ณ  ํ•จ. ์ด๋Š” ์–ธ์–ด๋Š” communication์— ์ ํ•ฉํ•  ๋ฟ ๋ณต์žกํ•œ ๋ฌธ์ œ ํ•ด๊ฒฐ์—๋Š” ๋ถˆํ•„์š”ํ•˜๋‹ค๋Š” ๊ฒƒ์„ ์‹œ์‚ฌํ•œ๋‹ค.

์ธ๊ฐ„์ด ์ถ”๋ก  ์ค‘์— ์ƒ๊ฐ์„ ์–ธ์–ด๋กœ ๋ฐ”๊ฟ”์•ผ ํ•  ํ•„์š”๊ฐ€ ์—†๋Š” ๊ฒƒ์ฒ˜๋Ÿผ AI๋„ ๋งˆ์ฐฌ๊ฐ€์ง€์ด๋‹ค. (์ธ๊ณต์ง€๋Šฅ์ด ์ธ๊ฐ„์„ ๋”ฐ๋ผํ•œ๋‹ค๋Š” ๊ฒƒ์ด ๋งŽ์ด ๋А๊ปด์ง€๋Š” ๋ถ€๋ถ„) โ†’ LLM๋„ language space ๋Œ€์‹  latent space์—์„œ reasoning์„ ์ˆ˜ํ–‰ํ•  ํ•„์š”๊ฐ€ ์žˆ๋‹ค. ๋ชจ๋ธ์ด ์ƒ์„ฑํ•˜๋Š” ๋Œ€๋ถ€๋ถ„์˜ tokens๋Š” ํ…์ŠคํŠธ์˜ ์ผ๊ด€์„ฑ์„ ์œ ์ง€ํ•˜๋Š” ๋ฐ ํ•„์š”ํ•  ๋ฟ ์‹ค์ œ๋กœ ์ถ”๋ก ์— ํฌ๊ฒŒ ๊ธฐ์—ฌํ•˜์ง€ ์•Š๋Š”๋‹ค.

์–ธ์–ด ์ œ์•ฝ ์—†์ด ์ž์œ ๋กœ์šด ์ถ”๋ก ์ด ๊ฐ€๋Šฅํ•˜๋„๋ก. ํ•„์š”ํ•  ๋•Œ๋งŒ ๊ฒฐ๊ณผ๋ฅผ ์–ธ์–ด๋กœ ๋ฒˆ์—ญํ•  ํ•„์š”๊ฐ€ ์žˆ๋‹ค.

๋ณธ ์—ฐ๊ตฌ๋Š” word-based reasoning์˜ ์ œ์•ฝ์—์„œ ๋ฒ—์–ด๋‚˜ LLM์ด continuous latent space์—์„œ ์ถ”๋ก ํ•ด์•ผํ•œ๋‹ค๊ณ  ์ œ์•ˆํ•œ๋‹ค. ํ•ด๋‹น method๋ฅผ CoCoNUT (Chain of Continuous Thought)๋ผ๊ณ  ๋ถ€๋ฅธ๋‹ค.

2/ CoT vs CoCoNUT (Chain of Continuous Thought)

image_1.png

CoT๋Š” ์ถ”๋ก  ๊ณผ์ •์„ word token sequence๋กœ ์ƒ์„ฑํ•˜๋Š” ๋ฐ˜๋ฉด, CoCoNUT๋Š” last hidden state๋ฅผ reasoning state(continuous thought)๋กœ ํ‘œํ˜„ํ•˜์—ฌ next input embedding์œผ๋กœ ์ง์ ‘ ์ด์šฉํ•œ๋‹ค.

โ†’ LLM์ด language space๊ฐ€ ์•„๋‹ˆ๋ผ ์ œ์•ฝ์ด ์—†๋Š” latent space์—์„œ ์ถ”๋ก ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•œ๋‹ค.

3/ CoT Method

question โ†’ LLM์— input tokens๋กœ embedding๋˜์–ด input.

โ†’ response์˜ ์ฒซ ๋ฒˆ์งธ token์„ ๋ฐ›์Œ (์ถ”๋ก  ๊ณผ์ • ์‹œ์ž‘), ํ•ด๋‹น token์€ last hidden state์—์„œ ๊ฐ€์ ธ์˜ด. (์ฆ‰, backbone Transformer์˜ ๋งˆ์ง€๋ง‰ ๋ ˆ์ด์–ด์˜ ์ถœ๋ ฅ)

โ†’ forward pass ๋ฐ˜๋ณต, ํ˜„์žฌ stage๊นŒ์ง€ ๊ฐ€์ง„ reasoning process tokens๊ณผ question์„ ๊ณต๊ธ‰

4/ Coconut Method

language mode์—์„œ latent thought mode๋กœ ๋ณ€๊ฒฝ.

๋ชจ๋ธ์€ ๊ธฐ๋ณธ language model๋กœ ์ž‘๋™ํ•˜๋ฉด์„œ next token์„ ์ƒ์„ฑ.

latent mode์—์„œ last hidden state๋ฅผ next step์˜ input์œผ๋กœ ์‚ฌ์šฉ.

last hidden state๋Š” current reasoning state๋ฅผ ๋‚˜ํƒ€๋‚ด๊ณ , ์ด๋ฅผ โ€˜continuous thoughtโ€™๋ผ๊ณ  ํ•จ.

<bot> special token + question *<bot> : latent thought mode ์‹œ์ž‘

โ†’ question์„ ์ฒ˜๋ฆฌํ•˜๊ณ  last hidden state๋ฅผ ์ƒ์‚ฐ (์ด์ „์—๋Š” language token์œผ๋กœ ๋ฐ”๊ฟจ๋Š”๋ฐ, ์—ฌ๊ธฐ์„œ๋Š” ์•„๋‹˜)

๋Œ€์‹  hidden state๊ฐ€ ๋‹ค์‹œ ๋ชจ๋ธ์— input embedding์œผ๋กœ question์˜ embeddings์™€ special token๊ณผ ํ•จ๊ป˜ ๋“ค์–ด๊ฐ„๋‹ค.

โ†’ ๋ฐ˜๋ณต

โ†’ ๋ฐ˜๋ณตํ•˜๋ฉด์„œ ์ ์  ๋” ๋งŽ์€ thought tokens๋ฅผ input์œผ๋กœ ํ™œ์šฉ

โ†’ <eot> special token ์‚ฌ์šฉ *<eot> : latent thought mode ์ข…๋ฃŒ ๋ฐ language mode ์‹œ์ž‘

5/ Training

image_2.png

continuous latent space์—์„œ ์ถ”๋ก ํ•˜๋Š” ๋ฒ•์„ LLM์—๊ฒŒ ์–ด๋–ป๊ฒŒ ํ›ˆ๋ จํ• ๊นŒ?

Stage 0์—์„œ ๋ชจ๋ธ์€ thought tokens๋ฅผ ์ƒ์„ฑํ•˜์ง€ ์•Š๋Š”๋‹ค๋‹ค. CoT samples์œผ ๋‹ต์„ ๋งž์ถ”๋„๋ก, reasoning traces๋ฅผ ์ƒ์‚ฐํ•˜๋„๋ก ํ›ˆ๋ จ๋˜์–ด ์žˆ์„ ๋ฟ์ด๋‹ค.

๊ทธ ์ดํ›„ ๊ฐ stage์—์„œ sample๋กœ๋ถ€ํ„ฐ ํ•˜๋‚˜์˜ reasoning step์„ ์ œ๊ฑฐํ•˜๊ณ , ๋Œ€์‹ ์— thought tokens์„ ๋„ฃ๋Š”๋‹ค. (ํ•œํ•œ ๊ฐœ์˜ singe reasoning step ๋Œ€์‹  ํ•œ ๊ฐœ์˜ thought tokens์ด ๊ฐ stage์— ์ถ”๊ฐ€, hyperparametr c๋กœ ์ปจํŠธ๋กค)

*c : ํ•œ ๋ฒˆ์˜ ์ถ”๋ก  ๋‹จ๊ณ„์—์„œ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋Š” thought (continuous embeddings)์˜ ๊ฐœ์ˆ˜

๊ฐ stage๋งˆ๋‹ค ๋‚จ์•„ ์žˆ๋Š” reasoning steps๊ณผ answer์— ๊ด€ํ•ด์„œ๋งŒ loss ๊ณ„์‚ฐํ•œ๋‹ค. (thought token๋Š” loss ๊ณ„์‚ฐ ์•ˆํ•จ) ๊ฐ pass๋งˆ๋‹ค ์ƒˆ๋กœ์šด latent thought๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ , ๋‚จ์•„ ์žˆ๋Š” text sequence์— ๋Œ€ํ•ด์„œ loss๋ฅผ ์–ป๋Š”๋‹ค.

loss objective๋Š” continuous thought๊ฐ€ language thought๋ฅผ ์••์ถ•ํ•˜๊ธฐ๋ณด๋‹ค reasoning์˜ ์˜ˆ์ธก ๋Šฅ๋ ฅ์„ ํ–ฅ์ƒ์‹œํ‚ค๋Š” ๋ฐ ์ง‘์ค‘ํ•œ๋‹ค. ๊ทธ๋Ÿฌ๋ฏ€๋กœ ๋ชจ๋ธ์ด human language์™€ ๋น„๊ตํ–ˆ์„ ๋•Œ ๋” ํšจ์œจ์ ์ธ ํ‘œํ˜„์„ ๋ฐฐ์šธ ์ˆ˜ ์žˆ๋‹ค. language tokens๋ฅผ ์ƒ์„ฑํ•  ํ•„์š” ์—†์ด ๋‚ด๋ถ€์ ์œผ๋กœ ๊ณ„์† ์ถ”๋ก ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•œ๋‹ค.

6/ Switching

๋ชจ๋ธ์€ ์–ธ์ œ latent thought mode์—์„œ language mode๋กœ ๋ฐ”๋€Œ๋Š”์ง€ ์–ด๋–ป๊ฒŒ ์•Œ๊นŒ?

  • ๋ชจ๋ธ์ด latent thoughts์— ๊ธฐ๋ฐ˜ํ•ด binary classifier ์‚ฌ์šฉ์„ ๊ฒฐ์ •ํ•˜๋Š” ๊ฒƒ์„ ๋‚ด๋ฒ„๋ ค๋‘๊ธฐ
  • latent thoughts์˜ ์ผ์ •ํ•œ ์ˆ˜๋ฅผ ์‚ฌ์šฉ

๋‘ ์ „๋žต ๋ชจ๋‘ ์œ ์‚ฌํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์—ฌ์คฌ์ง€๋งŒ, ๋‹จ์ˆœ์„ฑ์„ ์ด์œ ๋กœ constant number of thoughts๋ฅผ ํƒํ•œ๋‹ค.

7/ Results

image_3.png

GPT-2๋ฅผ ๊ธฐ๋ณธ ๋ชจ๋ธ๋กœ ์‚ฌ์šฉ.

  • No-CoT ๋Œ€๋น„ ๋ชจ๋“  ๋ฐ์ดํ„ฐ์…‹์—์„œ ์„ฑ๋Šฅ ์šฐ์„ธ
  • CoT ๋Œ€๋น„ ์ˆ˜ํ•™ (GSM8k; ์ดˆ๋“ฑํ•™ ์ˆ˜์ค€ ์ˆ˜ํ•™ ๋ฌธ์ œ) ์€ CoT๊ฐ€ ๋” ์šฐ์„ธ. ๊ทธ๋Ÿฌ๋‚˜ CoT๋Š” Coconut์— ๋น„ํ•ด ๋” ๋งŽ์€ tokens๋ฅผ ํ•„์š”๋กœ ํ•จ. โ†’ Coconut์ด ํ›จ์”ฌ ํšจ์œจ์ . ProsQA์™€ ๊ฐ™์ด ๋‹จ๊ณ„์  ์‚ฌ๊ณ ๊ฐ€ ๋” ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ์…‹์—์„œ ์šฐ์„ธ.
  • i-CoT (์ถ”๋ก  ๊ฐ€์ •์„ ๋‹ค๋ฅธ ๋ฐฉ๋ฒ•์œผ๋กœ ๋‚ด๋ถ€ํ™”ํ•˜๋ ค๋Š” ๋ฐฉ๋ฒ•) ๋Œ€๋น„ ์œ ์‚ฌํ•˜๋‚˜ ์ˆ˜ํ•™์€ coconut์ด ๋” ์šฐ์„ธ. ๊ทธ๋Ÿฌ๋‚˜ i-CoT๋Š” ๋” ์ ์€ tokens ์‚ฌ์šฉ

w/o curriculum๋Š” multi-stage tranining์˜ ์ค‘์š”์„ฑ์„ ๋ณด์—ฌ์ค€๋‹ค.

8/ BFS-like Reasoning

image_4.png image_5.png

ProsQA ๋ฐ์ดํ„ฐ์…‹๊ณผ ๊ฐ™์€ planning-intensive task์—์„œ Coconut์ด CoT๋ณด๋‹ค ์šฐ์„ธํ•œ ๊ฒฐ๊ณผ๋ฅผ ๋ณด์—ฌ์ค€๋‹ค.

CoT๋Š” ์ถ”๋ก  ์ค‘ hallucinated edge๋ฅผ ์ƒ์„ฑํ•ด ์ž˜๋ชป๋œ ๊ฒฐ๊ณผ๋ฅผ ๋„์ถœํ–ˆ๊ณ , Coconut์€ thought tokens๋ฅผ ์‚ฌ์šฉํ•ด BFS์™€ ๋น„์Šทํ•˜๊ฒŒ ์—ฌ๋Ÿฌ ๊ฒฝ๋กœ ํƒ์ƒ‰ํ•œ๋‹ค. thought token๋ฅผ ํ•˜๋‚˜ ์‚ฌ์šฉํ•  ๋•Œ๋Š” ์˜ค๋‹ต์ด์ง€๋งŒ, ๋‘ ๊ฐœ๋ฅผ ์‚ฌ์šฉํ•  ๋•Œ๋Š” ๋” ๋งŽ์€ ๊ฒฝ๋กœ๋ฅผ ์ฐพ๊ณ  ์˜ฌ๋ฐ”๋ฅธ ๋‹ต ๋„์ถœํ•œ๋‹ค. ๋‹จ์ผ ๊ฒฝ๋กœ๊ฐ€ ์•„๋‹ˆ๋ผ ์—ฌ๋Ÿฌ ๊ฒฝ๋กœ๋ฅผ ํ†ตํ•ด ํƒ์ƒ‰ํ•œ๋‹ค๋Š” ์ ์— BFS์™€ ๋น„์Šทํ•˜๋‹ค.

9/ Takeaways

image_6.png

Auto mode์—์„œ๋Š” LLM์€ planning์„ ํ•˜์ง€ ๋ชปํ•œ๋‹ค.

CoT, ReACT (Reasoning + Acting)

์‚ฌ๋žŒ์ด ๊ณ„์† ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ˆ˜์ •ํ•˜๋Š” ๊ณผ์ •์—์„œ ๋ชจ๋ธ์ด ์ •๋‹ต์„ ๋งž์ถ”๋Š” ๊ฒƒ๊ณผ ๊ฐ™์€ Clever Hans ํšจ๊ณผ๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ๋‹ค.

*Clever Hans : LLM ์ž์ฒด ๊ฒ€์ฆ ๋ฐ ๊ฐœ์„  ๋Šฅ๋ ฅ์ด ๋ถ€์กฑํ•˜์—ฌ ๋…ผ๋ฆฌ์ ์œผ๋กœ ๋ฌธ์ œ ํ•ด๊ฒฐํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ ํ”„๋กฌํ”„ํŠธ, ์ฆ‰ shallow heuristic์— ์˜์กดํ•˜๋ฉด์„œ ๋‹ต์„ ๋‚ด๋†“๋Š” ๊ฒƒ.

Clever Hans or Neural Theory of Mind? Stress Testing Social Reasoning in Large Language Models

(๋‹ค๋งŒ ํ•ด๋‹น ๋…ผ๋ฌธ์—์„œ 4o, o1์€ ์—†๋Š” ์  ์ฐธ๊ณ ํ•ด์•ผ ํ•  ๊ฒƒ ๊ฐ™๋‹ค… N-ToM;Neural Theory of Mind๊ณผ ๊ฐ™์ด ์ธ๊ฐ„ ์ˆ˜์ค€์˜ ์ •์„œ์  ์ถ”๋ก  ๋Šฅ๋ ฅ์„ ๋ชจ๋ฐฉํ•˜๊ธฐ ์œ„ํ•ด์„œ๋Š” ์•„์ง ๊ฐœ์„ ์ด ํ•„์š”..?)

Reference

Training Large Language Models to Reason in a Continuous Latent Space

facebookresearch/coconut

Chain of Continuous Thought (AIPapers Academy)

๋ฉ”ํƒ€์˜ ์ฝ”์ฝ”๋„›(COCONUT): ์–ธ์–ด ์—†์ด ์ƒ๊ฐํ•˜๋Š” AI ๋ฐฉ๋ฒ• (AI๋„ท)