【U-Net】PytorchでSemantic Segmentationの推論+切り抜きを実装

前書き

melheaven.hatenadiary.jp

前回、上記の記事のように学習を進めました。
今回は推論をやっていきます。

推論の準備

学習後、作成した重みファイル(.pth)をロードします。

# 学習モデルを引っ張ってくる(前回の記事参照)
model = ResNetUNet(num_class).to(device)

# 学習済みパラメータをロード
state_dict = torch.load("./model/unet_10.pth",
                        map_location={'cuda:0': 'cpu'})
model.load_state_dict(state_dict)

print('ネットワーク設定完了:学習済みの重みをロードしました')

入力する画像をPillowで開きます。

# 1. 元画像の表示
img_original = Image.open('input.jpg')   # [高さ][幅][色RGB]
img_width, img_height = img_original.size

ネットワークモデルを推論モードに変更後、
入力データの画像前処理+Tensorに変換します。
ここではBaseTransformクラスからインスタンスを生成しています。

torch.unsqueeze(0)で元のテンソルを変えず、次元を増やすことで、
ミニバッチ化を実現しているようです。
torch.unsqueezeはonehot化にも有効なようですので、覚えておく。

lilaboc.work

# ネットワークモデルを推論モードに変更
model.eval()

# (RGB)の色の平均値と標準偏差
resize = 224
color_mean = (0.485, 0.456, 0.406)
color_std = (0.229, 0.224, 0.225)

# TestTransformに変換
base_transform = BaseTransform(resize, color_mean, color_std)
img_transformed = base_transform(img_original)  # torch.Size([3, 224, 224])
x_img = img_transformed.unsqueeze(0)  # ミニバッチ化:torch.Size([1, 3, 475, 475])
print(x.shape)

ではBaseTransformクラスを見てみましょう。
コンストラクタではResize・中央を中心にした切り取り・テンソル変換・色情報の標準化
を行い、base_transformとインスタンス名を関数のように呼び出すことで、__call__が発動します。

class BaseTransform():

    def __init__(self, resize, mean, std):
        self.base_transform = transforms.Compose([
            transforms.Resize(resize),  # 短い辺の長さがresizeの大きさになる
            transforms.CenterCrop(resize),  # 画像中央をresize × resizeで切り取り
            transforms.ToTensor(),  # Torchテンソルに変換
            transforms.Normalize(mean, std)  # 色情報の標準化
        ])

    def __call__(self, img):
        return self.base_transform(img)

推論

U-Netを通して得た出力は確率として出力されているので、
最も可能性の高いクラスが求められます(今回ではPixelごとの二値分類なので0 or 1で構成されたTensor)。

ここでなぜか8bit符号化(uint8変換)しないとエラーを吐くのですが、
Pillowで画像を入力時のサイズにResizeしてやります。

outputs = model(x_img.to(device)).cpu()
y = outputs  

# UNetの出力から最大クラスを求め、カラーパレット形式にし、画像サイズを元に戻す
y = y[0].detach().numpy()  # y:torch.Size([1, 2, 475, 475])
y = np.argmax(y, axis=0)

result = Image.fromarray(np.uint8(y))
result = result.resize((img_width, img_height), Image.NEAREST)

検出箇所の切り出し

検出箇所の切り出しはPillowやOpenCVでも可能なようですが、
自分はすごく古典的な方法で行いました。

検出した結果の画素値が0(その画素に対象オブジェクトが存在しない)なら、
元画像の画素値は黒色化させました。
対象オブジェクト以外が黒色になります。

result = result.convert('RGB')

for x in range(img_width):
    for y in range(img_height):
        pixel = result.getpixel((x,y))
        
        if pixel[0] == 0 and pixel[1] == 0 and pixel[2] == 0:
            #黒色にする
            img_original .putpixel((x,y),(0,0,0))
        else:
            continue

plt.imshow(img_original)
plt.show()

感想

これで正常に動作すると思います。
まだまだ発展途上なコードですが、今後ブラッシュアップしていきます。

github.com