您正在使用IE低版浏览器,为了您的雷峰网账号安全和更好的产品体验,强烈建议使用更快更安全的浏览器
雷峰网
  • AI研习社
  • 雷峰网公开课
  • 活动中心
  • GAIR
  • 专题
  • 爱搞机
此为临时链接,仅用于文章预览,将在时失效
人工智能 正文
发私信给AI研习社
发送

1

TensorFlow极速入门

本文作者: AI研习社 2017-02-11 18:25
导语:目前,深度学习已经广泛应用于各个领域,很多童鞋想要一探究竟,这里抛砖引玉的介绍下最火的深度学习开源框架tensorflow。

雷锋网按:本文原载于Qunar技术沙龙,原作者已授权雷锋网发布。作者孟晓龙,2016年加入Qunar,目前在去哪儿网机票事业部担任算法工程师。热衷于深度学习技术的探索,对新事物有着强烈的好奇心。

一、前言

目前,深度学习已经广泛应用于各个领域,比如图像识别,图形定位与检测,语音识别,机器翻译等等,对于这个神奇的领域,很多童鞋想要一探究竟,这里抛砖引玉的简单介绍下最火的深度学习开源框架 tensorflow。本教程不是 cookbook,所以不会将所有的东西都事无巨细的讲到,所有的示例都将使用 python。

那么本篇教程会讲到什么?首先是一些基础概念,包括计算图,graph 与 session,基础数据结构,Variable,placeholder 与 feed_dict 以及使用它们时需要注意的点。最后给出了在 tensorflow 中建立一个机器学习模型步骤,并用一个手写数字识别的例子进行演示。

1、tensorflow是什么?

tensorflow 是 google 开源的机器学习工具,在2015年11月其实现正式开源,开源协议Apache 2.0。

下图是 query 词频时序图,从中可以看出 tensorflow 的火爆程度。

TensorFlow极速入门

2、 why tensorflow?

Tensorflow 拥有易用的 python 接口,而且可以部署在一台或多台 cpu , gpu 上,兼容多个平台,包括但不限于 安卓/windows/linux 等等平台上,而且拥有 tensorboard这种可视化工具,可以使用 checkpoint 进行实验管理,得益于图计算,它可以进行自动微分计算,拥有庞大的社区,而且很多优秀的项目已经使用 tensorflow 进行开发了。

3、 易用的tensorflow工具

