본문 바로가기
On Going/Computer Vision

[SAM2] Custom 학습 - SAM2 transfer learning with custom datasets, .ipynb

by 에아오요이가야 2024. 9. 9.

SAM2 model의 custom 학습을 진행해 보겠습니다. 대화형 인터프리터 버전

import torch
import numpy as np
import cv2,os
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

 

 

우선 필요한 package들을 import 해줍니다.

 

<그전에 sam2 model을 사용할 수 있도록 가상환경 등을 이용하여 환경설정을 해줘야 합니다.>

 

# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

 

sam2 모델은 선언해 주고 pretrained 모델을 load 해줍니다.

 

sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")

predictor = SAM2ImagePredictor(sam2_model)

 

 

데이터를 image_original과 image_mask로 분리하여 원본과 annotation으로 선언해 줍니다.

data = 부분에서 파일 형을 정리해서 갖고 계시면 됩니다.

og_data_dir = "/path/to/your/or_image"
mask_data_dir = "/path/to/your/mask_iamge"
data = [{'image': os.path.join(og_data_dir, name.split('.')[0] + '.png'), 
         'annotation': os.path.join(mask_data_dir, name)} 
        for name in os.listdir(mask_data_dir)]

 

batch단위로 학습하기 위해 image, annotation read 하는 부분을 선언해 주겠습니다.

def read_batch(data, batch_size=4):
    """Read a batch of images and their corresponding annotations."""
    batch_entries = np.random.choice(data, batch_size, replace=False)
    images, masks, points, labels = [], [], [], []
    
    for entry in batch_entries:
        image = cv2.imread(entry["image"])[..., ::-1]
        ann_map = cv2.imread(entry["annotation"])
        # print(image.shape,ann_map.shape)
        # Resize image and annotation
        r = np.min([1024 / image.shape[1], 1024 / image.shape[0]])
        image = cv2.resize(image, (int(image.shape[1] * r), int(image.shape[0] * r)))
        ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)

        # print(r)
        # print(image.shape,ann_map.shape)

        mat_map = ann_map[:, :, 0].astype(np.int64)
        ves_map = ann_map[:, :, 2].astype(np.int64)
        mat_map[mat_map == 0] = ves_map[mat_map == 0] * (mat_map.max() + 1)

        inds = np.unique(mat_map)[1:]
        image_masks, image_points = [], []
        for ind in inds:
            mask = (mat_map == ind).astype(np.uint8)
            image_masks.append(mask)
            coords = np.argwhere(mask > 0)
            random_coord = np.array(coords[np.random.randint(len(coords))])
            image_points.append([[random_coord[1], random_coord[0]]])

        images.append(image)
        masks.append(np.array(image_masks))
        points.append(np.array(image_points))
        labels.append(np.ones([len(image_masks), 1]))

    return images, masks, points, labels

 

 

sam2를 학습하기 위한 feature들을 적어주겠습니다.

# Enable training for mask decoder and prompt encoder
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)

# Set up optimizer and gradient scaler
optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=1e-5, weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler()

 

학습 진행합니다,

# Training loop
best_iou = 0.0
batch_size = 128

