TensorFlowの高レベルAPIの使用方法2:Dataset APIを使ってみる

TensorFlowにデータを食べさせるには色々なやり方がある。ここでは、Dataset APIを使った方法
https://www.tensorflow.org/programmers_guide/datasets
を用いてみる。Dataset APIを使うと、ランダムバッチなどを簡単に取り扱うことができる。Dataset APIに慣れると、カスタムEstimatorsに慣れやすい、はず。
これまでずっとやってきている、
https://qiita.com/cometscome_phys/items/95ed1b89acc7829950dd
でやっていたことを、Dataset APIを使ってやってみることにする。
バージョン
TensorFlow: 1.8
Python: 3.6.5
再現すべき関数
ここでは、ある関数
$$
y = a_0 x+ a_1 x^2 + b_0 + 3cos(20x)
$$
という関数を考える。ここで、最後のcosはノイズのようなものとして考えており、$a_0$と$a_1$と$b_0$によって得られる二次関数を得ることが目的となる。
データを300点作っておく。
test.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
n = 300
x0 = np.linspace(-2.0, 2.0, n)
a0 = 3.0
a1 = 2.0
b0 = 1.0
y0 = np.zeros(n)
y0[:] = a0*x0+a1*x0**2 + b0 + 3*np.cos(20*x0)
plt.plot(x0,y0 )
plt.show()
plt.savefig(“graph.png”)
この時のグラフは以下の通りである。
上のデータをフィッティングする際には、
$$
y = \sum_{k=0}^{k_{\rm max}} a_k x^k + b_0
$$
という形を考える。ここでは、$x^k$を基底関数として線形回帰をしていることになる。
詳しくは、
JuliaでTensorFlow その4: 線形基底関数を用いた回帰
https://qiita.com/cometscome_

サイト名: Qiita

無料メールマガジン登録

週1回、注目のAIニュースやイベント情報を
編集部がピックアップしてお届けしています。

こちらの規約にご同意のうえチェックしてください。

規約に同意する