이 포스트는 Github 접속 제약이 있을 경우를 위한 것이며, 아래와 동일 내용을 실행 결과와 함께 Jupyter notebook으로도 보실 수 있습니다.
You can also see the following as Jupyter notebook along with execution result screens if you have no trouble connecting to the Github.

01. Introduction to Graph(연산/계산 설계도)

1-1. Default Graph 사용

  • 연산을 위한 설계도(Graph)를 별도 정의하지 않으면, 기본적으로 제공되는 Graph가 자동 적용.
import tensorflow as tf

a1 = tf.Variable(3)
b1 = tf.Variable(5)
c1 = a1 + b1

sess1 = tf.Session()       #.Session() 괄호에 아무것도 없으면 기본 Graph가 자동 적용됨.
sess1.run(tf.global_variables_initializer())
print(sess1.run(c1))       # 계산 결과는 print안에서 .run을 호출해야 함.

Session 저장

saver = tf.train.Saver()
saver.save(sess1,'saved/my_test_sess')

1-2. Customer Graph 사용

  • epoch step 별 session 저장 가능 ```python cg1 = tf.Graph() sess2 = tf.Session(graph = cg1)

with cg1.as_default():

a2 = tf.Variable(3)
b2 = tf.Variable(5)
c2 = a2 + b2

sess2.run(tf.global_variables_initializer())

#모델 저장 시 별도의 saver를 만들어준다.
saver2 = tf.train.Saver()
saver2.save(sess2,'saved/my_test_sess') 

writer = tf.summary.FileWriter("./log", sess2.graph)

#summary_writer = tf.train.SummaryWriter("my_sess2_summary", sess2.graph)

print(sess2.run(c2))


- tensorboard 확인은 prompt window에서 입력
```python
tensorboard --logdir=./log

1-3. session 실행이 맞지 않으면 Error 발생

cg2 = tf.Graph()

with cg2.as_default():
    a3 = tf.Variable(3)
    b3 = tf.Variable(5)
    c3 = a3 + b3
    
sess3 = tf.Session(graph = cg2)      # with 구문 밖으로 나왔음.
sess3.run(tf.global_variables_initializer())
print(sess2.run(c3))                 # sess2에 c3을 연산하므로 맞지 않음

1-4. Graph와 Session을 따로 정의하는 방법

cg3 = tf.Graph()

with cg3.as_default():
    v0 = tf.placeholder(tf.int32, name = "V0")
    v1 = tf.Variable(10, name = "V1")
    v2 = tf.Variable(20, name = "V2")
    v3 = tf.add(v0,v2, name = "add")
    
with tf.Session(graph = cg3) as sess3:
    saver3 = tf.train.Saver()
    sess3.run(tf.global_variables_initializer())
    save_path = saver3.save(sess3, "./saved/test_sess3")
    feed_dict = {v0:7}
    output = sess3.run([v3], feed_dict = feed_dict)
    
    writer = tf.summary.FileWriter("./log", sess3.graph)

print(output)

1-5. graph 소속 여부 확인

c2.graph is tf.get_default_graph()

c2.graph is cg1

1-6. 사용 중인 graph node 확인

[node.name for node in cg3.as_graph_def().node]            # cg3.get_operations() 와 동일 결과

[node.input for node in cg3.as_graph_def().node]       # .input, op, device, attr 등 확인 가능

1-7. graph 요소 확인

graph_collection_key_list = cg1.get_all_collection_keys()
graph_collection_list = cg1.get_collection(graph_collection_key_list[0])
oper_list = cg1.get_operations()            # cg1.as_graph_def().node와 동일 결과

1-8. 저장된 model file을 불러와서 Graph node 확인

  • 불러온 모델의 graph 저장 방법 1
im1_graph = tf.Graph()
with im1_graph.as_default():
    im1_saver = tf.train.import_meta_graph("./saved/test_sess3.meta")
    train_op = tf.get_collection('variables')         # variables 확인

print(train_op)
[node.name for node in im1_graph.as_graph_def().node] 
with tf.Session(graph = im1_graph) as im1_sess:
    im1_saver.restore(im1_sess, tf.train.latest_checkpoint('./saved'))

    # 기존 graph에 있는 변수를 가져와 새로 지정
    new_v0 = im1_graph.get_tensor_by_name("V0:0")
    new_v3 = im1_graph.get_tensor_by_name("add:0")
    feed_dict = {new_v0:7}
    output1 = im1_sess.run([new_v3], feed_dict = feed_dict)
    
    # 기존 그래프에 새로운 그래프를 연결하여 연산
    new_v1 = im1_graph.get_tensor_by_name("V1:0")
    new_op = tf.multiply(new_v1, 3, name="multiply")
    output2 = im1_sess.run(new_op)    
    
        
print(output1)
print(output2)

1-9. pb file 만들기 (Freeze)

  • 그래프를 고정한다 라고 이야기 하기도 함.
  • 더 이상 학습이 안되도록 모델의 구조를 가중치와 결합하는 작업.
  • 안드로이드나 다른 디바이스에 포팅하기 위함
# tf.train.write_graph( graph_or_graph_def, logdir, name, as_text=True )
# as_text=True 또는 생략은 pbtxt 파일이 생성되며, False는 pb파일이 생성됨.

tf.train.write_graph(sess3.graph_def,'./saved','sess3.pbtxt', as_text=True)
tf.train.write_graph(sess3.graph_def,'./saved','sess3.pb', as_text=False)

1-10. pb file 불러와서 graph에 입력하기

im2_graph = tf.Graph()
with im2_graph.as_default():
    im2_graph_def = tf.GraphDef()
    with tf.gfile.FastGFile('./saved/my_test_sess3.pb', 'rb') as f:
        im2_graph_def.ParseFromString(f.read())
        tf.import_graph_def(im2_graph_def, name="")