Day 26 利用transformer自己实作一个翻译程序(八) Multi-head attention

Multi-head attention

Day 12 Self-attention(六) Multi-Head Self-attention有提到相关的概念

code的详细解说之後会补上,由於我自己也还在读这方面的内容,因此可能需要一点时间

class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)

    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)

    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention,
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

    return output, attention_weights
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512))  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, k=y, q=y, mask=None)
out.shape, attn.shape
(TensorShape([1, 60, 512]), TensorShape([1, 8, 60, 60]))

<<:  Explain详解(优化查询好帮手)-Part2(possible_keys、key、key_len、ref、rows、filtered、Extra、Json格式的执行计画)

>>:  Day 11 - React: Component

那些被忽略但很好用的 Web API / Animation On Scroll

学以致用是最快乐的事情 昨天我们认识了 IntersectionObserver,知道它可以侦测到...

什麽是物件导向程序设计 (Object-oriented programming)

什麽是物件导向程序设计? 物件导向程序设计 (Object-oriented programming...

Strings

mystring = 'hello' print(mystring) mystring = &quo...

第8-1章:管理本地端主机之使用者与群组(三)

前言 在上一章节中,笔者讲解了如和切换使用者以及取得最高的root使用者权限,接下来要讲解的是本地端...

前端工程师也能开发全端网页:挑战 30 天用 React 加上 Firebase 打造社群网站|Day10 发表文章功能

连续 30 天不中断每天上传一支教学影片,教你如何用 React 加上 Firebase 打造社群...