如果不想去研究 tensorflow 繁杂的API,仅想快速的实现些什么,可以使用其他高层工具。比如 tf.contrib.learn,tf.contrib.slim,Keras 等,它们都提供了高层封装。这里是 tflearn 的样例集(github链接  https://github.com/tflearn/tflearn/tree/master/examples)。

4、 tensorflow安装

目前 tensorflow 的安装已经十分方便,有兴趣可以参考官方文档 (https://www.tensorflow.org/get_started/os_setup)。

二、 tensorflow基础

实际上编写tensorflow可以总结为两步.

(1)组装一个graph;

(2)使用session去执行graph中的operation。

因此我们从 graph 与 session 说起。

1、 graph与session

(1)计算图

Tensorflow 是基于计算图的框架,因此理解 graph 与 session 显得尤为重要。不过在讲解 graph 与 session 之前首先介绍下什么是计算图。假设我们有这样一个需要计算的表达式。该表达式包括了两个加法与一个乘法,为了更好讲述引入中间变量c与d。由此该表达式可以表示为:

TensorFlow极速入门

当需要计算e时就需要计算c与d,而计算c就需要计算a与b,计算d需要计算b。这样就形成了依赖关系。这种有向无环图就叫做计算图,因为对于图中的每一个节点其微分都很容易得出,因此应用链式法则求得一个复杂的表达式的导数就成为可能,所以它会应用在类似tensorflow这种需要应用反向传播算法的框架中。

(2)概念说明

下面是 graph , session , operation , tensor 四个概念的简介。

Tensor:类型化的多维数组,图的边;

Operation:执行计算的单元,图的节点;

Graph:一张有边与点的图,其表示了需要进行计算的任务;

Session:称之为会话的上下文,用于执行图。

Graph仅仅定义了所有 operation 与 tensor 流向,没有进行任何计算。而session根据 graph 的定义分配资源,计算 operation,得出结果。既然是图就会有点与边,在图计算中 operation 就是点而 tensor 就是边。Operation 可以是加减乘除等数学运算,也可以是各种各样的优化算法。每个 operation 都会有零个或多个输入,零个或多个输出。 tensor 就是其输入与输出,其可以表示一维二维多维向量或者常量。而且除了Variables指向的 tensor 外所有的 tensor 在流入下一个节点后都不再保存。

(3)举例

下面首先定义一个图(其实没有必要,tensorflow会默认定义一个),并做一些计算。

import  tensorflow as tf

graph  = tf.Graph()

with  graph.as_default():

    foo = tf.Variable(3,name='foo')

    bar = tf.Variable(2,name='bar')

    result = foo + bar

    initialize =  tf.global_variables_initializer()

这段代码,首先会载入tensorflow,定义一个graph类,并在这张图上定义了foo与bar的两个变量,最后对这个值求和,并初始化所有变量。其中,Variable是定义变量并赋予初值。让我们看下result(下方代码)。后面是输出,可以看到并没有输出实际的结果,由此可见在定义图的时候其实没有进行任何实际的计算。

print(result)  #Tensor("add:0", shape=(), dtype=int32)

下面定义一个session,并进行真正的计算。

with  tf.Session(graph=graph) as sess:

    sess.run(initialize)

    res = sess.run(result)

   print(res)  # 5

这段代码中,定义了session,并在session中执行了真正的初始化,并且求得result的值并打印出来。可以看到,在session中产生了真正的计算,得出值为5。

下图是该graph在tensorboard中的显示。这张图整体是一个graph,其中foo,bar,add这些节点都是operation,而foo和bar与add连接边的就是tensor。当session运行result时,实际就是求得add这个operation流出的tensor值,那么add的所有上游节点都会进行计算,如果图中有非add上游节点(本例中没有)那么该节点将不会进行计算,这也是图计算的优势之一。

TensorFlow极速入门

2、数据结构

Tensorflow的数据结构有着rank,shape,data types的概念,下面来分别讲解。

(1)rank

Rank一般是指数据的维度,其与线性代数中的rank不是一个概念。其常用rank举例如下。

TensorFlow极速入门

(2)shape

Shape指tensor每个维度数据的个数,可以用python的list/tuple表示。下图表示了rank,shape的关系。

TensorFlow极速入门

(3)data type

Data type,是指单个数据的类型。常用DT_FLOAT,也就是32位的浮点数。下图表示了所有的types。

TensorFlow极速入门

3、 Variables

(1)介绍

当训练模型时,需要使用Variables保存与更新参数。Variables会保存在内存当中,所有tensor一旦拥有Variables的指向就不会在session中丢失。其必须明确的初始化而且可以通过Saver保存到磁盘上。Variables可以通过Variables初始化。

weights  = tf.Variable(tf.random_normal([784, 200], stddev=0.35),name="weights")

biases  = tf.Variable(tf.zeros([200]), name="biases")

其中,tf.random_normal是随机生成一个正态分布的tensor,其shape是第一个参数,stddev是其标准差。tf.zeros是生成一个全零的tensor。之后将这个tensor的值赋值给Variable。

(2)初始化

实际在其初始化过程中做了很多的操作,比如初始化空间,赋初值(等价于tf.assign),并把Variable添加到graph中等操作。注意在计算前需要初始化所有的Variable。一般会在定义graph时定义global_variables_initializer,其会在session运算时初始化所有变量。

直接调用global_variables_initializer会初始化所有的Variable,如果仅想初始化部分Variable可以调用tf.variables_initializer。

Init_ab  = tf.variables_initializer([a,b],name=”init_ab”)

Variables可以通过eval显示其值,也可以通过assign进行赋值。Variables支持很多数学运算,具体可以参照官方文档 (https://www.tensorflow.org/api_docs/python/math_ops/)。

(3)Variables与constant的区别

值得注意的是Variables与constant的区别。Constant一般是常量,可以被赋值给Variables,constant保存在graph中,如果graph重复载入那么constant也会重复载入,其非常浪费资源,如非必要尽量不使用其保存大量数据。而Variables在每个session中都是单独保存的,甚至可以单独存在一个参数服务器上。可以通过代码观察到constant实际是保存在graph中,具体如下。

const  = tf.constant(1.0,name="constant")

print(tf.get_default_graph().as_graph_def())

这里第二行是打印出图的定义,其输出如下。

node {

  name: "constant"

  op: "Const"

  attr {

    key: "dtype"

    value {

      type: DT_FLOAT

    }

  }

  attr {

    key: "value"

    value {

      tensor {

        dtype: DT_FLOAT

        tensor_shape {

        }

        float_val: 1.0

      }

    }

  }

}

versions {

  producer: 17

}

(4)命名

另外一个值得注意的地方是尽量每一个变量都明确的命名,这样易于管理命令空间,而且在导入模型的时候不会造成不同模型之间的命名冲突,这样就可以在一张graph中容纳很多个模型。

4、 placeholders与feed_dict

当我们定义一张graph时,有时候并不知道需要计算的值,比如模型的输入数据,其只有在训练与预测时才会有值。这时就需要placeholder与feed_dict的帮助。

定义一个placeholder,可以使用tf.placeholder(dtype,shape=None,name=None)函数。

foo =  tf.placeholder(tf.int32,shape=[1],name='foo')

bar = tf.constant(2,name='bar')

result = foo + bar

with tf.Session() as sess:

    print(sess.run(result))

在上面的代码中,会抛出错误(InvalidArgumentError),因为计算result需要foo的具体值,而在代码中并没有给出。这时候需要将实际值赋给foo。最后一行修改如下:

print(sess.run(result,{foo:[3]}))

其中最后的dict就是一个feed_dict,一般会使用python读入一些值后传入,当使用minbatch的情况下,每次输入的值都不同。

三、mnist识别实例

介绍了一些tensorflow基础后,我们用一个完整的例子将这些串起来。

首先,需要下载数据集,mnist数据可以在Yann LeCun's website( http://yann.lecun.com/exdb/mnist/ )下载到,也可以通过如下两行代码得到。

from  tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/",  one_hot=True)

该数据集中一共有55000个样本,其中50000用于训练,5000用于验证。每个样本分为X与y两部分,其中X如下图所示,是28*28的图像,在使用时需要拉伸成784维的向量。

TensorFlow极速入门

整体的X可以表示为。

TensorFlow极速入门


y为X真实的类别,其数据可以看做如下图的形式。因此,问题可以看成一个10分类的问题。

TensorFlow极速入门

而本次演示所使用的模型为逻辑回归,其可以表示为

TensorFlow极速入门

用图形可以表示为下图,具体原理这里不再阐述,更多细节参考 该链接 (http://tech.meituan.com/intro_to_logistic_regression.html)。

TensorFlow极速入门

那么 let's coding。

当使用tensorflow进行graph构建时,大体可以分为五部分:

   1、 为 输入X与 输出y 定义placeholder;

    2、定义权重W;

    3、定义模型结构;

    4、定义损失函数;

    5、定义优化算法。

首先导入需要的包,定义X与y的placeholder以及 W,b 的 Variables。其中None表示任意维度,一般是min-batch的 batch size。而 W 定义是 shape 为784,10,rank为2的Variable,b是shape为10,rank为1的Variable。

import tensorflow as tf

x = tf.placeholder(tf.float32,  [None, 784])

y_ = tf.placeholder(tf.float32,  [None, 10])

W = tf.Variable(tf.zeros([784,  10]))

b = tf.Variable(tf.zeros([10]))

之后是定义模型。x与W矩阵乘法后与b求和,经过softmax得到y。

y = tf.nn.softmax(tf.matmul(x,  W) + b)

求逻辑回归的损失函数,这里使用了cross entropy,其公式可以表示为:

TensorFlow极速入门

这里的 cross entropy 取了均值。定义了学习步长为0.5,使用了梯度下降算法(GradientDescentOptimizer)最小化损失函数。不要忘记初始化 Variables。

cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))

train_step =  tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

init =  tf.global_variables_initializer()

最后,我们的 graph 至此定义完毕,下面就可以进行真正的计算,包括初始化变量,输入数据,并计算损失函数与利用优化算法更新参数。

with tf.Session() as sess:

    sess.run(init)

    for i in range(1000):

        batch_xs, batch_ys =  mnist.train.next_batch(100)

        sess.run(train_step, feed_dict={x:  batch_xs, y_: batch_ys})

其中,迭代了1000次,每次输入了100个样本。mnist.train.next_batch 就是生成下一个 batch 的数据,这里知道它在干什么就可以。那么训练结果如何呢,需要进行评估。这里使用单纯的正确率,正确率是用取最大值索引是否相等的方式,因为正确的 label 最大值为1,而预测的 label 最大值为最大概率。

correct_prediction =  tf.equal(tf.argmax(y,1), tf.argmax(y_,1))

accuracy = tf.reduce_mean(tf.cast(correct_prediction,  tf.float32))

print(sess.run(accuracy,  feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

至此,我们开发了一个简单的手写数字识别模型。

总结

总结全文,我们首先介绍了 graph 与 session,并解释了基础数据结构,讲解了一些Variable需要注意的地方并介绍了 placeholders 与 feed_dict 。最终以一个手写数字识别的实例将这些点串起来,希望可以给想要入门的你一丢丢的帮助。雷锋网

雷峰网版权文章,未经授权禁止转载。详情见 转载须知。

TensorFlow极速入门

34人收藏
分享:
相关文章
  • 定了!WAVE SUMMIT 深度学习开发者大会2024将于6月2 ...
  • AI Infra 往事之异构计算篇:吴韧与他的学生们
  • 脉脉2023年度人才迁徙报告: ChatGPT研究员6.7万居高 ...
  • 圆桌对话|智能驾驶行业真的需要大模型吗?
AI研习社

编辑

聚焦数据科学,连接 AI 开发者。更多精彩内容,请访问:yanxishe.com
发私信
当月热门文章
最新文章
  • 项目征集 | 全球创新项目路演:AI创新集结号,寻找下一个科技独角兽!
  • 共话大模型技术进展与挑战,CCF大模型论坛北京会议圆满落幕!
  • 美图影像节:聚焦AI工作流,6款新品赋能影像设计行业
  • 港投公司与「港产独角兽」思谋科技今签定战略合作协议
  • 早鸟倒计时3天丨院士领衔、重磅嘉宾云集!中国大模型大会(CLM2024)诚邀您共同探索中国大模型之路!详细日程公开
  • 专访联想集团 CTO 芮勇:智能体是具身智能的基础|具身智能十人谈
热门搜索
Android Apple 银行 Windows 中兴 蚂蚁集团 Nokia GoPro 迅雷 地图 商汤
请填写申请人资料
姓名
电话
邮箱
微信号
作品链接
个人简介
为了您的账户安全,请 验证邮箱
您的邮箱还未验证,完成可获20积分哟!
请验证您的邮箱
立即验证
完善账号信息
您的账号已经绑定,现在您可以 设置密码以方便用邮箱登录
立即设置 以后再说

深圳SEO优化公司龙华优化哪家好赣州网站优化按天计费哪家好通辽百姓网标王推广龙岩建站公司临汾百姓网标王推广甘孜网站推广系统哪家好本溪外贸网站建设公司四平英文网站建设推荐鞍山营销型网站建设南宁网站推广系统哪家好宝安网站推广工具哪家好驻马店推广网站价格南阳seo网站推广哪家好绍兴百度标王报价绵阳SEO按天扣费推荐横岗SEO按效果付费抚州百度竞价包年推广公司桂林百度关键词包年推广公司铁岭建网站南充百姓网标王推广报价大芬英文网站建设多少钱广元网站建设公司北京建网站梅州企业网站改版咸阳seo网站推广多少钱鹤壁百度竞价包年推广推荐太原网站建设设计多少钱山南seo公司徐州网站推广系统报价鹤壁企业网站制作歼20紧急升空逼退外机英媒称团队夜以继日筹划王妃复出草木蔓发 春山在望成都发生巨响 当地回应60岁老人炒菠菜未焯水致肾病恶化男子涉嫌走私被判11年却一天牢没坐劳斯莱斯右转逼停直行车网传落水者说“没让你救”系谣言广东通报13岁男孩性侵女童不予立案贵州小伙回应在美国卖三蹦子火了淀粉肠小王子日销售额涨超10倍有个姐真把千机伞做出来了近3万元金手镯仅含足金十克呼北高速交通事故已致14人死亡杨洋拄拐现身医院国产伟哥去年销售近13亿男子给前妻转账 现任妻子起诉要回新基金只募集到26元还是员工自购男孩疑遭霸凌 家长讨说法被踢出群充个话费竟沦为间接洗钱工具新的一天从800个哈欠开始单亲妈妈陷入热恋 14岁儿子报警#春分立蛋大挑战#中国投资客涌入日本东京买房两大学生合买彩票中奖一人不认账新加坡主帅:唯一目标击败中国队月嫂回应掌掴婴儿是在赶虫子19岁小伙救下5人后溺亡 多方发声清明节放假3天调休1天张家界的山上“长”满了韩国人?开封王婆为何火了主播靠辱骂母亲走红被批捕封号代拍被何赛飞拿着魔杖追着打阿根廷将发行1万与2万面值的纸币库克现身上海为江西彩礼“减负”的“试婚人”因自嘲式简历走红的教授更新简介殡仪馆花卉高于市场价3倍还重复用网友称在豆瓣酱里吃出老鼠头315晚会后胖东来又人满为患了网友建议重庆地铁不准乘客携带菜筐特朗普谈“凯特王妃P图照”罗斯否认插足凯特王妃婚姻青海通报栏杆断裂小学生跌落住进ICU恒大被罚41.75亿到底怎么缴湖南一县政协主席疑涉刑案被控制茶百道就改标签日期致歉王树国3次鞠躬告别西交大师生张立群任西安交通大学校长杨倩无缘巴黎奥运

深圳SEO优化公司 XML地图 TXT地图 虚拟主机 SEO 网站制作 网站优化