본문 바로가기

Deep Learning Papers/Optimization

Tensorflow 모델 pruning 적용 방법

1. 개요

tensorflow 모델에 pruning을 적용하는 방법을 정리함

pruning 알고리듬은 'gradual pruning' 방식을 적용함 (초기 sparsity 값으로부터 목표 sparsity 값에 도달할 때까지 진행하며 pruning 속도를 점차 완만하게 늦춤) 

cifar10 CNN 모델을 pruning하는 예제를 통하여 pruning mask 및 weight 변수 값의 변화를 확인함

pruning을 통한 이득을 얻기 위해서는, pruning된 sparse tensor를 압축하는 메커니즘 및 sparse tensor 연산을 가속하기 위한 HW적 지원이 필요함

 

2. Workflow

3. pruning 모델 생성

pruning 대상이 되는 layer에 mask와 threshold variable을 추가함

mask는 해당 layer의 weight tensor와 shape이 동일하며 forward execution에 적용할 weight을 결정하는 역할을 수행함

apply_mask 함수를 이용하여 layer의 weight tensor를 wrapping 함으로써, mask와 threshold가 추가된 convolutional layer를 생성함

ex) convolution layer에 mask와 threshold 변수가 추가된 사례 

ex) pruning operation 제거 전의 그래프

ex) pruning operation 제거 후의 그래프

 

4. pruning 모델의 mask 및 weight 변화 사례

cifar10 CNN 모델의 conv1 layer에 대하여 mask 및 weight 값 변화를 살펴봄

(총 5000 training step을 수행함. 1000~5000 step 구간에서 50 step 간격으로 pruning 진행함(pruning frequency = 50). sparsity 최종 목표 값은 0.7임.)

- cifar10 pruninig 모델 training

python cifar10_train.py --train_dir='./train' 
--max_steps=5000 
--pruning_hparams=name=cifar10_pruning,
   begin_pruning_step=1000,
   end_pruning_step=5000,
   target_sparsity=0.7,
   pruning_frequency=50,
   sparsity_function_begin_step=1000,
   sparsity_function_end_step=5000
 

pruning 시작 전

(900 step)

pruning 종료 시

(5000 steps)

mask 값

mask의 sparsity 값이 0으로 mask의 값이 모두 1로 채워져 있음

mask의 sparsity 값이 0.7에 도달하여 mask의 많은 값이 0으로 변경되어 있음

mask의

sparsity

변화

weight 값

(pruning

operation

제거 전)

weight 값

(pruning

operation

제거 후)

 

mask 값이 weight tensor와 fusing되어, mask 값이 0인 element에 대해서는

weight tensor의 element 값도 0으로 변경됨

5. pruning hyper-parameter

pruning 수행 시에 조정할 수 있는 주요 hyper-parameter들은 다음과 같음

Hyperparameter Type Default Description
name string model_pruning Name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope
begin_pruning_step integer 0 The global step at which to begin pruning
end_pruning_step integer -1 The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops
weight_sparsity_map list of strings [""]

list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8].

For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used.

pruning_frequency integer 10 How often should the masks be updated? (in # of global_steps)
initial_sparsity float 0.0 Initial sparsity value
target_sparsity float 0.5 Target sparsity value
sparsity_function_begin_step integer 0 The global step at this which the gradual sparsity function begins to take effect
sparsity_function_end_step integer 100 The global step used as the end point for the gradual sparsity function

 

6. 학습 graph에 pruning operation 추가

학습 그래프에 pruning operation들을 추가해서, layer의 weight 값들의 크기 분포를 모니터링하고 각 pruning step에서 목표로 하는 sparsity 수준을 달성하기 위하여 해당 layer의 threshold를 결정함

 

7. 학습 graph에서 pruning operation 제거

pruning 모델을 training 완료 후에 pruning을 위해 추가했던 variable들(mask 및 threshold)과 operation들을 제거해야 함

strip utility를 이용하여 variable들을 constant로 변환하고 threshold variable들을 제거함

mask 값을 weight tensor와 fusing하여, mask 값이 0인 element에 대해서는 weight tensor의 대응되는 element 값도 0으로 변경됨

- pruning operation 제거

python strip_pruning_vars.py --checkpoint_path=/tmp/cifar10_train 
                                      --output_node_names=softmax_linear/softmax_linear_2 
                                      --filename=cifar_pruned.pb

 

8. 참고 site

https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning