SAM2 모델의 배치 학습 구현하기
안녕하세요! 오늘은 SAM2(Segment Anything Model 2) 모델의 배치 학습을 구현하는 방법에 대해 알아보겠습니다. SAM2는 이미지 세그멘테이션 작업에 매우 효과적인 모델이지만, 기본 구현은 단일 이미지 처리에 초점이 맞춰져 있습니다. 여기서는 배치 학습을 통해 학습 효율성을 높이는 방법을 소개하겠습니다.
1. 필요한 라이브러리 임포트 및 설정
먼저, 필요한 라이브러리를 임포트 하고 GPU 설정을 합니다:
python
import torch
import os
import numpy as np
import cv2
from sam2.build_sam
import build_sam2
from sam2.sam2_image_predictor
import SAM2ImagePredictor
# bfloat16 혼합 정밀도 활성화
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
# Ampere GPU 이상에서 TensorFloat-32 활성화
if torch.cuda.get_device_properties(0).major >= 8:
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True
2. 모델 및 데이터 준비
SAM2 모델을 불러오고 데이터셋을 준비합니다:
python
# SAM2 모델 구축
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
# 데이터셋 준비
og_data_dir = "/path/to/images"
mask_data_dir = "/path/to/masks"
data = [{'image': os.path.join(og_data_dir, name.split('.')[0] + '.png'),
'annotation': os.path.join(mask_data_dir, name)} for name in os.listdir(mask_data_dir)]
3. 배치 데이터 읽기 함수 구현
배치 단위로 데이터를 읽는 함수를 구현합니다:
def read_batch(data, batch_size=4):
"""Read a batch of images and their corresponding annotations."""
batch_entries = np.random.choice(data, batch_size, replace=False)
images, masks, points, labels = [], [], [], []
for entry in batch_entries:
image = cv2.imread(entry["image"])[..., ::-1]
ann_map = cv2.imread(entry["annotation"])
# print(image.shape,ann_map.shape)
# Resize image and annotation
r = np.min([1024 / image.shape[1], 1024 / image.shape[0]])
image = cv2.resize(image, (int(image.shape[1] * r), int(image.shape[0] * r)))
ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)
# print(r)
# print(image.shape,ann_map.shape)
mat_map = ann_map[:, :, 0].astype(np.int64)
ves_map = ann_map[:, :, 2].astype(np.int64)
mat_map[mat_map == 0] = ves_map[mat_map == 0] * (mat_map.max() + 1)
inds = np.unique(mat_map)[1:]
image_masks, image_points = [], []
for ind in inds:
mask = (mat_map == ind).astype(np.uint8)
image_masks.append(mask)
coords = np.argwhere(mask > 0)
random_coord = np.array(coords[np.random.randint(len(coords))])
image_points.append([[random_coord[1], random_coord[0]]])
images.append(image)
masks.append(np.array(image_masks))
points.append(np.array(image_points))
labels.append(np.ones([len(image_masks), 1]))
return images, masks, points, labels
4. 학습 루프 구현
배치 학습을 위한 학습 루프를 구현합니다:
# Enable training for mask decoder and prompt encoder
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)
# Set up optimizer and gradient scaler
optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=1e-5, weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler()
# Training loop
best_iou = 0.0
batch_size = 128
for itr in range(50000):
with torch.cuda.amp.autocast():
images, masks, input_points, input_labels = read_batch(data, batch_size)
batch_loss = 0
batch_iou = 0
for i in range(batch_size):
if masks[i].shape[0] == 0:
continue
predictor.set_image(images[i])
mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
input_points[i], input_labels[i], box=None, mask_logits=None, normalize_coords=True
)
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
points=(unnorm_coords, labels), boxes=None, masks=None,
)
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
repeat_image=True,
high_res_features=high_res_features,
)
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
gt_mask = torch.tensor(masks[i].astype(np.float32)).cuda()
prd_mask = torch.sigmoid(prd_masks[:, 0])
seg_loss = (-gt_mask * torch.log(prd_mask + 1e-5) - (1 - gt_mask) * torch.log(1 - prd_mask + 1e-5)).mean()
inter = (gt_mask * (prd_mask > 0.5)).sum(dim=[1, 2])
iou = inter / (gt_mask.sum(dim=[1, 2]) + (prd_mask > 0.5).sum(dim=[1, 2]) - inter)
score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
loss = seg_loss + score_loss * 0.05
batch_loss += loss
batch_iou += np.mean(iou.cpu().detach().numpy())
# Average loss and IoU over the batch
batch_loss /= batch_size
batch_iou /= batch_size
# Backpropagation and optimization step
predictor.model.zero_grad()
scaler.scale(batch_loss).backward()
scaler.step(optimizer)
scaler.update()
# Update and save model
if itr == 0:
mean_iou = 0
mean_iou = mean_iou * 0.99 + 0.01 * batch_iou
if mean_iou > best_iou * 1.1:
best_iou = mean_iou
torch.save(predictor.model.state_dict(), f"model_batch.torch")
print(f"Step {itr}, Accuracy (IoU) = {mean_iou:.4f}")
결론
이렇게 구현한 배치 학습 방식은 여러 이미지를 동시에 처리함으로써 학습 효율성을 높입니다. 배치 크기는 GPU 메모리와 계산 자원에 따라 조정할 수 있습니다. 이 방식을 통해 SAM2 모델의 학습 속도를 향상하고 더 안정적인 학습 결과를 얻을 수 있습니다. 이 구현에서 주목할 점은 다음과 같습니다:
1. 배치 단위로 데이터를 읽고 처리합니다.
2. 각 이미지에 대해 개별적으로 손실을 계산한 후 배치 평균을 구합니다.
3. 혼합 정밀도 학습을 사용하여 메모리 효율성을 높입니다.
4. IoU 점수를 통해 모델의 성능을 모니터링하고 최선의 모델을 저장합니다.
SAM2 모델의 배치 학습 구현에 대해 궁금한 점이 있다면 댓글로 남겨주세요!
'On Going > Computer Vision' 카테고리의 다른 글
[Super Resolution] Using Hugging Face Diffusers (2) | 2024.09.10 |
---|---|
[SAM2] Custom 학습 - SAM2 transfer learning with custom datasets, .ipynb (0) | 2024.09.09 |
[SAM2] segment anything 2 (0) | 2024.08.08 |
[ECW] ECW 파일포맷을 다루고싶어!! (0) | 2024.08.06 |
[ECW] ECW file 포맷을 다루고 싶어! (0) | 2024.08.06 |
댓글