深度学习库 Gluon

深度学习库 Gluon

Apache
Python
跨平台
微软
2017-10-21
红薯

Gluon 是微软联合亚马逊推出的一个开源深度学习库,这是一个清晰、简洁、简单但功能强大的深度学习 API,该规范可以提升开发人员学习深度学习的速度,而无需关心所选择的深度学习框架。Gluon API 提供了灵活的接口来简化深度学习原型设计、创建、训练以及部署,而且不会牺牲数据训练的速度。

Gluon 规范已经在 Apache MXNet 中实现,只需要安装最新的 MXNet 即可使用。推荐使用 Python 3.3 或者更新版本。

主要优势包括:

  • 代码简单,易于理解

  • 灵活,命令式结构: 不需要严格定义神经网络模型,而是将训练算法和模型更紧密地结合起来,开发灵活

  • 动态图: Gluon 可以让开发者动态的定义神经网络模型,这意味着他们可以在运行时创建模型、结构,以及使用任何 Python 原生的控制流

  • 高性能: Gluon 所提供的这些优势对底层引擎的训练速度并没有任何影响

示例代码:

import mxnet as mx
from mxnet import gluon, autograd, ndarray
import numpy as np

train_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=True, 
			transform=lambda data, label: (data.astype(np.float32)/255, label)),
            batch_size=32, shuffle=True)
test_data = mx.gluon.data.DataLoader(mx.gluon.data.vision.MNIST(train=False, 
			transform=lambda data, label: (data.astype(np.float32)/255, label)),
            batch_size=32, shuffle=False)                     

# First step is to initialize your model
net = gluon.nn.Sequential()
# Then, define your model architecture
with net.name_scope():
    net.add(gluon.nn.Dense(128, activation="relu")) # 1st layer - 128 nodes
    net.add(gluon.nn.Dense(64, activation="relu")) # 2nd layer – 64 nodes
    net.add(gluon.nn.Dense(10)) # Output layer

# We start with random values for all of the model’s parameters from a
# normal distribution with a standard deviation of 0.05
net.collect_params().initialize(mx.init.Normal(sigma=0.05))

# We opt to use softmax cross entropy loss function to measure how well the # model is able to predict the correct answer
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

# We opt to use the stochastic gradient descent (sgd) training algorithm
# and set the learning rate hyperparameter to .1
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})

epochs = 10
for e in range(epochs):
    for i, (data, label) in enumerate(train_data):
        data = data.as_in_context(mx.cpu()).reshape((-1, 784))
        label = label.as_in_context(mx.cpu())
        with autograd.record(): # Start recording the derivatives
            output = net(data) # the forward iteration
            loss = softmax_cross_entropy(output, label)
            loss.backward()
        trainer.step(data.shape[0])
        # Provide stats on the improvement of the model over each epoch
        curr_loss = ndarray.mean(loss).asscalar()
    print("Epoch {}. Current Loss: {}.".format(e, curr_loss))
的码云指数为
超过 的项目
加载中

评论(6)

milin
milin
一个tensotflow就够了,其他都一样
乳沟
乳沟
无法入手
netkiller-
netkiller-
先观望
sdvdxl
sdvdxl
选择困难症了
BetaYuan
BetaYuan
继前端轮子满天飞以后人工智能轮子满天飞的时代来了。
假红薯
md 这么多框架 到底哪个好?

暂无资讯

暂无问答

mxnet的gluon接口

mxnet新出了个gluon接口,文档和api都还不错,有torch和tensorflow相似的一些风格,使得从这两个框架转过来相对容易一些。gluon接口中文网站传送门:https://zh.gluon.ai/index.html. gluon...

2017/10/27 22:56
75
1
JavaFX Scene Builder 8.2 下载地址

Scene Builder 8.X 下载: http://gluonhq.com/open-source/scene-builder/ **官方已移交给Team Gluon维护** Team Gluon : http://gluonhq.com/ 引用: http://www.javafx-tutorials.com/ja...

2016/05/24 11:04
134
2
2018年值得关注的5个大数据趋势

随着大数据系统日益高效,每年的大数据趋势变得更具开创性。根据调研机构Forrester Research最近发布的营销报告,随着组织的领导者开始意识到大量使用大数据技术所需的工作量,人工智能(AI)正...

2018/06/14 12:23
19
0
Coding and Paper Letter(三十九)

资源整理。 1 Coding: 1.Python库benchmark rio s3,用于在访问S3上的文件时对Rasterio / GDAL的多线程性能进行基准测试的工具。 benchmark rio s3 2.Pangeo-Binder Cookiecutter模板。 cook...

2018/10/19 15:08
10
0

没有更多内容

加载失败,请刷新页面

返回顶部
顶部