#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@author = 'wyx'
@time = 2018/11/7 15:45
@annotation = ''
"""fromdatetimeimportdatetimeimporttensorflowastffromtensorflow.examples.tutorials.mnistimportinput_datafromtensorflow.python.saved_modelimporttag_constantsmnist=input_data.read_data_sets("MNIST_data/",one_hot=True)save_model_path='mnist_model/model.ckpt'deftrain():learning_rate=0.05batch_size=100max_epochs=100num_of_batch=int(mnist.train.num_examples/batch_size)now=datetime.utcnow().strftime("%Y%m%d%H%M%S")X=tf.placeholder(tf.float32,shape=[None,784],name='X')y=tf.placeholder(tf.float32,shape=[None,10],name='y')print(X.name,y.name)W=tf.get_variable(shape=[784,10],name='weight')b=tf.get_variable(initializer=tf.zeros([10]),name='bais')tf.summary.histogram("weights",W)tf.summary.histogram("biases",b)withtf.name_scope('pred'):y_pred=tf.nn.softmax(tf.matmul(X,W)+b,name='predict')print(y_pred.name)withtf.name_scope('loss'):loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=y_pred))tf.summary.scalar('loss',loss)optimizer=tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)withtf.name_scope('acc'):correct_prediction=tf.equal(tf.argmax(y_pred,1),tf.argmax(y,1))accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32),name='acc')print(accuracy.name)merged_summary_op=tf.summary.merge_all()init_op=tf.global_variables_initializer()saver=tf.train.Saver()withtf.Session()assess:sess.run(init_op)loss_avg=0writer=tf.summary.FileWriter('mnist/{}'.format(now),sess.graph)forepochinrange(max_epochs):foriinrange(num_of_batch):batch_x,batch_y=mnist.train.next_batch(batch_size)summary_str,_,l=sess.run([merged_summary_op,optimizer,loss],feed_dict={X:batch_x,y:batch_y})loss_avg+=lglobal_step=epoch*num_of_batch+iwriter.add_summary(summary_str,global_step)ifglobal_step%100==0:print('Epoch {}: {} save model'.format(epoch,i))# save model in halfway
saver.save(sess,save_model_path,global_step=global_step)loss_avg/=num_of_batchprint('Epoch {}: Loss {}'.format(epoch,loss_avg))print(sess.run(accuracy,feed_dict={X:mnist.test.images,y:mnist.test.labels}))saver.save(sess,save_model_path)tf.saved_model.simple_save(sess,'simple_model',inputs={'X':X,'y':y,},outputs={'pred':y_pred,})defpredict(mode=1):def_predict():graph=tf.get_default_graph()X=graph.get_tensor_by_name('X:0')y=graph.get_tensor_by_name('y:0')accuracy=graph.get_tensor_by_name('acc/acc:0')print(sess.run(accuracy,feed_dict={X:mnist.test.images,y:mnist.test.labels}))pred=graph.get_tensor_by_name('pred/predict:0')importmatplotlib.pyplotasplti=90img_orign=mnist.train.images[i]img=img_orign.reshape((28,28))plt.imshow(img,cmap='gray')plt.title(mnist.train.labels[i])plt.show()a=sess.run(pred,feed_dict={X:img_orign.reshape(-1,784)})importnumpyasnpprint(np.argmax(a))ifmode==1:meta_path='mnist_model/model.ckpt.meta'checkpoint_path='mnist_model'elifmode==2:# stupid var WTF ValueError: No variables to save
_=tf.Variable(0)saver=tf.train.Saver()withtf.Session()assess:sess.run(tf.global_variables_initializer())ifmode==1:saver=tf.train.import_meta_graph(meta_path)saver.restore(sess,tf.train.latest_checkpoint(checkpoint_path))elifmode==2:saver.restore(sess,save_model_path)elifmode==3:tf.saved_model.loader.load(sess,[tag_constants.SERVING],'simple_model')_predict()defcheck_ckpt():fromtensorflow.python.toolsimportinspect_checkpointaschkpchkp.print_tensors_in_checkpoint_file(save_model_path,tensor_name='',all_tensors=True)if__name__=='__main__':# train()
predict(mode=3)# check_ckpt()