Implementing RNN in Tensorflow

예전 프로젝트에서는 RNN 구현에 있어서 Keras를 사용해서 (example코드를 거의 그대로 사용) 별 고민 없이 쉽게 해결했었다.

물론 cs231n의 숙제로 RNN을 numpy로 쩔쩔매며 구현하기도 했었지만, 실제 프로젝트에서 자주 쓰기 위해서는 역시  Tensorflow 레벨에서의 활용방법을 간단히 정리해놓는 것이 필요하다고 판단.

텐서플로우에서 RNN을 사용하는 방식은 다음과 같다.

  • Cell을 정의한다. (BasicLSTMCell 등의 구현된 클래스 사용. 내부에 연산한 후 output값과 state값을 넘겨주는 구조가 정의되어 있음.)
cell = tf.contrib.rnn.BasicLSTMCell(num_hidden_unit,
                                    state_is_tuple=True)

 

  • RNN cell 내부에서 받을 hidden state값(previous step으로부터 넘겨받는)에 대한 초기값을 지정하기 위해서 일단 cell모양 그대로 0을 채워놓은 텐서를 저장해놓는다. (이작업을 안하면 그냥 initialize_all_variables() 로 커버될 것 같기도 했는데, 정확히는 모르겠음.)
initial_state = cell.zero_state(batch_size, tf.float32)

 

  • 정의된 cell들을 가지고 (static 혹은) dynamic rnn 구조를 정의한다. 여기서 최종 레이어 단계에서의 output sequence과 최종 state 값을 리턴받는다. (만약 이 스테이트 값을 다음 iteration에서 쓰고자 한다면 받아오고, 아닐 경우 사용하지 않는다.)
output, _ = tf.nn.dynamic_rnn(cell, input_tensor, sequence_length,
                              time_major=False, dtype=tf.float32)

중요한 부분은 input_tensor로 받아오는 텐서 (즉 인풋)의 shape을 가지고 알아서 time step의 길이를 추정한다는 것이다. (sequence_length부분을 생략해도 된다.)

Batch size x time steps x features

이 부분에 유동성을 위해서 time_major라는 argument가 쓰이는데, 보통 [batch_size, num_steps, state_size]의 꼴로 처리를 하지만 이것을 True로 설정하면 [num_steps, batch_size, state_size]의 꼴로 처리한다. 특정 스텝에서의 결과값을 얻어내는 데에 유용하게 쓰일 수 있다.

 

  • 혹은 이 때 시퀀스의 길이가 일정하다면 static RNN을 사용해도 상관없다.(하지만 메모리를 미리 잡는 이슈가 있어서 그냥 dynamic_rnn을 사용하라는 이야기를 줏어 들었다.)
output, _ = tf.nn.static_rnn(cell, input_tensor,
                             dtype=tf.float32)

 

  • 멀티 레이어 RNN을 사용하고 싶다면, RNN셀을 각각 생성한 뒤에 이를 리스트로 묶어서 tf.contrib.rnn.MultiRNNCell안에 인풋으로 넣어주면 된다.
rnn_cells = tf.contrib.rnn.MultiRNNCell([cell1, cell2])

 

  • 혹은 다음과 같이 아예 cell을 생성하는 함수를 만들면 더 편리하다.
# RNN cell layer generating function
def create_rnn_cell():
    cell = tf.contrib.rnn.BasicLSTMCell(num_units = hidden_size,
                                        state_is_tuple = True)
    return cell

# 2 layers for hidden layer
multi_cells = tf.contrib.rnn.MultiRNNCell([create_rnn_cell()
                                           for _ in range(2)],
                                           state_is_tuple=True)

 

  • 그리고 정의된 multi_cells를 가지고 static 혹은 dynamic rnn 구조를 정의한다.
outputs, _= tf.nn.dynamic_rnn(multi_cells, x_data,
                                     dtype=tf.float32)

tf.nn.dynamic_rnn을 처리한 output의 dimension은 cell의 크기와 동일하게 된다. (cell에 지정한 num_units 만큼의 output의 dimension이 결정된다.)

즉, [batch_size, sequence_length, input_dim]을 인풋으로 넣으면 [batch_size, sequence_length, num_units]의 output이 나오게 되는 것이다.

 

  • dynamic_rnn에서 한가지 더 짚고 넘어갈 만한 부분은 sequence_length의 부분이다.
outputs, _states = tf.nn.dynamic_rnn(cell, x_data,
                                     dtype=tf.float32,
                                     sequence_length=[1, 3, 2])

