[TF]Tensorflowの学習パラメータの保存と読み込み方法


Tensorflow で学習したパラメータの保存と読み込みには、tf.train.Saverを使用します。

保存
保存をするときは、作成したsaver classのsaveメソッドを使用します。

python
saver = tf.train.Saver()
なんらかの処理
#保存
saver.save(sess, model.ckpt)

保存は学習の最後でもいいし、学習の途中のタイミングでもいいです。

読み込み
読み込みをするときは、作成したsaver classのrestoreメソッドを使用します。
sessionが必要なので、sessionを作成した後に読み込みます。
ipython上で実行する場合は、tf.InteractiveSession()で、通常はtf.Session()でsessionを作成します。

python
sess=tf.InteractiveSession()
saver.restore(sess, model.ckpt)

実際に保存と読み込みを行った様子が下記になります。
流れは下記のようになっています。
1.モデル作成
2.学習
3.あとで比較するためにパラメータを別の変数に保存
4.パラメータをファイルに保存
5.Session Close
6.Session作成
7.初期化(本来はこれは必要ありません。比較するためにわざと初期化しました。)
8.保存したパラメータと比較(これはひとつ前で初期化したので差がでます。)
9.ファイルからパラメータを読み込む
10.保存したパラメータと比較(これは一致します)
11.学習

コード

python
# # import
# In[1]:
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

# # load dataset
# In[2]:
mnist = input_data.read_data_sets(./data/mnist/, one_hot=True)

# # build model
# In[3]:
def mlp(x, output_dim, reuse=False):
この記事の続きを読む

サイト名: Qiita - Python

無料メールマガジン登録

週1回、注目のAIニュースやイベント情報を
編集部がピックアップしてお届けしています。

こちらの規約にご同意のうえチェックしてください。

規約に同意する


Leave a Reply

Your email address will not be published.

This site uses Akismet to reduce spam. Learn how your comment data is processed.