REMOTE SENSING

TorchGeo: 객체 탐지(object detection) 예제 소개

유병혁 2024. 4. 7. 01:36

TorchGeo는 torchvision과 유사한 PyTorch 도메인 라이브러리로, 지리공간 데이터에 특화된 데이터셋, 샘플러, 변환, 그리고 사전 훈련된 모델을 제공합니다. 이번 실습은 TorchGeo에서 객체 탐지(object detection) 예제를 소개해 보겠습니다. 이 예제는 Microsoft AI for Good의 케일럽 로빈슨(Caleb Robinson) 님이 제공하는 Jupyter Notebook 코드에 설명을 덧붙인 것입니다.

GPU 선택

학습에 앞서 노트북 상단 메뉴에서 런타임 > 런타임 유형 변경을 선택한 후, 하드웨어 가속기 메뉴에서 GPU를 선택하고 저장합니다. 저는 Colab Pro를 구독하고 있으며 A100 GPU를 사용했습니다.

TorchGeo 설치

TorchGeo는 `pip install torchgeo`를 비교적 가볍게 만들기 위해 필수 의존성 세트만 설치하고 있습니다. 선택적 의존성 세트를 포함한 전체 설치는 `pip install torchgeo[datasets]`를 사용하시면 됩니다.

  • `pip install torchgeo`: "필수(Required)" 의존성 세트 설치
  • `pip install torchgeo[datasets]`:  "선택적(Optional)" 의존성 세트를 포함한 전체 설치
%pip install -q -U torchgeo[datasets]

PyTorch Lightning 설치

PyTorch Lightning은 PyTorch의 상위 레벨 인터페이스를 제공하여, 모델 학습 과정을 더 간결하고 효율적으로 만들어 줍니다.

VHR-10 데이터셋 다운로드

VHR-10 데이터셋을 다운로드한 후, 열어보는 방법은 이전 글을 참고하시면 됩니다.

 

TorchGeo: NWPU VHR-10 데이터셋 다운로드 방법 소개

TorchGeo는 torchvision과 유사한 PyTorch 도메인 라이브러리로, 지리공간 데이터에 특화된 데이터셋, 샘플러, 변환, 그리고 사전 훈련된 모델을 제공합니다. 이번 글은 TorchGeo에서 NWPU VHR-10 데이터셋 쉽

foss4g.tistory.com

모델 학습

아래는 PyTorch의 데이터 로더(DataLoader)를 사용하여 collate_fn 함수를 정의하고 이를 사용하여 데이터를 로드하는 과정을 보여줍니다.

collate_fn 함수는 주어진 배치에서 각 항목의 image, boxes, labels, masks를 추출하여 새로운 배치 딕셔너리를 구성합니다. 이렇게 구성된 새로운 배치는 모델 학습이나 평가에 직접 사용됩니다. DataLoader는 이 함수를 사용하여 데이터를 배치 단위로 모델에 공급할 준비를 합니다. shuffle=True 옵션은 모델 학습 시 데이터의 순서에 대한 의존성을 줄이기 위해 데이터셋의 순서를 무작위로 섞습니다.
def collate_fn(batch):
    new_batch = {
        "image": [item["image"] for item in batch],  # 이미지
        "boxes": [item["boxes"] for item in batch],  # 바운딩 박스
        "labels": [item["labels"] for item in batch],  # 레이블
        "masks": [item["masks"] for item in batch],  # 마스크
    }
    return new_batch  # 새 배치 반환


# 데이터 로더

dl = DataLoader(
    ds,  # 데이터셋
    batch_size=32,  # 한 번에 로드할 데이터 수
    num_workers=2,  # 데이터 로딩을 위해 사용할 프로세스 수
    shuffle=True,  # 데이터를 로드하기 전에 데이터셋을 섞을지 여부
    collate_fn=collate_fn,  # 배치 처리를 위한 collate_fn 함수
)

 

이 코드는 객체 탐지(Object Detection) 작업을 위한 학습 클래스를 정의하고 인스턴스를 생성하는 과정을 보여줍니다. 특히, 이 클래스는 가변 크기의 입력을 처리할 수 있도록 설계되었습니다.

VariableSizeInputObjectDetectionTask 클래스는 표준 객체 탐지 작업(ObjectDetectionTask)을 상속받아, 각 배치 내에서 가변 크기의 입력 이미지를 처리할 수 있도록 training_step 메서드를 정의합니다. 이를 통해 모델이 다양한 크기의 입력 이미지에 대해 효과적으로 학습할 수 있습니다. 생성된 인스턴스는 Faster R-CNN 모델을 사용하여 지정된 설정으로 객체 탐지 작업을 수행할 준비를 합니다.

 

