前面有提到transformer需要3个矩阵,K、Q、V
这个公式是前人推导出来效果最好的公式
推导这个公式的方法如果之後有时间看懂了我会补上
def scaled_dot_product_attention(q, k, v, mask):
"""Calculate the attention weights.
q, k, v must have matching leading dimensions.
k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
The mask has different shapes depending on its type(padding or look ahead)
but it must be broadcastable for addition.
Args:
q: query shape == (..., seq_len_q, depth)
k: key shape == (..., seq_len_k, depth)
v: value shape == (..., seq_len_v, depth_v)
mask: Float tensor with shape broadcastable
to (..., seq_len_q, seq_len_k). Defaults to None.
Returns:
output, attention_weights
"""
matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k)
# scale matmul_qk
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)
# add the mask to the scaled tensor.
if mask is not None:
scaled_attention_logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k) so that the scores
# add up to 1.
attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k)
output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v)
return output, attention_weights
利用这个方法可以让需要被注意的区块留住,不需要的部分被清除掉
def print_out(q, k, v):
temp_out, temp_attn = scaled_dot_product_attention(
q, k, v, None)
print('Attention weights are:')
print(temp_attn)
print('Output is:')
print(temp_out)
np.set_printoptions(suppress=True)
temp_k = tf.constant([[10, 0, 0],
[0, 10, 0],
[0, 0, 10],
[0, 0, 10]], dtype=tf.float32) # (4, 3)
temp_v = tf.constant([[1, 0],
[10, 0],
[100, 5],
[1000, 6]], dtype=tf.float32) # (4, 2)
# This `query` aligns with the second `key`,
# so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0. 1. 0. 0.]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[10. 0.]], shape=(1, 2), dtype=float32)
# This query aligns with a repeated key (third and fourth),
# so all associated values get averaged.
temp_q = tf.constant([[0, 0, 10]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0. 0. 0.5 0.5]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[550. 5.5]], shape=(1, 2), dtype=float32)
# This query aligns equally with the first and second key,
# so their values get averaged.
temp_q = tf.constant([[10, 10, 0]], dtype=tf.float32) # (1, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor([[0.5 0.5 0. 0. ]], shape=(1, 4), dtype=float32)
Output is:
tf.Tensor([[5.5 0. ]], shape=(1, 2), dtype=float32)
把这些Q矩阵放在一起
temp_q = tf.constant([[0, 0, 10],
[0, 10, 0],
[10, 10, 0]], dtype=tf.float32) # (3, 3)
print_out(temp_q, temp_k, temp_v)
Attention weights are:
tf.Tensor(
[[0. 0. 0.5 0.5]
[0. 1. 0. 0. ]
[0.5 0.5 0. 0. ]], shape=(3, 4), dtype=float32)
Output is:
tf.Tensor(
[[550. 5.5]
[ 10. 0. ]
[ 5.5 0. ]], shape=(3, 2), dtype=float32)
<<: D25 - 如何用 Apps Script 自动化地创造与客制 Google Sheet?(二)结合股票价格通知与信件
>>: Day 12 [Python ML、特徵工程] 特徵工程整理
先来制作死亡後开启死亡画面,把之前死亡画面的Active(开启)暂时先关闭,等Player死亡时候才...
今天来讲讲如何来测试hook的原始版本 来,我们来先做一个超简单的hook就叫useName吧 im...
其实在开赛前,我有规划一些软性书单,想说在忙碌或想要休息时,可以拿来挡一下。但我今天早上真正 rev...
先前开发的功能都是只有在本机执行,顶多同一个区域网路的用户可以使用服务,还是得想个办法让其他外网用户...
那今天呢,我们来讲解一下请LIR注册ASN需要几个重要的Object 首先,我们这边是以RIPE为例...