https://www.tensorflow.org/guide/ Tensorflow官方文档
TensorFlow是一个采用数据流图(data flow graphs),用于数值计算的开源框架。
tf特点: 多语言接口 , 底层c++实现, 分布式分发训练(多显卡计算), tensorboard可视化,是tf的一组web应用 ;落地部署大部分公司使用, 移动端使用。
TF结构
构建图(数据与操作的执行步骤被描述成一个图)->执行图
图graphs: 是 TensorFlow 将计算表示为指令之间的依赖关系的一种表示法,代表程序运算过程,图包含了一组tf.Operation代表的计算单元对象和tf.Tensor代表的计算单元之间流动的数据。
会话session: TensorFlow 跨一个或多个本地或远程设备运行数据流图的机制,是执行程序的一个入口
节点operation : 在图中表示数学操作
线edges :则表示在节点间相互联系的多维数据数组,即张量(tensor)。
图
TensorFlow程序会默认帮我们创建一张图。print图会打印出该图地址
查看默认图的两种方法:
- 通过调用tf.get_default_graph()访问 ,要将操作添加到默认图形中,直接创建OP即可。
- op、sess都含有graph属性 ,默认都在一张图中
创建图
- 可以通过tf.Graph()自定义创建图,实际应用较少
- 如果要在这张图中创建OP,典型用法是使用tf.Graph.as_default()上下文管理器
- 新图若不是默认图, 使用会话运行不了, 加上graph=新图才可用
Operation
opration是广义的, 不仅仅包括运算, 还包括常见定义方法、模型存储保存接口, tf的api大多数都属于op
op可以包含 一个或多个tensor数据的输入与输出, 输入与输出都是tensor数据。(tf.varible比较特殊,接收tensor返回varible数组类型)
每一个tensor, 都有一个唯一的节点名称op_name shape dtype,打印可显示,tensorboard也显示节点名; 如果节点方法一样, 会默认给字符串名称增加下划线和数字区分, 也可以在创建op时指定op_name
类型 | 实例 |
---|---|
标量运算 | add, sub, mul, div, exp, log, greater, less, equal |
向量运算 | concat, slice, splot, constant, rank, shape, shuffle |
矩阵运算 | matmul, matrixinverse, matrixdateminant |
带状态的运算 | Variable, assgin, assginadd |
神经网络组件 | softmax, sigmoid, relu,convolution,max_pool |
存储, 恢复 | Save, Restroe |
队列及同步运算 | Enqueue, Dequeue, MutexAcquire, MutexRelease |
控制流 | Merge, Switch, Enter, Leave, NextIteration、 |
会话
一个运行TensorFlow operation的类。会话包含以下两种开启方式
- tf.Session:用于完整的程序当中
- tf.InteractiveSession:用于交互式上下文中的TensorFlow ,例如shell
会话可能拥有的资源,如 tf.Variable,tf.QueueBase和tf.ReaderBase。当这些资源不再需要时,释放这些资源非常重要。因此,需要调用tf.Session.close会话中的方法,或将会话用作上下文管理器(with as)
- target:如果将此参数留空(默认设置),会话将仅使用本地计算机中的设备。可以指定 grpc:// 网址,以便指定 TensorFlow 服务器的地址,这使得会话可以访问该服务器控制的计算机上的所有设备。
- graph:默认情况下,新的 tf.Session 将绑定到当前的默认图。
- config:此参数允许您指定一个 tf.ConfigProto 以便控制会话的行为。例如,ConfigProto协议用于打印设备使用信息 本地能用于运算的设备, 对于复杂模型 计算量大的节点 显示每个节点对应的处理器
会话的运行:
run(fetches,feed_dict=None, options=None, run_metadata=None)
- 通过使用sess.run()来运行operation , 可直接放入节点列表
- fetches:单一的operation,或者列表、元组(其它不属于tensorflow的类型不行)
- feed_dict:参数允许调用者覆盖图中张量的值,运行时赋值
- 与tf.placeholder搭配使用,则会检查值的形状是否与占位符兼容。
tf.operation.eval()也可运行operation,但必须有上下文环境 或 (指定会话参数)
feed操作(很少用):
- placeholder提供占位符(也是一个节点操作),run时候通过feed_dict指定参数
张量
TensorFlow 的张量就是一个 n 维数组, 类型为tf.Tensor。Tensor具有以下两个重要的属性
- type:数据类型
- shape:形状(阶)
创建张量的指令(同numpy基本一样):
- 固定值 例如tf.ones() 等
- 随机值
- tf.variable 变量op
- tf.placeholder 占位符
形状变换:
- 类型改变 tf.cast
- 形状改变
- 动态形状 tf.reshape 返回新的形状的tensor, 但也不是随意修改, 也要考虑总的元素数量
- 静态形状 tf.set_shape 修改张量本身形状, 若形状固定, 不能修改
形状可用 [none, none]占位,tf.placeholder(shape=[none, none])
tensor.get_shape()获取形状
张量的数学运算:
- 算术运算符
- 基本数学函数
- 矩阵运算
- reduce操作
- 序列索引操作
变量(特殊的张量)
- 存储持久化
- 可修改值
- 可指定被训练
创建变量: tf.variable 这个op返回的不是一个tensor类, 而是返回一个特殊的variable类
- tf.Variable(
initial_value=None,trainable=True,collections=None
,name=None)
- initial_value:初始化的值
- trainable:是否被训练
- collections:新变量将添加到列出的图的集合中collections,默认为[GraphKeys.GLOBAL_VARIABLES],如果trainable是True变量也被添加到图形集合 GraphKeys.TRAINABLE_VARIABLES
- 变量需要显式初始化init = tf.global_variables_initializer(),才能在会话中使用run()运行值
-
使用tf.variable_scope()修改变量的命名空间with tf.variable_scope(“name”):,会在OP的名字前面增加命名空间的指定名字。变量较多时使用。
Tensorboard可视化学习
方便tf程序理解 调试 优化
实现程序可视化过程:
- 数据序列化-events文件(代码最好写在session中)
TensorBoard 通过读取 TensorFlow 的事件文件来运行,需要使用sess.graph获取图,将数据生成一个序列化的 Summary protobuf 对象,生成events文件。
1 2 3 |
# 返回filewriter,写入事件文件到指定目录(最好用绝对路径),以提供给tensorboard使用 tf.summary.FileWriter('文件夹路径', graph=sess.graph) |
这将在指定目录中生成一个 event 文件,其名称格式如下:
1 2 |
events.out.tfevents.{timestamp}.{hostname} |
- 启动TensorBoard
1 2 |
tensorboard --logdir="文件夹路径" |
在浏览器中打开 TensorBoard 的图页面 [ip:6006 ]
- tf.variable_scope()增加命名空间,方便events文件查看更清晰, tensorboard以命名空间划分
-
tf.summary.scalar(‘自定义名称’, 变量名)查看模型损失变化
-
tf.summary.histogram(‘自定义名称’, 变量名) 查看高纬度权重参数
使用 tf.summary.scalar tf.summary.histogram 收集变量 -> merge = tf.summary.merge_all()合并tensor结果 -> 会话中运行观察tensor summary=sess.run(merge) -> 加入 file_writer.add_summary(summary, i) i表示分每步查看
模型的保存与加载
checkpoint文件
saver = tf.train.Saver()添加saver, 每一步保存一次, 默认保存最近5步的模型 -> 会话中 保存模型saver.save(sess, 路径/文件名.ckpt)
加载模型: 在会话中 saver.restore(sess, 路径)