class VariableSizeInputObjectDetectionTask(ObjectDetectionTask):
    # 학습 단계 정의
    def training_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch["image"]  # 이미지
        batch_size = len(x)  # 배치 크기 설정 (이미지 수)
        y = [
            {"boxes": batch["boxes"][i], "labels": batch["labels"][i]}
            for i in range(batch_size)
        ] # 각 이미지의 바운딩 박스와 레이블 정보 추출
        loss_dict = self(x, y)  # 손실
        train_loss: Tensor = sum(loss_dict.values())  # 학습 손실 (손실 값의 합)
        self.log_dict(loss_dict)  # 손실 값 기록
        return train_loss  # 학습 손실 반환

task = VariableSizeInputObjectDetectionTask(
    model="faster-rcnn",  # Faster R-CNN 모델
    backbone="resnet18",  # ResNet18 신경망 아키텍처
    weights=True,  # 사전 훈련된 가중치 사용
    in_channels=3,  # 입력 이미지의 채널 수 (RGB 이미지)
    num_classes=11,  # 분류할 클래스의 수 (10개+배경)
    trainable_layers=3,  # 훈련 가능한 층의 수
    lr=1e-3,  # 학습률
    patience=10,  # 학습 중 조기 종료를 위한 대기 횟수 설정
    freeze_backbone=False,  # 백본 네트워크의 가중치를 고정하지 않고 훈련할지 여부
)
task.monitor = "loss_classifier"  # 모니터링할 메트릭 설정 (여기서는 분류기의 손실)

 

PyTorch Lightning 라이브러리를 사용하여 모델 학습을 위한 설정을 준비합니다. 아래 설정에서는 GPU를 사용하여 모델을 학습하고, 학습 로그와 체크포인트를 'logs/' 디렉토리에 저장하며, 최소 6 에포크에서 최대 100 에포크까지의 학습을 설정하고 있습니다.

trainer = pl.Trainer(
    default_root_dir="logs/",  # 기본 디렉토리 설정
    accelerator="gpu",  # 학습에 사용할 하드웨어 가속기 종류 설정 (GPU 사용)
    devices=[0],  # 사용할 디바이스의 ID 목록 ([0]은 첫 번째 GPU 의미)
    min_epochs=6,  # 최소 학습 에포크 수 설정
    max_epochs=100,  # 최대 학습 에포크 수 설정
    log_every_n_steps=20,  # 몇 번의 스텝마다 로그를 기록할지 설정
)

 

학습 시간은 37min 22s가 소요되었습니다.

%%time
# 모델 학습
trainer.fit(task, train_dataloaders=dl)

모델 추론(inference) 예제

데이터 로더(dl)에서 다음 배치를 가져옵니다.

batch = next(iter(dl))

 

작업(task)에서 모델을 가져와 평가 모드로 설정합니다. 이렇게 하면 훈련 중에 사용되는 드롭아웃(Dropout)과 같은 특정 레이어가 비활성화됩니다.


`torch.no_grad()`는 기울기 계산을 비활성화하여 메모리 사용량을 줄이고 계산 속도를 높입니다. 이는 평가나 추론 과정에서 모델을 업데이트하지 않을 때 사용합니다. 이제 모델에 이미지 배치를 전달하여 예측 결과를 얻습니다.

model = task.model
model.eval()

with torch.no_grad():
  out = model(batch["image"])

 

특정 배치 인덱스에 대한 샘플을 정의합니다.

def create_sample(batch, out, batch_idx):
    return {
        "image": batch["image"][batch_idx],  # 이미지
        "boxes": batch["boxes"][batch_idx],  # 실제 경계 상자
        "labels": batch["labels"][batch_idx],  # 실제 라벨
        "masks": batch["masks"][batch_idx],  # 실제 마스크
        "prediction_labels": out[batch_idx]["labels"],  # 모델이 예측한 라벨
        "prediction_boxes": out[batch_idx]["boxes"],  # 모델이 예측한 경계 상자
        "prediction_scores": out[batch_idx]["scores"],  # 각 예측의 신뢰도 점수
    }

batch_idx = 5
sample = create_sample(batch, out, batch_idx)

 

이제 주어진 sample을 시각화합니다. plot 메서드는 sample에 포함된 이미지, 실제 레이블 및 경계 상자, 예측된 레이블 및 경계 상자 등을 시각화합니다.

ds.plot(sample)
plt.savefig('inference.png', bbox_inches='tight')
plt.show()

# 배치 인덱스 14에 대한 샘플 시각화
batch_idx = 14
sample = create_sample(batch, out, batch_idx)

ds.plot(sample)
plt.show()

# 배치 인덱스 5에 대한 샘플 시각화
batch_idx = 5
sample = create_sample(batch, out, batch_idx)

ds.plot(sample)
plt.show()