본문 바로가기
On Going/Computer Vision

[SAM2] SAM2 transfer learning with custom datasets, .py format

by 에아오요이가야 2024. 9. 9.

SAM2 모델의 배치 학습 구현하기

안녕하세요! 오늘은 SAM2(Segment Anything Model 2) 모델의 배치 학습을 구현하는 방법에 대해 알아보겠습니다. SAM2는 이미지 세그멘테이션 작업에 매우 효과적인 모델이지만, 기본 구현은 단일 이미지 처리에 초점이 맞춰져 있습니다. 여기서는 배치 학습을 통해 학습 효율성을 높이는 방법을 소개하겠습니다.

1. 필요한 라이브러리 임포트 및 설정

먼저, 필요한 라이브러리를 임포트 하고 GPU 설정을 합니다:

python
import torch
import os 
import numpy as np 
import cv2 
from sam2.build_sam 
import build_sam2 
from sam2.sam2_image_predictor 
import SAM2ImagePredictor 

# bfloat16 혼합 정밀도 활성화 
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
# Ampere GPU 이상에서 TensorFloat-32 활성화 
if torch.cuda.get_device_properties(0).major >= 8: 
	torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True

2. 모델 및 데이터 준비

SAM2 모델을 불러오고 데이터셋을 준비합니다:

python
# SAM2 모델 구축 
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt" 
model_cfg = "sam2_hiera_l.yaml" 
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") 
predictor = SAM2ImagePredictor(sam2_model)

# 데이터셋 준비 
og_data_dir = "/path/to/images" 
mask_data_dir = "/path/to/masks" 

data = [{'image': os.path.join(og_data_dir, name.split('.')[0] + '.png'),
'annotation': os.path.join(mask_data_dir, name)} for name in os.listdir(mask_data_dir)]

3. 배치 데이터 읽기 함수 구현

배치 단위로 데이터를 읽는 함수를 구현합니다:

def read_batch(data, batch_size=4):
    """Read a batch of images and their corresponding annotations."""
    batch_entries = np.random.choice(data, batch_size, replace=False)
    images, masks, points, labels = [], [], [], []
    
    for entry in batch_entries:
        image = cv2.imread(entry["image"])[..., ::-1]
        ann_map = cv2.imread(entry["annotation"])
        # print(image.shape,ann_map.shape)
        # Resize image and annotation
        r = np.min([1024 / image.shape[1], 1024 / image.shape[0]])
        image = cv2.resize(image, (int(image.shape[1] * r), int(image.shape[0] * r)))
        ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)

        # print(r)
        # print(image.shape,ann_map.shape)

        mat_map = ann_map[:, :, 0].astype(np.int64)
        ves_map = ann_map[:, :, 2].astype(np.int64)
        mat_map[mat_map == 0] = ves_map[mat_map == 0] * (mat_map.max() + 1)

        inds = np.unique(mat_map)[1:]
        image_masks, image_points = [], []
        for ind in inds:
            mask = (mat_map == ind).astype(np.uint8)
            image_masks.append(mask)
            coords = np.argwhere(mask > 0)
            random_coord = np.array(coords[np.random.randint(len(coords))])
            image_points.append([[random_coord[1], random_coord[0]]])

        images.append(image)
        masks.append(np.array(image_masks))
        points.append(np.array(image_points))
        labels.append(np.ones([len(image_masks), 1]))

    return images, masks, points, labels

4. 학습 루프 구현

배치 학습을 위한 학습 루프를 구현합니다:

# Enable training for mask decoder and prompt encoder
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)

# Set up optimizer and gradient scaler
optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=1e-5, weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler()

# Training loop
best_iou = 0.0
batch_size = 128

for itr in range(50000):
    with torch.cuda.amp.autocast():
        images, masks, input_points, input_labels = read_batch(data, batch_size)
        
        batch_loss = 0
        batch_iou = 0
        
        for i in range(batch_size):
            if masks[i].shape[0] == 0:
                continue

            predictor.set_image(images[i])

            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
                input_points[i], input_labels[i], box=None, mask_logits=None, normalize_coords=True
            )
            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=(unnorm_coords, labels), boxes=None, masks=None,
            )

            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=True,
                repeat_image=True,
                high_res_features=high_res_features,
            )
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

            gt_mask = torch.tensor(masks[i].astype(np.float32)).cuda()
            prd_mask = torch.sigmoid(prd_masks[:, 0])
            seg_loss = (-gt_mask * torch.log(prd_mask + 1e-5) - (1 - gt_mask) * torch.log(1 - prd_mask + 1e-5)).mean()

            inter = (gt_mask * (prd_mask > 0.5)).sum(dim=[1, 2])
            iou = inter / (gt_mask.sum(dim=[1, 2]) + (prd_mask > 0.5).sum(dim=[1, 2]) - inter)
            score_loss = torch.abs(prd_scores[:, 0] - iou).mean()

            loss = seg_loss + score_loss * 0.05
            batch_loss += loss
            batch_iou += np.mean(iou.cpu().detach().numpy())

        # Average loss and IoU over the batch
        batch_loss /= batch_size
        batch_iou /= batch_size

        # Backpropagation and optimization step
        predictor.model.zero_grad()
        scaler.scale(batch_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update and save model
        if itr == 0:
            mean_iou = 0
        mean_iou = mean_iou * 0.99 + 0.01 * batch_iou

        if mean_iou > best_iou * 1.1:
            best_iou = mean_iou
            torch.save(predictor.model.state_dict(), f"model_batch.torch")
            print(f"Step {itr}, Accuracy (IoU) = {mean_iou:.4f}")

결론

이렇게 구현한 배치 학습 방식은 여러 이미지를 동시에 처리함으로써 학습 효율성을 높입니다. 배치 크기는 GPU 메모리와 계산 자원에 따라 조정할 수 있습니다. 이 방식을 통해 SAM2 모델의 학습 속도를 향상하고 더 안정적인 학습 결과를 얻을 수 있습니다. 이 구현에서 주목할 점은 다음과 같습니다:

 

1. 배치 단위로 데이터를 읽고 처리합니다.

2. 각 이미지에 대해 개별적으로 손실을 계산한 후 배치 평균을 구합니다.

3. 혼합 정밀도 학습을 사용하여 메모리 효율성을 높입니다.

4. IoU 점수를 통해 모델의 성능을 모니터링하고 최선의 모델을 저장합니다.

 

SAM2 모델의 배치 학습 구현에 대해 궁금한 점이 있다면 댓글로 남겨주세요!

댓글