위와 같이 지정을 하면, sequence_length 파라미터에 들어오는 array의 element들 만큼의 길이로 차례차례 input을 처리하게 된다. (그니까 서로다른 길이의 인풋들이 들어올 때, zero-padding을 해서 길이를 맞추어 줄 필요 없이 지정된 길이만큼만 시퀀스를 학습하는 것.)

 

  • tf.contrib.seq2seq.sequence_loss를 사용한 시퀀스의 각 엘리먼트에 대한 loss계산
 

weights = tf.ones([batch_size, sequence_length])
sequence_loss = tf.contrib.seq2seq.sequence_loss(logits = RNN_cell_output,
                                                 targets = True_Y,
                                                 weights = weights)

weight는 시퀀스의 각 엘리먼트이 갖는 가중치(loss에서의)이다. 여기서 중요한 것은 logits부분은 one hot encoding이고, targets부분은  one hot encoding으로 하지 않는다는 사실이다.

따라서,

  • logits에 들어갈 RNN_cell_output : [batch_size, sequence_length, num_classes]
  • targets에 들어갈 레이블 True_Y : [batch_size, sequence_length]

이 둘을 가지고 알아서 sequence에 대한 cross-entropy loss를 계산하여 준다.

 

  • 참고로, 이번에 작업중인 프로젝트에서는 LSTM의 마지막 스텝에서의 아웃풋만 가지고 loss를 계산하도록 해보았는데, 이럴 경우 다음과 같이 뽑아내면 된다.
 

# flatten the LSTM output to make input of fully connected layer(after LSTM)
input_for_fc = tf.reshape(LSTM_output, [-1, hidden_LSTM_cell_size])

# perform fc layer
fc_output = tf.contrib.layers.fully_connected(input = input_for_fc,
                                              num_outputs = num_output_classes,
                                              activation_fn = None)

# reshaping process to get last step's [batch x num_output_classes]
fc_output_reshaped = tf.reshape(fc_output, [input_batch_size,
                                            num_steps,
                                            num_output_classes]
last_index = tf.shape(fc_output_reshaped)[1]
fc_output_reshaped_again = tf.transpose(fc_output_reshaped, [1, 0, 2])

# our [batch x num_output_classes] tensor at the last step
last_output = tf.nn.embedding_lookup(fc_output_reshaped_again,
                                     last_index)

만약 fully connected layer 전에 LSTM에서 특정 스텝의 아웃풋을 뽑아내는 경우라면, 앞서 살펴본 time_major argument를 사용하면 편하게 작업이 가능하다.

 

 

Advertisements

2 thoughts on “Implementing RNN in Tensorflow”

  1. 안녕하세요! 올려주신 글 감사히 봤습니다.
    한 가지 질문이 있는데요,
    seq_length = 3
    data_dim = 3 (feature 수)
    hidden_dim = 2

    일때,
    input의 shape이 shape=(?, 3, 3) 인데,
    rnn을 학습 후 rnn의 weight를 출력해보니
    rnn/basic_rnn_cell/kernel:0 [[ 0.5102138 -1.4371608 ]
    [ 4.3501725 -0.3554855 ]
    [ 4.857006 -0.42872694]
    [-30.570776 0.0319937 ]
    [ 8.434559 1.4694374 ]]
    rnn/basic_rnn_cell/bias:0 [-7.840317 -0.64124054]
    이렇게 나옵니다.

    근데 어떻게 input(3,3)과 weight (5,2)가 곱해질 수 있나요?
    저 출력된 rnn weight 의미가 이해가 안 가요
    올려주신 글의 dynamic_rnn과 static_rnn관련이 있는 문제인가 해서 질문 드려봅니다..ㅜㅜ

    감사합니다.

    Like

    1. 엇 답이 너무 늦었네요 죄송합니다. 블로그를 방치한지 너무 오래되어서..ㅜ 음 이미 찾아봐서 알고 계실 것 같긴 하지만.. 답변드리자면 rnn의 cell은 한번에 한 time step에 대해 연산을 하게 됩니다. 즉, 이번 스텝에서 인풋으로 들어온 3dim의 값들과 이전 스텝에서 받아온 hidden state 2dim의 값들에 대해 곱해줄 weight가 필요하기 때문에 5개가 필요합니다. (이를 2dim의 출력으로 내어 놓게 되는 것입니다.) 음 그리고.. BasicRNNCell은 output과 hidden state out을 같은 걸 뱉어내는 cell이기 때문에 같은 output이 두 곳에 리턴됩니다. 그래서 5 by 2 weight만 필요한 것입니다.

      Like

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s