SAM2 model의 custom 학습을 진행해 보겠습니다. 대화형 인터프리터 버전
import torch
import numpy as np
import cv2,os
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
우선 필요한 package들을 import 해줍니다.
<그전에 sam2 model을 사용할 수 있도록 가상환경 등을 이용하여 환경설정을 해줘야 합니다.>
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
sam2 모델은 선언해 주고 pretrained 모델을 load 해줍니다.
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
데이터를 image_original과 image_mask로 분리하여 원본과 annotation으로 선언해 줍니다.
data = 부분에서 파일 형을 정리해서 갖고 계시면 됩니다.
og_data_dir = "/path/to/your/or_image"
mask_data_dir = "/path/to/your/mask_iamge"
data = [{'image': os.path.join(og_data_dir, name.split('.')[0] + '.png'),
'annotation': os.path.join(mask_data_dir, name)}
for name in os.listdir(mask_data_dir)]
batch단위로 학습하기 위해 image, annotation read 하는 부분을 선언해 주겠습니다.
def read_batch(data, batch_size=4):
"""Read a batch of images and their corresponding annotations."""
batch_entries = np.random.choice(data, batch_size, replace=False)
images, masks, points, labels = [], [], [], []
for entry in batch_entries:
image = cv2.imread(entry["image"])[..., ::-1]
ann_map = cv2.imread(entry["annotation"])
# print(image.shape,ann_map.shape)
# Resize image and annotation
r = np.min([1024 / image.shape[1], 1024 / image.shape[0]])
image = cv2.resize(image, (int(image.shape[1] * r), int(image.shape[0] * r)))
ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)
# print(r)
# print(image.shape,ann_map.shape)
mat_map = ann_map[:, :, 0].astype(np.int64)
ves_map = ann_map[:, :, 2].astype(np.int64)
mat_map[mat_map == 0] = ves_map[mat_map == 0] * (mat_map.max() + 1)
inds = np.unique(mat_map)[1:]
image_masks, image_points = [], []
for ind in inds:
mask = (mat_map == ind).astype(np.uint8)
image_masks.append(mask)
coords = np.argwhere(mask > 0)
random_coord = np.array(coords[np.random.randint(len(coords))])
image_points.append([[random_coord[1], random_coord[0]]])
images.append(image)
masks.append(np.array(image_masks))
points.append(np.array(image_points))
labels.append(np.ones([len(image_masks), 1]))
return images, masks, points, labels
sam2를 학습하기 위한 feature들을 적어주겠습니다.
# Enable training for mask decoder and prompt encoder
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)
# Set up optimizer and gradient scaler
optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=1e-5, weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler()
학습 진행합니다,
# Training loop
best_iou = 0.0
batch_size = 128
for itr in range(50000):
with torch.cuda.amp.autocast():
images, masks, input_points, input_labels = read_batch(data, batch_size)
batch_loss = 0
batch_iou = 0
for i in range(batch_size):
if masks[i].shape[0] == 0:
continue
predictor.set_image(images[i])
mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
input_points[i], input_labels[i], box=None, mask_logits=None, normalize_coords=True
)
sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
points=(unnorm_coords, labels), boxes=None, masks=None,
)
high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
repeat_image=True,
high_res_features=high_res_features,
)
prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])
gt_mask = torch.tensor(masks[i].astype(np.float32)).cuda()
prd_mask = torch.sigmoid(prd_masks[:, 0])
seg_loss = (-gt_mask * torch.log(prd_mask + 1e-5) - (1 - gt_mask) * torch.log(1 - prd_mask + 1e-5)).mean()
inter = (gt_mask * (prd_mask > 0.5)).sum(dim=[1, 2])
iou = inter / (gt_mask.sum(dim=[1, 2]) + (prd_mask > 0.5).sum(dim=[1, 2]) - inter)
score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
loss = seg_loss + score_loss * 0.05
batch_loss += loss
batch_iou += np.mean(iou.cpu().detach().numpy())
# Average loss and IoU over the batch
batch_loss /= batch_size
batch_iou /= batch_size
# Backpropagation and optimization step
predictor.model.zero_grad()
scaler.scale(batch_loss).backward()
scaler.step(optimizer)
scaler.update()
# Update and save model
if itr == 0:
mean_iou = 0
mean_iou = mean_iou * 0.99 + 0.01 * batch_iou
if mean_iou > best_iou * 1.1:
best_iou = mean_iou
torch.save(predictor.model.state_dict(), f"model_batch.torch")
print(f"Step {itr}, Accuracy (IoU) = {mean_iou:.4f}")
끝!
이걸 활용해서 결과도 봐야겠죠?
여기선 학습에 사용된 데이터를 확인한 것인데요 학습하지 않은 데이터로 확인하는 것이 더 효용 있겠죠?
image_path = data[0]['image']
maks_path = data[0]['maks']
num_samples = 2 # number of points/segment to sample
이미지를 읽고 point를 잡아오는 부분입니다,
def read_image(image_path, mask_path): # read and resize image and mask
img = cv2.imread(image_path)[...,::-1] # read image
mask = cv2.imread(mask_path,0)
r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))
mask = cv2.resize(mask, (int(mask.shape[1] * r), int(mask.shape[0] * r)),interpolation=cv2.INTER_NEAREST)
return img, mask
def get_points(mask,num_points): # Sample points inside the input mask
points=[]
for i in range(num_points):
coords = np.argwhere(mask > 0)
yx = np.array(coords[np.random.randint(len(coords))])
points.append([[yx[1], yx[0]]])
return np.array(points)
이미지 마스크, point를 짚어줍니다 get_points부분이 sam2에서 어려워 보일 수 있는데요 천천히 테스트하다 보면 알 수 있습니다.
# read image and sample points
image,mask = read_image(image_path, mask_path)
input_points = get_points(mask,num_samples)
모델을 선언해 주고
학습한 모델은 얻어줘야겠죠 여기서?
# Load model you need to have pretrained model already made
sam2_checkpoint = "./checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
# Build net and load weights
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load("model.torch"))
결과를 보여주고,
annotation (추론한 이미지)와 mix(seg 추론결과를 original image위에 overlay 한 결과 이미지)를 저장하는 부분입니다.
# predict mask
with torch.no_grad():
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=np.ones([input_points.shape[0],1])
)
# Short predicted masks from high to low score
np_masks = np.array(masks[:,0]) # remove .cpu().numpy() as masks is already a numpy array
np_scores = scores[:,0]
shorted_masks = np_masks[np.argsort(np_scores)][::-1]
# Stitch predicted mask into one segmentation mask
seg_map = np.zeros_like(shorted_masks[0],dtype=np.uint8)
occupancy_mask = np.zeros_like(shorted_masks[0],dtype=bool)
for i in range(shorted_masks.shape[0]):
mask = shorted_masks[i]
if (mask*occupancy_mask).sum()/mask.sum()>0.15: continue
mask[occupancy_mask]=0
# Convert mask to boolean for indexing
seg_map[mask > 0.5] = i + 1 # Assuming 0.5 as the threshold for positive prediction
occupancy_mask[mask > 0.5] = 1
# create colored annotation map
height, width = seg_map.shape
# Create an empty RGB image for the colored annotation
rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
# Map each class number to a random color
for id_class in range(1,seg_map.max()+1):
rgb_image[seg_map == id_class] = [np.random.randint(255), np.random.randint(255), np.random.randint(255)]
# save and display
cv2.imwrite("annotation.png",rgb_image)
cv2.imwrite("mix.png",(rgb_image/2+image/2).astype(np.uint8))
# plt.imshow(cv2.resize(rgb_image, (800, 600))) # Use cv2_imshow instead of cv2.imshow
# plt.imshow(cv2.resize((rgb_image/2+image/2).astype(np.uint8),(800, 600))) # Use cv2_imshow instead of cv2.imshow
plt.imshow(image) # Use cv2_imshow instead of cv2.imshow
# cv2.waitKey()
'On Going > Computer Vision' 카테고리의 다른 글
[Super Resolution] Using Hugging Face Diffusers (2) | 2024.09.10 |
---|---|
[SAM2] SAM2 transfer learning with custom datasets, .py format (1) | 2024.09.09 |
[SAM2] segment anything 2 (0) | 2024.08.08 |
[ECW] ECW 파일포맷을 다루고싶어!! (0) | 2024.08.06 |
[ECW] ECW file 포맷을 다루고 싶어! (0) | 2024.08.06 |
댓글