이 포스트는 Github 접속 제약이 있을 경우를 위한 것이며, 아래와 동일 내용을 실행 결과와 함께 Jupyter notebook으로도 보실 수 있습니다.
You can also see the following as Jupyter notebook along with execution result screens if you have no trouble connecting to the Github.
01. Introduction to Graph(연산/계산 설계도)
1-1. Default Graph 사용
- 연산을 위한 설계도(Graph)를 별도 정의하지 않으면, 기본적으로 제공되는 Graph가 자동 적용.
import tensorflow as tf
a1 = tf.Variable(3)
b1 = tf.Variable(5)
c1 = a1 + b1
sess1 = tf.Session() #.Session() 괄호에 아무것도 없으면 기본 Graph가 자동 적용됨.
sess1.run(tf.global_variables_initializer())
print(sess1.run(c1)) # 계산 결과는 print안에서 .run을 호출해야 함.
Session 저장
saver = tf.train.Saver()
saver.save(sess1,'saved/my_test_sess')
1-2. Customer Graph 사용
- epoch step 별 session 저장 가능 ```python cg1 = tf.Graph() sess2 = tf.Session(graph = cg1)
with cg1.as_default():
a2 = tf.Variable(3)
b2 = tf.Variable(5)
c2 = a2 + b2
sess2.run(tf.global_variables_initializer())
#모델 저장 시 별도의 saver를 만들어준다.
saver2 = tf.train.Saver()
saver2.save(sess2,'saved/my_test_sess')
writer = tf.summary.FileWriter("./log", sess2.graph)
#summary_writer = tf.train.SummaryWriter("my_sess2_summary", sess2.graph)
print(sess2.run(c2))
- tensorboard 확인은 prompt window에서 입력
```python
tensorboard --logdir=./log
1-3. session 실행이 맞지 않으면 Error 발생
cg2 = tf.Graph()
with cg2.as_default():
a3 = tf.Variable(3)
b3 = tf.Variable(5)
c3 = a3 + b3
sess3 = tf.Session(graph = cg2) # with 구문 밖으로 나왔음.
sess3.run(tf.global_variables_initializer())
print(sess2.run(c3)) # sess2에 c3을 연산하므로 맞지 않음
1-4. Graph와 Session을 따로 정의하는 방법
cg3 = tf.Graph()
with cg3.as_default():
v0 = tf.placeholder(tf.int32, name = "V0")
v1 = tf.Variable(10, name = "V1")
v2 = tf.Variable(20, name = "V2")
v3 = tf.add(v0,v2, name = "add")
with tf.Session(graph = cg3) as sess3:
saver3 = tf.train.Saver()
sess3.run(tf.global_variables_initializer())
save_path = saver3.save(sess3, "./saved/test_sess3")
feed_dict = {v0:7}
output = sess3.run([v3], feed_dict = feed_dict)
writer = tf.summary.FileWriter("./log", sess3.graph)
print(output)
1-5. graph 소속 여부 확인
c2.graph is tf.get_default_graph()
c2.graph is cg1
1-6. 사용 중인 graph node 확인
[node.name for node in cg3.as_graph_def().node] # cg3.get_operations() 와 동일 결과
[node.input for node in cg3.as_graph_def().node] # .input, op, device, attr 등 확인 가능
1-7. graph 요소 확인
graph_collection_key_list = cg1.get_all_collection_keys()
graph_collection_list = cg1.get_collection(graph_collection_key_list[0])
oper_list = cg1.get_operations() # cg1.as_graph_def().node와 동일 결과
1-8. 저장된 model file을 불러와서 Graph node 확인
- 불러온 모델의 graph 저장 방법 1
im1_graph = tf.Graph()
with im1_graph.as_default():
im1_saver = tf.train.import_meta_graph("./saved/test_sess3.meta")
train_op = tf.get_collection('variables') # variables 확인
print(train_op)
[node.name for node in im1_graph.as_graph_def().node]
with tf.Session(graph = im1_graph) as im1_sess:
im1_saver.restore(im1_sess, tf.train.latest_checkpoint('./saved'))
# 기존 graph에 있는 변수를 가져와 새로 지정
new_v0 = im1_graph.get_tensor_by_name("V0:0")
new_v3 = im1_graph.get_tensor_by_name("add:0")
feed_dict = {new_v0:7}
output1 = im1_sess.run([new_v3], feed_dict = feed_dict)
# 기존 그래프에 새로운 그래프를 연결하여 연산
new_v1 = im1_graph.get_tensor_by_name("V1:0")
new_op = tf.multiply(new_v1, 3, name="multiply")
output2 = im1_sess.run(new_op)
print(output1)
print(output2)
1-9. pb file 만들기 (Freeze)
- 그래프를 고정한다 라고 이야기 하기도 함.
- 더 이상 학습이 안되도록 모델의 구조를 가중치와 결합하는 작업.
- 안드로이드나 다른 디바이스에 포팅하기 위함
# tf.train.write_graph( graph_or_graph_def, logdir, name, as_text=True )
# as_text=True 또는 생략은 pbtxt 파일이 생성되며, False는 pb파일이 생성됨.
tf.train.write_graph(sess3.graph_def,'./saved','sess3.pbtxt', as_text=True)
tf.train.write_graph(sess3.graph_def,'./saved','sess3.pb', as_text=False)
1-10. pb file 불러와서 graph에 입력하기
im2_graph = tf.Graph()
with im2_graph.as_default():
im2_graph_def = tf.GraphDef()
with tf.gfile.FastGFile('./saved/my_test_sess3.pb', 'rb') as f:
im2_graph_def.ParseFromString(f.read())
tf.import_graph_def(im2_graph_def, name="")