mnist 데이터셋 정리

TensorFlow  샘플에 보면 mnist 데이터셋이 많이 등장한다.
mnist를 잘 알면, 이후 코드를 보는데 도움이 될 것 같아서 정리해 놓는다.

1. input_data.py
TensorFlow 샘플에 포함된 예제인데, mnist 데이터셋이 없을 경우 인터넷으로부터 다운로드한다.
추가로 DataSet 클래스 등의 정의가 이 안에 들어 있다.
mnist 데이터셋을 다루는 코드의 꼭대기에는 대부분 input_data.py를 import하는 코드가 들어 있다.

2. mnist 로딩
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

코드마다 조금 다를 수 있는데, 이 코드가 제일 좋은 것 같다.
tensorflow를 설치했다면, 그 안에 있는 input_data.py를 사용하고 있다.
꼭 현재 폴더에 있어야 하는 건 아니다. 
눈에 보이지 않고 최신 버전으로 유지되기 때문에 좋은 코드다.

"/tmp/data/"는 리눅스나 맥에서
루트(/) 폴더 밑에 tmp 폴더 밑에 data 폴더에 저장하라는 뜻이다.
간혹 현재 폴더를 가리키는 도트(.)으로 시작하는 경로가 있기도 하다.
tmp 폴더는 삭제해도 괜찮은 파일을 저장하는 용도로 사용하는 폴더이다.

one_hot은 label 데이터셋을 만들 때, label을 one-hot 방식으로 처리할 것인지를 가리킨다.
기본값은 False이기 때문에 사용할 경우 반드시 True를 전달해야 한다.
데이터가 one-hot 방식으로 넘어오면 처리하는 시점에 변환하지 않아도 되므로 편하다.

3. mnist 자료형
앞의 코드에서 반환한 mnist 변수의 자료형은 DataSets 클래스이다.
input_data.py 파일에 정의된 클래스로 mnist 샘플에서만 사용하는 임시 클래스이다.

DataSets 클래스는 구조체처럼 사용하기 위해 만든 클래스다.
train, validation, test의 3개 멤버 변수만 갖고 있고, 이들은 모두 DataSet 클래스이다.
멤버 함수와 같은 것은 전혀 없다.
DataSet 클래스는 마지막에 s가 없는 단수다.

4. test, validation, test
이들의 자료형은 DataSet 클래스라고 얘기했다.
DataSet 클래스는 mnist 샘플에서 가져다 쓸 수 있도록 다양한 멤버들을 갖추고 있다.

mnist 데이터셋은 전체 4개의 파일로 이루어져 있는데,
2개는 이미지 파일이고, 2개는 이미지가 어떤 숫자인지 알려주는 label 파일이다.
이 안에는 validation 파일은 없고, train과 test 파일만 있다.

train 파일에는 6만 개의 이미지가 들어 있고, test 파일에는 1만 개가 들어 있다.
그런데, validation 파일이 없어서 train에 포함된 5,000개를 validation으로 만들어서 사용한다.
이미지와 label 갯수는 train이 55,000, validation이 5,000, test가 10,000개이다.

이들의 자료형은 모두 numpy의 다차원 배열인 ndarray.
차원 변환과 transpose를 할 수 있고, 행렬 연산도 지원하기 때문에 당연히 numpy가 되는 것이 맞다.

train.images 데이터셋을 출력한 결과. 55,000x784 행렬. 1차원으로 하면 43,120,000개.
<class 'numpy.ndarray'> (55000, 784) 43120000

5. DataSet 클래스
images : 이미지 데이터셋
labels : label 데이터셋
num_examples : 데이터 갯수
next_batch : 데이터셋으로부터 필요한 만큼의 데이터를 반환하는 함수

print('갯수 :', len(mnist.train.images))
print('갯수 :", mnist.train.num_examples)

# train 데이터셋으로부터 데이터 100개 가져오기. (이미지, label) 튜플
train_images, train_labels = mnist.train.next_batch(100)

6. 주의
주의랄 것은 없는데, input_data.py 파일의 버전이 여러 개다.
tensorflow를 배포할 때, 일부 수정된 내용이 반영될 수 있다.
파이썬 3에서 예전 버전을 사용했더니 빨간색 경고가 떴다. 코드는 동작했지만, 기분이 좋지 않았다.
최신 버전의 코드로 바꾸니까, 아무런 문제도 발생하지 않았다.

2015년 12월에 배포된 버전 보기. 클릭하면 github로 이동한다.