一番単純な、2入力1出力です。(論理積を学習させます)
以前、chainerとかのサンプルも探しましたがどれも学習結果を使う所が見つかりませんでした。TensorFlowだとダイレクトにプログラムしないといけないので、その辺りは楽にできました。
とりあえず、サンプルです。
- # -*- coding: utf-8 -*-
- import tensorflow as tf
- import numpy as np
- # 入力データの定義 4行2列(データの定義方法がchainerとは違うようです)
- # x_data = [
- # np.array([0., 0.]),
- # np.array([0., 1.]),
- # np.array([1., 0.]),
- # np.array([1., 1.])
- # ]
- x_data = np.array([
- [0., 0.],
- [0., 1.],
- [1., 0.],
- [1., 1.]
- ])
- # 結果データの定義(4行1列)
- # y_data = [
- # np.array([0.]),
- # np.array([0.]),
- # np.array([0.]),
- # np.array([1.])
- # ]
- y_data = np.array([
- [0.],
- [0.],
- [0.],
- [1.]
- ])
- # 機械学習で最適化するWとbを設定する。Wは4行2列のテンソル。bは4行1列のテンソル。
- W = tf.Variable(tf.random_uniform([4, 2], -1.0, 1.0))
- b = tf.Variable(tf.zeros([4, 1]))
- y = W * x_data + b
- loss = tf.reduce_mean(tf.square(y_data - y))
- optimizer = tf.train.GradientDescentOptimizer(0.5)
- train = optimizer.minimize(loss)
- # 学習を始める前にこのプログラムで使っている変数を全てリセットして空っぽにする
- init = tf.initialize_all_variables()
- # Launch the graph.(おきまりの文句)
- sess = tf.Session()
- sess.run(init)
- # 学習を1000回行い、100回目ごとに画面に学習回数とWとbのその時点の値を表示する
- for step in xrange(1001):
- sess.run(train)
- if step % 100 == 0:
- print step, sess.run(W), sess.run(b)
- # 学習結果を確認
- x_input = np.array([
- [0., 0.],
- [0., 1.],
- [1., 0.],
- [1., 1.]
- ])
- y_res = tf.Variable(tf.zeros([4, 1]))
- y_res = W * x_input + b
- print sess.run(y_res)
- # 4行1列の結果を期待しているのだが、4行2列になってしまう?
- # 学習が十分進めば、どちらの列も同じような結果になるからいいか。
- print sess.run(b)
最後に、学習結果を確認しています。結果はこうなります。(一応、期待通りの結果になりました)
最後の、変数y_resとbの出力だけを貼り付けました。 どうしてもまだ理解できないのが、y_resの計算結果は4行1列になるはずなんですが、4行2列になってしまいます。(代入する直前でy_resを4行1列で定義しても上書きされてしまいます。bは4行2列でも、4行1列でも構わないようです)ただ結果をみると各行の値はどちらの列をとっても同じような値なのでよさそうです。 まだ深層学習の理論的な本を買って読み始めた所でよくわかっていませんが、TensrFlowの方がPrimitiveにプログラムしなければいけない分、直感的にわかります。(でもなんとなく、chainerの方がこの後進めていくのによさげな気もしますので、しばらく悩みます)
- [[ 0.00000000e+00 0.00000000e+00]
- [ -2.62142518e-23 1.62012971e-23]
- [ 7.70256410e-23 -1.24630115e-22]
- [ 9.99999881e-01 1.00000012e+00]]
- [[ 0.00000000e+00]
- [ -2.62142518e-23]
- [ -1.24630115e-22]
- [ 6.69862151e-01]]
0 件のコメント:
コメントを投稿