关于 TensorFlow 的一些零散知识

阅读 116

2022-06-01

关于 TensorFlow 的一些零散知识

TensorFlow 中的内容相当繁杂, 及时总结是一个好习惯; 平时会收集/总结一些有用的知识点和代码片段, 放在本篇博文下是很合适的. 嘻嘻, 我就是想水一篇文章… ????????????

广而告之

可以在微信中搜索 “珍妮的算法之路” 或者 “world4458” 关注我的微信公众号;另外可以看看知乎专栏 ​​PoorMemory-机器学习​​, 以后文章也会发在知乎专栏中;

变量初始化

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
sess.run(tf.tables_initializer())

前两条代码可以处理 ​​Attempting to use uninitialized value​​​ 的问题, 最后一条用于处理 ​​LookUpTable not initialized​​​ 的问题: 在使用 ​​feature_column​​​ 时, 由于 ​​feature​​ 需要查表获取, 这个表也需要进行初始化, 比如:

FailedPreconditionError (see above for traceback): Table not initialized.
         [[node hash_table_Lookup (defined at 5.py:23)  = LookupTableFindV2[Tin=DT_STRING, Tout=DT_INT64, _device="/job:localhost/replica:0/task:0/device:CPU:0"](relationship_lookup/hash_table, to_sparse_input_1/values, relationship_lookup/hash_table/Const)]]

获取 TensorFlow 中变量或者 Op 的 Name

all_vars = tf.global_variables()
for v in all_vars:
    print(v.op.name)

graph = tf.get_default_graph()
for op in graph.get_operations():
    print(op.name)

for var in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
    arr = sess.run(var)

from tensorflow.python.framework import ops
print(tf.get_collection(ops.GraphKeys.MODEL_VARIABLES))

读取 Estimator 对象的 Variable

names = linear_est.get_variable_names()
print('name: ', names)
for i in names:
    print(type(linear_est.get_variable_value(i)))

还有一种方法, 来自: ​​can tf.estimator.Estimator’s parameters be modified by hand?​​​,
通过访问 ​​​tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)​​ 来达到目的, 但如果为了得到模型的权重, 而不是整张图上的变量, 应该访问:

from tensorflow.python.framework import ops
print(tf.get_collection(ops.GraphKeys.MODEL_VARIABLES))

上面链接中的代码如下:

# Restore, Update, Save
# tested only on tesorflow 1.4
import tensorflow as tf
tf.reset_default_graph()

CHECKPOINT_DIR = 'CHECKPOIN_DIR' # for example '/my_checkpoints' as in tf.estimator.LinearClassifier(model_dir='/my_checkpoints'...
checkpoint = tf.train.get_checkpoint_state(CHECKPOINT_DIR)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph(checkpoint.model_checkpoint_path + '.meta')
    saver.restore(sess, checkpoint.model_checkpoint_path)

    # just to check all variables values
    # sess.run(tf.all_variables())

    # get your variable
    KEY = 'linear/linear_model/0/weights/part_0:0'# for tf.estimator.LinearClassifier first weight
    var_wights_0 = [v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if v.name == KEY][0]
    sess.run(var_wights_0)

    # your update operation
    var_wights_0_updated = var_wights_0.assign(var_wights_0 - 100)
    sess.run(var_wights_0_updated)

    # you can check that value is updated
    # sess.run(tf.all_variables())

    # this saves updated values to last checkpoint saved by estimator
    saver.save(sess, checkpoint.model_checkpoint_path)

TensorFlow 将整数转化为字符串

使用 ​​tf.string.format​​​, 来自 ​​Tensorflow - How to Convert int32 to string (using Python API for Tensorflow)​​

import tensorflow as tf

x = tf.constant([1, 2, 3], dtype=tf.int32)
x_as_string = tf.map_fn(lambda xi: tf.strings.format('{}', xi), x, dtype=tf.string)

with tf.Session() as sess:
  res = sess.run(x_as_string)
  print(res)
  # [b'1' b'2' b'3']

tf.data 介绍

  • ​​https://www.tensorflow.org/guide/data​​
  • ​​https://www.tensorflow.org/guide/data_performance​​
  • ​​十图详解tensorflow数据读取机制(附代码)​​
  • ​​tf.data.Dataset.interleave() 最通俗易懂的使用详解(有图有真相)​​
  • ​​How to use parallel_interleave in TensorFlow​​
  • ​​Tensorflow踩坑记之tf.data​​

tf.identity 的作用

  • ​​StackOverFlow: tf.identity 的作用​​

总的来说, 主要是两个, ​​tf.identity​​ 相当于创建了一个和原始结果一样的新节点, 可以和各种控制流的 op 配合使用, 具体看链接中的例子; 另一个是给 op 命名.


精彩评论(0)

0 0 举报