8️⃣Segmentatio

HuggingFace: Segmentation

Mask Generation with SAM

Segment Anything(SAM)은 MetaAI에서 공개한 Transformer 기반의 Segmentation 모델입니다.

Dataset

!wget https://huggingface.co/sd-concepts-library/smiling-friend-style/resolve/main/concept_images/21.jpeg -O ./dataset/huggingface_friends.jpg
--2024-05-20 20:33:54--  https://huggingface.co/sd-concepts-library/smiling-friend-style/resolve/main/concept_images/21.jpeg
Resolving huggingface.co (huggingface.co)... 13.225.131.93, 13.225.131.94, 13.225.131.35, ...
Connecting to huggingface.co (huggingface.co)|13.225.131.93|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 69275 (68K) [image/jpeg]
Saving to: ‘./dataset/huggingface_friends.jpg’

./dataset/huggingfa 100%[===================>]  67.65K  --.-KB/s    in 0.002s  

2024-05-20 20:33:54 (33.2 MB/s) - ‘./dataset/huggingface_friends.jpg’ saved [69275/69275]
from PIL import Image

raw_image = Image.open('dataset/huggingface_friends.jpg')
raw_image.resize((720, 375))

SAM Pipeline

from transformers import pipeline

sam_pipe = pipeline("mask-generation",
    "Zigeng/SlimSAM-uniform-77")

points_per_batch 값이 높을수록 파이프라인 추론이 더 효율적입니다.

output = sam_pipe(
    raw_image, 
    points_per_batch=32
)

Inspection Mask

import matplotlib.pyplot as plt
import numpy as np

def show_pipe_masks_on_image(raw_image, outputs):
  plt.imshow(np.array(raw_image))
  ax = plt.gca()
  for mask in outputs["masks"]:
      show_mask(mask, ax=ax, random_color=True)
  plt.axis("off")
  plt.show()
    
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3),
                                np.array([0.6])],
                               axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
show_pipe_masks_on_image(raw_image, output)

Faster Inference: Image and a Single Point

from transformers import SamModel, SamProcessor
model = SamModel.from_pretrained(
    "Zigeng/SlimSAM-uniform-77")

processor = SamProcessor.from_pretrained(
    "Zigeng/SlimSAM-uniform-77")
raw_image.resize((720, 375))
input_points = [[[1600, 700]]]
inputs = processor(
    raw_image,
    input_points=input_points,
    return_tensors="pt"
)
import torch

with torch.no_grad():
    outputs = model(**inputs)
predicted_masks = processor.image_processor.post_process_masks(
    outputs.pred_masks,
    inputs["original_sizes"],
    inputs["reshaped_input_sizes"]
)

predicted_masks 의 길이는 입력에 사용된 이미지의 수에 해당합니다.

len(predicted_masks)
1

Inspection Mask

predicted_mask = predicted_masks[0]
predicted_mask.shape
torch.Size([1, 3, 565, 1004])
outputs.iou_scores
tensor([[[0.5211, 0.5908, 0.4307]]])
def show_mask_on_image(raw_image, mask, return_image=False):
    if not isinstance(mask, torch.Tensor):
      mask = torch.Tensor(mask)

    if len(mask.shape) == 4:
      mask = mask.squeeze()

    fig, axes = plt.subplots(1, 1, figsize=(15, 15))

    mask = mask.cpu().detach()
    axes.imshow(np.array(raw_image))
    show_mask(mask, axes)
    axes.axis("off")
    plt.show()

    if return_image:
      fig = plt.gcf()
      return fig2img(fig)
for i in range(3):
    show_mask_on_image(raw_image, predicted_mask[:, i])

Depth Estimation with DPT

DPT Pipeline

depth_estimator = pipeline(
    task="depth-estimation",
    model="Intel/dpt-hybrid-midas")
raw_image = Image.open('dataset/photo.jpg')
raw_image.resize((806, 621))
output = depth_estimator(raw_image)
output
{'predicted_depth': tensor([[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [2685.2251, 2689.2563, 2688.1436,  ..., 2672.5557, 2666.8252,
           2666.0598],
          [2704.1567, 2702.0527, 2704.7058,  ..., 2678.6560, 2684.4863,
           2670.7288],
          [2695.6284, 2710.9890, 2708.0481,  ..., 2696.0364, 2693.9844,
           2684.9038]]]),
 'depth': <PIL.Image.Image image mode=L size=560x369>}
output["predicted_depth"].shape
torch.Size([1, 384, 384])
output["predicted_depth"].unsqueeze(1).shape
torch.Size([1, 1, 384, 384])

Prediction

prediction = torch.nn.functional.interpolate(
    output["predicted_depth"].unsqueeze(1),
    size=raw_image.size[::-1],
    mode="bicubic",
    align_corners=False,
)
prediction.shape
torch.Size([1, 1, 369, 560])
raw_image.size[::-1],
((369, 560),)
prediction
tensor([[[[   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
              0.0000],
          ...,
          [2682.9749, 2685.8513, 2687.9121,  ..., 2666.2478, 2664.8928,
           2665.7424],
          [2703.8203, 2702.0054, 2701.3359,  ..., 2683.1323, 2677.5796,
           2668.8230],
          [2694.5044, 2704.0776, 2711.4558,  ..., 2694.5051, 2689.3669,
           2683.9182]]]])

Depth Format

output = prediction.squeeze().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
depth = Image.fromarray(formatted)
depth

Last updated