mnist 分类 tensorflow

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@author = 'wyx'
@time = 2018/11/7 15:45
@annotation = ''
"""
from datetime import datetime

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.saved_model import tag_constants

mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
save_model_path = 'mnist_model/model.ckpt'


def train():
    learning_rate = 0.05
    batch_size = 100
    max_epochs = 100
    num_of_batch = int(mnist.train.num_examples / batch_size)
    now = datetime.utcnow().strftime("%Y%m%d%H%M%S")

    X = tf.placeholder(tf.float32, shape=[None, 784], name='X')
    y = tf.placeholder(tf.float32, shape=[None, 10], name='y')
    print(X.name, y.name)

    W = tf.get_variable(shape=[784, 10], name='weight')
    b = tf.get_variable(initializer=tf.zeros([10]), name='bais')
    tf.summary.histogram("weights", W)
    tf.summary.histogram("biases", b)

    with tf.name_scope('pred'):
        y_pred = tf.nn.softmax(tf.matmul(X, W) + b, name='predict')
        print(y_pred.name)

    with tf.name_scope('loss'):
        loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=y_pred))
        tf.summary.scalar('loss', loss)
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

    with tf.name_scope('acc'):
        correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='acc')
        print(accuracy.name)

    merged_summary_op = tf.summary.merge_all()
    init_op = tf.global_variables_initializer()

    saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(init_op)

        loss_avg = 0
        writer = tf.summary.FileWriter('mnist/{}'.format(now), sess.graph)
        for epoch in range(max_epochs):
            for i in range(num_of_batch):
                batch_x, batch_y = mnist.train.next_batch(batch_size)
                summary_str, _, l = sess.run([merged_summary_op, optimizer, loss], feed_dict={X: batch_x, y: batch_y})
                loss_avg += l
                global_step = epoch * num_of_batch + i
                writer.add_summary(summary_str, global_step)

                if global_step % 100 == 0:
                    print('Epoch {}: {} save model'.format(epoch, i))
                    # save model in halfway
                    saver.save(sess, save_model_path, global_step=global_step)

            loss_avg /= num_of_batch
            print('Epoch {}: Loss {}'.format(epoch, loss_avg))

        print(sess.run(accuracy, feed_dict={X: mnist.test.images, y: mnist.test.labels}))
        saver.save(sess, save_model_path)
        tf.saved_model.simple_save(sess, 'simple_model', inputs={
            'X': X,
            'y': y,
        }, outputs={
            'pred': y_pred,
        })


def predict(mode=1):
    def _predict():
        graph = tf.get_default_graph()
        X = graph.get_tensor_by_name('X:0')
        y = graph.get_tensor_by_name('y:0')
        accuracy = graph.get_tensor_by_name('acc/acc:0')
        print(sess.run(accuracy, feed_dict={X: mnist.test.images, y: mnist.test.labels}))

        pred = graph.get_tensor_by_name('pred/predict:0')
        import matplotlib.pyplot as plt
        i = 90
        img_orign = mnist.train.images[i]
        img = img_orign.reshape((28, 28))
        plt.imshow(img, cmap='gray')
        plt.title(mnist.train.labels[i])
        plt.show()
        a = sess.run(pred, feed_dict={X: img_orign.reshape(-1, 784)})
        import numpy as np
        print(np.argmax(a))

    if mode == 1:
        meta_path = 'mnist_model/model.ckpt.meta'
        checkpoint_path = 'mnist_model'
    elif mode == 2:
        # stupid var WTF ValueError: No variables to save
        _ = tf.Variable(0)
        saver = tf.train.Saver()

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if mode == 1:
            saver = tf.train.import_meta_graph(meta_path)
            saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))
        elif mode == 2:
            saver.restore(sess, save_model_path)
        elif mode == 3:
            tf.saved_model.loader.load(sess, [tag_constants.SERVING], 'simple_model')
        _predict()


def check_ckpt():
    from tensorflow.python.tools import inspect_checkpoint as chkp
    chkp.print_tensors_in_checkpoint_file(save_model_path, tensor_name='', all_tensors=True)


if __name__ == '__main__':
    # train()
    predict(mode=3)
    # check_ckpt()

MLP

1
2
3
4
5
6
7
8
9
10
def multi_layer(X):
    n_classes = 10
    with tf.name_scope('layer'):
        l1 = tf.layers.dense(X, 32, activation=tf.nn.relu)
        l2 = tf.layers.dense(l1, 16, activation=tf.nn.relu)
        out = tf.layers.dense(l2, n_classes)
    with tf.name_scope('pred'):
        y_pred = tf.nn.softmax(out, name='predict')
        print(y_pred.name)
    return y_pred

Liner layer

1
2
3
4
5
6
7
8
9
10
11
def liner_layer(X):
    W = tf.get_variable(shape=[784, 10], name='weight')
    b = tf.get_variable(initializer=tf.zeros([10]), name='bais')
    tf.summary.histogram("weights", W)
    tf.summary.histogram("biases", b)

    with tf.name_scope('pred'):
        y_pred = tf.nn.softmax(tf.matmul(X, W) + b, name='predict')
        print(y_pred.name)

    return y_pred