ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • 4-2. CNN Visualization(2): 시각화 방법
    쿠다 4기/<네이버 부스트코스> 컴퓨터 비전 2023. 10. 11. 14:46

    1. Embedding Feature Analysis

    모델의 행동 방식 및 특성을 분석하는 방법 중 첫 번째인 모델의 embedding feature을 분석하는 방법을 알아보자. 

    query image를 입력하면, DB내에서 query image와 유사한 이미지가 나열되는 Nearest-Neighbor 방식이다.  embedding space 내에서도 유사한 이미지들끼리 군집을 이루고 있음을 알 수 있다. 위의 파란색 박스 부분을 보면 같은 개이지만 다른 자세이기 때문에 벡터 값으로 비교하면 다를것이다. 그러나 유사하다고 본 것을 통해 물체의 의미나 개념을 이해해 군집화되었음을 알 수 있다.   이미지를 embedding vector로 표현하는 것은 학습된 네트워크에 forward pass시켜 수행되는데, DB의 이미지와 embedding vector을 매칭시켜두고, nearest-neighbor 알고리즘을 수행하는 것이다.

    미리 학습된 neural network에 query 영상을 넣어주면 특징이 추출되고 추출된 특징이 고차원 공간 어딘가에 위치한다. 이렇게 모두 뽑아 놓은 데이터의 feature들은 고차원 공간 상에 존재하고 미리 뽑아 놓은 건 db에 저장한다. 따라서 queryimage를 넣으면 특징을 뽑아 이 특징과 거리가 가까운 image들을 return 하는 것이다.

    그러나 이 방법은 전체적인 그림을 한번에 확인할 수 없다는 문제가 있다.

     

    2. t-distributed stochastic neighbor embedding(t-SNE)

    매우 고차원 벡터인 embedding vector을 사람이 쉽게 파악할 수 있도록 차원 축소를 하면 위 방법의 문제점을 해결할 수 있다. 대표적인 방법으로 t-SNE방법이 있다. t-SNE는 높은 차원의 복잡한 데이터를 2차원으로 차원 축소 하는 방법이다. 그러면 높은 차원 공간에서 비슷한 데이터 구조는 낮은 차원 공간에서 가깝게 대응하며, 비슷하지 않은 데이터 구조는 멀리 떨어져 대응된다.  이러한 t-SNE를 사용하면 데이터를 저차원으로 투영하여 시각화하고, 데이터 포인트를 클러스터링하여 유사한 항목끼리 묶을 수 있다. 이를 통해 데이터의 패턴 및 구조를 파악하고 해석할 수 있다.

    위 그림에서 0-9까지의 숫자 이미지로 구성된  MNIST 데이터의 특징 벡터를 t-SNE를 사용하여 저차원 벡터로 mapping한 결과입니다. 각 클래스마다 색깔 별로 구분하고 있으며, 확인해보면 몇몇의 아웃라이어들을 제외하고 클래스에 따라 군집화된 결과를 확인할 수 있다.

     

    3. Layer Activation

    layer의 activation을 분석해 모델의 특성을 파악할 수도 있다.

    위의 그림을 보면 각각 AlexNet conv5 layer의 138번째 채널, 53번째 채널의 activation을 적절한 값으로 thresholding하여 mask를 만들고, 이를 원본 이미지에 overlay하면 위와 같은 결과를 얻을 수 있다. 이를 통해 각 layer의 hidden node들의 역할(ex. 얼굴을 탐지하는 혹은 계단을 탐지하는 역할 등)을 파악할 수 있다. CNN은 중간중간 hidden unit들이 각각 간단한 얼굴 detection, 손 detection 등 간단한 다층으로 쌓아서 그것들을 조합해서 물체를 인식한다고 해석할 수 있다.

     

    비슷한 방법으로 layer activation에서 가장 큰 값을 가지는 patch의 위치를 기반으로 분석하는 방법도 있다.

    여기서 patch란 이미지의 특정 부분을 지정하는 작은 부분으로 일반적으로 사각형 영역을 가리킨다. 구체적으로 이미지를 입력시켜 특정 레이어의 activation map을 구하고, activation map에서 가장 큰 값을 가지는 patch의 위치 정보를 저장한다. 입력한 이미지에서 그 위치에 해당하는 부분을 잘라서 확인해보면 위의 결과와 같이 layer의 역할(ex. 강아지의 눈 혹은 코 등 검정색 동그란 부분을 탐지하는 역할 등)을 추정해볼 수 있다.

     

    4. class visualization

    class visualization은 데이터를 사용하지 않고 네트워크에 내재되어 있는 정보를 시각화하는 방법이다.

    구체적으로 수행 과정을 살펴보면, 먼저 임의의 dummy 이미지에 대하여 클래스 스코어를 예측한다. 이후 앞서 살펴봤던 loss function에 따라 backpropagation을 수행하여, 클래스 스코어를 최대화하는 방향으로 입력으로 사용한 dummy 이미지를 업데이트 시킨다. 이후 업데이트된 이미지를 입력으로 사용하여, 동일한 과정을 반복적으로 수행하여 클래스에 대한 시각화 결과를 얻는다. 

    이때 아래와 같은 loss function을 이용한다.

    f(I)는 이미지를 CNN에 입력해주었을 때 출력된 하나의 클래스 스코어로 특정 클래스에 대한 클래스 스코어를 최대로 하는 이미지 I를 찾는 것이다. 이때 규제는 L2 norm을 사용하였는데, 꼭 L2norm이 아니어도 크게 상관은 없다. 규제를 하는 이유는 argmax항에만 의존하여 결과값이 너무 커지면 사람이 이미지의 형태로 파악하기 어려운 결과가 출력될 수 있기 때문이다. 아래의 사진을 보면 그 의미를 알 수 있다.

    왼쪽부터 규제가 없는 그림, L2 norm 적용한 그림, L1 norm 적용한 그림

Designed by Tistory.