【U-Net】PytorchでSemantic Segmentationの推論+切り抜きを実装
推論の準備
学習後、作成した重みファイル(.pth)をロードします。
# 学習モデルを引っ張ってくる(前回の記事参照) model = ResNetUNet(num_class).to(device) # 学習済みパラメータをロード state_dict = torch.load("./model/unet_10.pth", map_location={'cuda:0': 'cpu'}) model.load_state_dict(state_dict) print('ネットワーク設定完了:学習済みの重みをロードしました')
入力する画像をPillowで開きます。
# 1. 元画像の表示 img_original = Image.open('input.jpg') # [高さ][幅][色RGB] img_width, img_height = img_original.size
ネットワークモデルを推論モードに変更後、
入力データの画像前処理+Tensorに変換します。
ここではBaseTransformクラスからインスタンスを生成しています。
torch.unsqueeze(0)で元のテンソルを変えず、次元を増やすことで、
ミニバッチ化を実現しているようです。
torch.unsqueezeはonehot化にも有効なようですので、覚えておく。
# ネットワークモデルを推論モードに変更 model.eval() # (RGB)の色の平均値と標準偏差 resize = 224 color_mean = (0.485, 0.456, 0.406) color_std = (0.229, 0.224, 0.225) # TestTransformに変換 base_transform = BaseTransform(resize, color_mean, color_std) img_transformed = base_transform(img_original) # torch.Size([3, 224, 224]) x_img = img_transformed.unsqueeze(0) # ミニバッチ化:torch.Size([1, 3, 475, 475]) print(x.shape)
ではBaseTransformクラスを見てみましょう。
コンストラクタではResize・中央を中心にした切り取り・テンソル変換・色情報の標準化
を行い、base_transformとインスタンス名を関数のように呼び出すことで、__call__が発動します。
class BaseTransform(): def __init__(self, resize, mean, std): self.base_transform = transforms.Compose([ transforms.Resize(resize), # 短い辺の長さがresizeの大きさになる transforms.CenterCrop(resize), # 画像中央をresize × resizeで切り取り transforms.ToTensor(), # Torchテンソルに変換 transforms.Normalize(mean, std) # 色情報の標準化 ]) def __call__(self, img): return self.base_transform(img)
推論
U-Netを通して得た出力は確率として出力されているので、
最も可能性の高いクラスが求められます(今回ではPixelごとの二値分類なので0 or 1で構成されたTensor)。
ここでなぜか8bit符号化(uint8変換)しないとエラーを吐くのですが、
Pillowで画像を入力時のサイズにResizeしてやります。
outputs = model(x_img.to(device)).cpu() y = outputs # UNetの出力から最大クラスを求め、カラーパレット形式にし、画像サイズを元に戻す y = y[0].detach().numpy() # y:torch.Size([1, 2, 475, 475]) y = np.argmax(y, axis=0) result = Image.fromarray(np.uint8(y)) result = result.resize((img_width, img_height), Image.NEAREST)
検出箇所の切り出し
検出箇所の切り出しはPillowやOpenCVでも可能なようですが、
自分はすごく古典的な方法で行いました。
検出した結果の画素値が0(その画素に対象オブジェクトが存在しない)なら、
元画像の画素値は黒色化させました。
対象オブジェクト以外が黒色になります。
result = result.convert('RGB') for x in range(img_width): for y in range(img_height): pixel = result.getpixel((x,y)) if pixel[0] == 0 and pixel[1] == 0 and pixel[2] == 0: #黒色にする img_original .putpixel((x,y),(0,0,0)) else: continue plt.imshow(img_original) plt.show()