for itr in range(50000):
    with torch.cuda.amp.autocast():
        images, masks, input_points, input_labels = read_batch(data, batch_size)
        
        batch_loss = 0
        batch_iou = 0
        
        for i in range(batch_size):
            if masks[i].shape[0] == 0:
                continue

            predictor.set_image(images[i])

            mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
                input_points[i], input_labels[i], box=None, mask_logits=None, normalize_coords=True
            )
            sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
                points=(unnorm_coords, labels), boxes=None, masks=None,
            )

            high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
            low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
                image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
                image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=True,
                repeat_image=True,
                high_res_features=high_res_features,
            )
            prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

            gt_mask = torch.tensor(masks[i].astype(np.float32)).cuda()
            prd_mask = torch.sigmoid(prd_masks[:, 0])
            seg_loss = (-gt_mask * torch.log(prd_mask + 1e-5) - (1 - gt_mask) * torch.log(1 - prd_mask + 1e-5)).mean()

            inter = (gt_mask * (prd_mask > 0.5)).sum(dim=[1, 2])
            iou = inter / (gt_mask.sum(dim=[1, 2]) + (prd_mask > 0.5).sum(dim=[1, 2]) - inter)
            score_loss = torch.abs(prd_scores[:, 0] - iou).mean()

            loss = seg_loss + score_loss * 0.05
            batch_loss += loss
            batch_iou += np.mean(iou.cpu().detach().numpy())

        # Average loss and IoU over the batch
        batch_loss /= batch_size
        batch_iou /= batch_size

        # Backpropagation and optimization step
        predictor.model.zero_grad()
        scaler.scale(batch_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update and save model
        if itr == 0:
            mean_iou = 0
        mean_iou = mean_iou * 0.99 + 0.01 * batch_iou

        if mean_iou > best_iou * 1.1:
            best_iou = mean_iou
            torch.save(predictor.model.state_dict(), f"model_batch.torch")
            print(f"Step {itr}, Accuracy (IoU) = {mean_iou:.4f}")

끝!

 

이걸 활용해서 결과도 봐야겠죠?

여기선 학습에 사용된 데이터를 확인한 것인데요 학습하지 않은 데이터로 확인하는 것이 더 효용 있겠죠?

image_path = data[0]['image']
maks_path = data[0]['maks']
num_samples = 2 # number of points/segment to sample

 

이미지를 읽고 point를 잡아오는 부분입니다,

def read_image(image_path, mask_path): # read and resize image and mask
        img = cv2.imread(image_path)[...,::-1]  # read image
        mask = cv2.imread(mask_path,0)
        r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
        img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
        mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
        return img, mask


def get_points(mask,num_points): # Sample points inside the input mask
        points=[]
        for i in range(num_points):
            coords = np.argwhere(mask > 0)
            yx = np.array(coords[np.random.randint(len(coords))])
            points.append([[yx[1], yx[0]]])
        return np.array(points)

 

이미지 마스크, point를 짚어줍니다 get_points부분이 sam2에서 어려워 보일 수 있는데요 천천히 테스트하다 보면 알 수 있습니다.

# read image and sample points
image,mask = read_image(image_path, mask_path)
input_points = get_points(mask,num_samples)

 

모델을 선언해 주고

 

학습한 모델은 얻어줘야겠죠 여기서?

# Load model you need to have pretrained model already made
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")

# Build net and load weights
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load("model.torch"))

 

결과를 보여주고,

annotation (추론한 이미지)와 mix(seg 추론결과를 original image위에 overlay 한 결과 이미지)를 저장하는 부분입니다.

 

# predict mask
with torch.no_grad():
        predictor.set_image(image)
        masks, scores, logits = predictor.predict(
            point_coords=input_points,
            point_labels=np.ones([input_points.shape[0],1])
        )

# Short predicted masks from high to low score

np_masks = np.array(masks[:,0]) # remove .cpu().numpy() as masks is already a numpy array
np_scores = scores[:,0]
shorted_masks = np_masks[np.argsort(np_scores)][::-1]


# Stitch predicted mask into one segmentation mask

seg_map = np.zeros_like(shorted_masks[0],dtype=np.uint8)
occupancy_mask = np.zeros_like(shorted_masks[0],dtype=bool)
for i in range(shorted_masks.shape[0]):
    mask = shorted_masks[i]
    if (mask*occupancy_mask).sum()/mask.sum()>0.15: continue
    mask[occupancy_mask]=0
    # Convert mask to boolean for indexing
    seg_map[mask > 0.5] = i + 1  # Assuming 0.5 as the threshold for positive prediction
    occupancy_mask[mask > 0.5] = 1


# create colored annotation map
height, width = seg_map.shape

# Create an empty RGB image for the colored annotation
rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
# Map each class number to a random  color


for id_class in range(1,seg_map.max()+1):
    rgb_image[seg_map == id_class] = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]

# save and display

cv2.imwrite("annotation.png",rgb_image)
cv2.imwrite("mix.png",(rgb_image/2+image/2).astype(np.uint8))

# plt.imshow(cv2.resize(rgb_image, (800, 600))) # Use cv2_imshow instead of cv2.imshow
# plt.imshow(cv2.resize((rgb_image/2+image/2).astype(np.uint8),(800, 600))) # Use cv2_imshow instead of cv2.imshow
plt.imshow(image) # Use cv2_imshow instead of cv2.imshow



# cv2.waitKey()

댓글