๐ง ํธ๋์คํฌ๋จธ(Transformer) ๊ธฐ์ด์ ์๋ฆฌ
ํธ๋์คํฌ๋จธ๋ ์ธ๊ณต์ง๋ฅ ๋ถ์ผ์์ ํ๋ช ์ ์ธ ์ฑ๊ณผ๋ฅผ ์ด๋ฃฌ ๋ฅ๋ฌ๋ ์ํคํ ์ฒ์ ๋๋ค. 2017๋ ๊ตฌ๊ธ ์ฐ๊ตฌ์ง์ด ๋ฐํํ "Attention Is All You Need" ๋ ผ๋ฌธ์ ํตํด ์๊ฐ๋์์ผ๋ฉฐ, ์ดํ GPT, BERT ๋ฑ ๋๊ท๋ชจ ์ธ์ด๋ชจ๋ธ์ ๊ธฐ๋ฐ์ด ๋์์ต๋๋ค.
๐ ํธ๋์คํฌ๋จธ์ ํ์ ๋ฐฐ๊ฒฝ
ํธ๋์คํฌ๋จธ ์ด์ ์๋ ์ํ์ค ๋ชจ๋ธ๋ง์ ์ํด RNN(Recurrent Neural Network)๊ณผ LSTM(Long Short-Term Memory)์ด ์ฃผ๋ก ์ฌ์ฉ๋์์ต๋๋ค. ๊ทธ๋ฌ๋ ์ด๋ค ์ํคํ ์ฒ๋ ์์ฐจ์ ์ฒ๋ฆฌ ๋ฐฉ์์ผ๋ก ์ธํด ๋ณ๋ ฌ ์ฒ๋ฆฌ๊ฐ ๋ถ๊ฐ๋ฅํ๊ณ , ๊ธด ์ํ์ค ํ์ต ์ Gradient Vanishing ๋ฌธ์ ๊ฐ ์์์ต๋๋ค.
ํธ๋์คํฌ๋จธ๋ Self-Attention Mechanism์ ๋์ ํ์ฌ ๋ชจ๋ ํ ํฐ ๊ฐ์ ๊ด๊ณ๋ฅผ ๋ณ๋ ฌ๋ก ๊ณ์ฐํ ์ ์๊ฒ ํ๊ณ , ์ด๋ก ์ธํด ํจ์ฌ ๋น ๋ฅธ ํ์ต ์๋์ ๋ ๋์ ์ ํ๋๋ฅผ ๋ฌ์ฑํ ์ ์์์ต๋๋ค.
โ๏ธ ์ฃผ์ ๊ตฌ์ฑ ์์
Encoder
์ ๋ ฅ ์ํ์ค๋ฅผ ์ธ์ฝ๋ฉํ์ฌ ์ปจํ ์คํธ ๋ฒกํฐ ์์ฑ
Decoder
Encoder์ ์ถ๋ ฅ์ ๋ฐ์ ๋ชฉํ ์ํ์ค ์์ฑ
Self-Attention
๊ฐ ํ ํฐ์ด ๋ค๋ฅธ ํ ํฐ๊ณผ ์ด๋ค ๊ด๊ณ๋ฅผ ๊ฐ๋์ง ๊ณ์ฐ
Multi-Head
์ฌ๋ฌ ๊ฐ์ Attention์ ๋ณ๋ ฌ๋ก ์คํํ์ฌ ๋ค์ํ ๊ด์ ํ์ต
๐ Self-Attention Mechanism
Self-Attention์ ํธ๋์คํฌ๋จธ์ ๊ฐ์ฅ ํต์ฌ์ ์ธ ๊ธฐ์ ์ ๋๋ค. ์ด ๋ฉ์ปค๋์ฆ์ ํตํด ๊ฐ ๋จ์ด๊ฐ ๋ฌธ์ฅ ๋ด ๋ค๋ฅธ ๋จ์ด๋ค๊ณผ ์ด๋ค ๊ด๊ณ๋ฅผ ๋งบ๋์ง(์ฐ๊ด์ฑ)์ ํ์ ํ ์ ์์ต๋๋ค.
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, embed_size, heads):
super(SelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.head_dim = embed_size // heads
assert self.head_dim * heads == embed_size, "Embed size must be divisible by heads"
self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
def forward(self, values, keys, queries, mask):
N = queries.shape[0]
value_len, key_len, query_len = values.shape[1], keys.shape[1], queries.shape[1]
# Split embedding into self.heads pieces
values = values.reshape(N, value_len, self.heads, self.head_dim)
keys = keys.reshape(N, key_len, self.heads, self.head_dim)
queries = queries.reshape(N, query_len, self.heads, self.head_dim)
# Compute attention scores
energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
attention = torch.softmax(energy / (self.embed_size ** 0.5), dim=3)
# Apply attention to values
out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)
out = self.fc_out(out)
return out
์ ์ฝ๋๋ PyTorch๋ก ๊ตฌํํ Self-Attention ๋ ์ด์ด์ ๋๋ค. Query, Key, Value ์ธ ๋ฒกํฐ๋ฅผ ์ฌ์ฉํด attention ์ ์๋ฅผ ๊ณ์ฐํ๊ณ , ์ด๋ฅผ ํตํด ๊ฐ ํ ํฐ์ ์ค์๋๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.
๐ ํธ๋์คํฌ๋จธ์ ์์ฉ
ํธ๋์คํฌ๋จธ ์ํคํ ์ฒ๋ ์ด์ ๋ค์ํ ๋ถ์ผ์์ ํ์ฉ๋๊ณ ์์ต๋๋ค:
- NLP: GPT, BERT, T5 ๋ฑ ์ธ์ด๋ชจ๋ธ
- Computer Vision: Vision Transformer (ViT)
- Audio: WaveNet, Wav2Vec
- Reinforcement Learning: AlphaFold, AlphaZero
- Multi-modal: CLIP, DALL-E
โ ์๋ฆฌ์ ์ฎ๊น โ
๐ ์ฐธ๊ณ : "Attention Is All You Need" (2017, Google Research)
๋๊ธ๋ชฉ๋ก
๋ฑ๋ก๋ ๋๊ธ์ด ์์ต๋๋ค.