[텐서플로우 정리] 09. argmax 함수

argmax 함수가 어떤 역할을 하는지만 알면 충분한데, 갑자기 궁금해졌다. 두 번째 파라미터로 전달되는 차원이 무엇을 의미하는건지. 너무 당연하게 전달하는 배열의 차원을 알려달라는 것인지 알았다.


import tensorflow as tf
import functions

a1 = tf.Variable([0.1, 0.3, 0.5])
functions.showOperation(tf.argmax(a1, 0))
[출력 결과]
2

두 번째 파라미터는 one-hot-encoding을 적용할 차원을 알려주는 매개변수이다. 1차원 배열에 대해서는 0, 2차원 배열에 대해서는 0과 1, 3차원 배열에 대해서는 0, 1, 2를 사용할 수 있다. 헷갈릴 수 있는데, 0은 열(column), 1은 행(row), 2는 면(page, 행열)을 가리킨다. 그러고 보니 flag 2에 대해서는 정확한 용어를 모르겠다. 보통 차원(dimension)을 얘기할 때 점선면이라고 말하니까, 여기서는 '면'이라고 하고 page라고 읽겠다.

1차원 배열을 전달했으니까 사용할 수 있는 플래그는 0 밖에 없다. 그런데, 느낌이 이상한 것이 1행밖에 없는데 왜 결과가 한 개밖에 나오지 않을까,라고 생각할 수도 있다. 그러나, 1차원 배열을 행이 아니라 열로 생각할 수도 있다. 다시 말해 열이 1개 있기 때문에 결과도 하나가 나왔다,라고 보면 된다.


import tensorflow as tf
import functions

a2 = tf.Variable([[0.1, 0.3, 0.5]])
functions.showOperation(tf.argmax(a2, 0))
functions.showOperation(tf.argmax(a2, 1))
[출력 결과]
[0 0 0]
[2]

1행 3열의 2차원 배열을 사용했다. 사용할 수 있는 플래그는 0과 1의 두 가지다.

먼저 열 단위로 찾기 위해서 0을 전달했다. 결과는 3개가 나왔고 모두 0이다. 열 단위로는 데이터가 1개밖에 없으니까, 위치는 항상 0이 될 수밖에 없다.

행 단위로 찾기 위해서 1을 전달하면, 결과는 1개만 나온다. 행이 하나밖에 없으니까. 0번째 행에서 가장 큰 값은 0.5이고, 2번째 위치에 있어서 2를 반환했다.


import tensorflow as tf
import functions

a3 = tf.Variable([[[0.1, 0.3, 0.5],
[0.3, 0.5, 0.1]],
[[0.5, 0.1, 0.3],
[0.1, 0.3, 0.5]],
[[0.3, 0.5, 0.1],
[0.5, 0.1, 0.3]]])

functions.showOperation(tf.argmax(a3, 0))
functions.showOperation(tf.argmax(a3, 1))
functions.showOperation(tf.argmax(a3, 2))
[출력 결과]
[[1 2 0]
[2 0 1]]
[[1 1 0]
[0 1 1] [1 0 1]]
[[2 1]
[0 2]
[1 0]]

a3은 3x2x3 크기를 갖는 3차원 배열이다. 어떻게 이런 결과가 나왔는지 말로 하는 것이 쉽지 않다. 행과 열, 페이지에 따라 계산한 방식을 빨간색으로 표시해 보았다. 행열 기준의 경우에는 행과 같은 행에 있으면서 열이 일치하는 요소끼리 비교하고 있다. 0행의 0열과 1열, 1행의 0열과 1열 등등.

플래그 0 (열 기준)
[[0.1, 0.3, 0.5], [0.3, 0.5, 0.1]]
[[0.5, 0.1, 0.3], [0.1, 0.3, 0.5]]
[[0.3, 0.5, 0.1], [0.5, 0.1, 0.3]]
[[  1     2     0 ], [  2     0     1 ]]

플래그 1 (행 기준)
[[0.1, 0.3, 0.5],
[0.3, 0.5, 0.1]]    [[1 1 0],
[[0.5, 0.1, 0.3],
[0.1, 0.3, 0.5]]     [0 1 1],
[[0.3, 0.5, 0.1],
[0.5, 0.1, 0.3]]     [1 0 1]]

플래그 2 (행열 기준)
[[0.1, 0.3, 0.5], [0.3, 0.5, 0.1]]     [[2 1],
[[0.5, 0.1, 0.3], [0.1, 0.3, 0.5]]      [0 2],
[[0.3, 0.5, 0.1], [0.5, 0.1, 0.3]]      [1 0]]