【深層距離学習】画像分類が難しい場合の学習法

前書き

画像分類をする上でどうしても特徴が2クラス間で似通っており、
分類させるのが難しいといったケースが存在します。
そういった場合にどういった学習が有効なのか調査してみました。

深層距離学習

深層距離学習の考え方についてはこちらの記事を参考にしていただけると良いでしょう。

cpp-learning.com

設定した課題から抽出したデータ数値に基準を設け、
その基準から離れているか否かを「距離」で示します。

「距離」は実際には二次元空間(Embedding Space)で表現されることが多く、
「距離」は各データの特徴量を座標情報として示した点と点の距離に近い考え方です。

距離の閾値の定義としてはユークリッド距離やマハラノビス距離など様々で、
距離学習を損失関数に埋め込む際もCenterLoss・ArcFaceなど様々です。

距離学習は教師なし学習であるクラスタリングに非常に近いです。
クラスタリングはデータの類似度に従って、データのグループ分けを行う考え方です。

ledge.ai

クラスタリングでは、似たグループのデータ同士は分散が小さくなるように、
異なるグループのデータ同士は分散が大きくなるように学習が進みます。
次はこの距離学習の考え方を畳み込みニューラルネットワークに導入してみます。

実装編(CenterLoss)


今回はCenterLossを適用しました。
CenterLossはCrossEntropyLossなどの損失関数と併用する形で実装します。
以下はCenterLoss論文の中身を説明した様子が記されています。

www.slideshare.net

CenterLossは以下のコードで動作可能です。

class CenterLoss(nn.Module):

    def __init__(self, num_classes=10, feat_dim=2, use_gpu=True):
        super(CenterLoss, self).__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.use_gpu = use_gpu

        if self.use_gpu:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim).cuda())
        else:
            self.centers = nn.Parameter(torch.randn(self.num_classes, self.feat_dim))

    def forward(self, x, labels):
        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t()
        distmat.addmm_(1, -2, x, self.centers.t())

        classes = torch.arange(self.num_classes).long()
        if self.use_gpu: classes = classes.cuda()
        labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
        mask = labels.eq(classes.expand(batch_size, self.num_classes))

        dist = distmat * mask.float()
        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

        return loss

上記のコードは以下のリポジトリを参考にしました。
github.com

ネットワークモデル

データはMNISTの中でも比較的形が似ている「1 」と「7」の分類を行います。
ネットワークモデルはMNISTに適応したモデルを構築します。

class ConvNet(nn.Module):
    
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1_1 = nn.Conv2d(1, 32, 5, stride=1, padding=2)
        self.prelu1_1 = nn.PReLU()
        self.conv1_2 = nn.Conv2d(32, 32, 5, stride=1, padding=2)
        self.prelu1_2 = nn.PReLU()
        
        self.conv2_1 = nn.Conv2d(32, 64, 5, stride=1, padding=2)
        self.prelu2_1 = nn.PReLU()
        self.conv2_2 = nn.Conv2d(64, 64, 5, stride=1, padding=2)
        self.prelu2_2 = nn.PReLU()
        
        self.conv3_1 = nn.Conv2d(64, 128, 5, stride=1, padding=2)
        self.prelu3_1 = nn.PReLU()
        self.conv3_2 = nn.Conv2d(128, 128, 5, stride=1, padding=2)
        self.prelu3_2 = nn.PReLU()
        
        self.fc1 = nn.Linear(128*3*3, 2)
        self.prelu_fc1 = nn.PReLU()
        self.fc2 = nn.Linear(2, num_classes)
        
    def forward(self, x):
        x = self.prelu1_1(self.conv1_1(x))
        x = self.prelu1_2(self.conv1_2(x))
        x = f.max_pool2d(x, 2)
        
        x = self.prelu2_1(self.conv2_1(x))
        x = self.prelu2_2(self.conv2_2(x))
        x = f.max_pool2d(x, 2)
        
        x = self.prelu3_1(self.conv3_1(x))
        x = self.prelu3_2(self.conv3_2(x))
        x = f.max_pool2d(x, 2)
        
        x = x.view(-1, 128*3*3)
        x = self.prelu_fc1(self.fc1(x))
        y = self.fc2(x)

        return x, y

net = ConvNet().to(device)

順伝播関数において返り値はxとyの2つを返しています。
xにて返している値はCNNから抽出した特徴量を
二次元座標情報として返し、次のCenterLossに入力します。
yは2クラス分類の結果を出力します。

CenterLossとCrossEntropyLossを合計して、損失関数に適用していきます。

criterion_xent = nn.CrossEntropyLoss()
criterion_cent = CenterLoss(num_classes=num_classes, feat_dim=2, use_gpu=True)
optimizer_model = torch.optim.SGD(net.parameters(), lr=0.001, weight_decay=5e-04, momentum=0.9)
optimizer_centloss = torch.optim.SGD(criterion_cent.parameters(), lr=0.5)
lr_sceduler = torch.optim.lr_scheduler.StepLR(optimizer_model, step_size=10, gamma=0.1)

学習の様子

以下は訓練を示したコードの一部を抜粋したものです。
center_loss_alphaはCenterLossの影響度です。
Loss = loss_xent + loss_cent × center_loss_alpha となり、
LossはCenterLossとCrossEntropyLossの合計値です。
今回は試しにcenter_loss_alpha = 1.00にしてみます。

images, labels = images.to(device), labels.to(device)
optimizer_model.zero_grad()

# 順伝搬の計算
features, outputs = net(images)

# centerlossの計算
loss_xent = criterion_xent(outputs, labels)
loss_cent = criterion_cent(features, labels)
loss_cent *= center_loss_alpha
loss = loss_xent+loss_cent
optimizer_model.zero_grad()
optimizer_centloss.zero_grad()
        
train_loss += loss.item()

結果

f:id:electric-city:20201207111714p:plain:h200:w400
CenterLossなし
f:id:electric-city:20201207111718p:plain:h200:w400
CenterLossあり

結果の通りCenterLossを適応した方が、分布のグループ化がはっきりと離れています。

参考文献

cpp-learning.com