string模拟实现(直接上源码)

小a草

关注

阅读 18

2024-04-28


前言

通过这篇文章,你可以学习到Tensorflow实现MultiHeadAttention的底层原理。


一、MultiHeadAttention的本质内涵

1.Self_Atention机制

2.MultiHead_Atention机制

二、使MultiHeadAttention在TensorFlow中的代码实现

1.参数说明

2.整体结构

        ''' 多头映射层 '''
        query = self._query_dense(query)
        key = self._key_dense(key)
        value = self._value_dense(value)
        
        ''' 注意力层 '''
        attention_output, attention_scores = self._compute_attention(
            query, key, value, attention_mask, training
        )
        
        ''' 输出层 '''
        attention_output = self._output_dense(attention_output)

3.多头映射层

4.注意力层

5.输出映射层


验证

import tensorflow as tf


layer = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=2)
target = tf.keras.Input(shape=[9, 16])
source = tf.keras.Input(shape=[4, 16])
output_tensor, weights = layer(query=target, value=source,
                               return_attention_scores=True)

''' 手动计算训练参数总数 '''
sum = 16*2*2*3+2*2*3+2*2*16+16
print(f'手动计算的训练参数总数为 : {sum}')
print(f'训练参数总共为 : {layer.count_params()}')
print(f'输出shape为 : {output_tensor.shape}')
print(f'注意力分数shape为 : {weights.shape}')



手动计算的训练参数总数为 : 284
训练参数总共为 : 284
输出shape为 : (None, 9, 16)
注意力分数shape为 : (None, 2, 9, 4)

精彩评论(0)

0 0 举报