【Python】 PyTorch ~ CNN編 ~

■ はじめに

https://dk521123.hatenablog.com/entry/2020/07/05/000000

の続き。
今回は、畳み込みニューラルネットワーク(CNN)について、扱う。

■ サンプル

例1:SRCNN(Super-Resolution Convolutional Neural Network)
 => 超解像 畳み込みニューラルネットワーク

超解像(super resolution)とは

* 入力画像の解像度を高めて出力する技術

例1:SRCNN

import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn, optim
from PIL import Image
from sklearn.datasets import fetch_lfw_people
from torch.utils.data import TensorDataset, DataLoader

# 有名人の画像データ(400サンプル)を取得
# ↓に時間がかかる
lfw_people = fetch_lfw_people(resize=1, color=True)
print("Done download")
lfw_people = lfw_people.images.astype(np.uint8)[:400]

inputs = []
for face_image in lfw_people:
  image = Image.fromarray(face_image)
  small_image = image.resize((
    image.size[0] // 3, image.size[1] // 3))
  low_image = small_image.resize((
    image.size[0], image.size[1]))
  inputs.append(np.asarray(low_image))
inputs = np.asarray(inputs) / 255
outputs = lfw_people / 255

# テストデータ(10サンプル)をTensor形式で抽出
x = torch.tensor(inputs[10:], dtype=torch.float32) \
  .transpose(1, 2) \
  .transpose(1, 3)
y = torch.tensor(outputs[10:], dtype=torch.float32) \
  .transpose(1, 2) \
  .transpose(1, 3)
dataset = TensorDataset(x, y)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# SRCNN(Super-Resolution Convolutional Neural Network)
model = nn.Sequential(
  # 畳み込み層
  nn.Conv2d(3, 64, 9, padding=4),
  # 活性化関数
  nn.ReLU(),
  # 畳み込み層
  nn.Conv2d(64, 32, 1),
  # 活性化関数
  nn.ReLU(),
  # 畳み込み層
  nn.Conv2d(32, 3, 5, padding=2)
)

# 最適化アルゴリズム adAdam
optimizer = optim.Adam(model.parameters())
mse = nn.MSELoss()

max_loop = 10
for i in range(max_loop):
  for batch_x, batch_y in loader:
    # 勾配リセット
    optimizer.zero_grad()
    # 予想値を計算
    y_product = model(batch_x)
    # 平均二乗誤差
    loss = mse(y_product, batch_y)
    # 微分して勾配を計算
    loss.backward()
    # パラメータを更新
    optimizer.step()
  print("[{}/{}] Loss = {}".format(
    (i+1), max_loop, loss.item()))

# テスト用データに適用する
x = torch.tensor(inputs[:20], dtype=torch.float32) \
  .transpose(1, 2) \
  .transpose(1, 3)
y = model(x) \
  .transpose(1, 3) \
  .transpose(1, 2)
y = y.detach().numpy().clip(0, 1)
for i in range(10):
  plt.subplot(131).imshow(inputs[i])
  plt.subplot(132).imshow(outputs[i])
  plt.subplot(133).imshow(y[i])
  plt.show()

print("Done")

出力結果

Done download
[1/10] Loss = 0.024243367835879326
[2/10] Loss = 0.010735847055912018
[3/10] Loss = 0.007179690524935722
[4/10] Loss = 0.006752449087798595
[5/10] Loss = 0.00441069295629859
[6/10] Loss = 0.004010963719338179
[7/10] Loss = 0.0025954623706638813
[8/10] Loss = 0.001879471936263144
[9/10] Loss = 0.0017351649003103375
[10/10] Loss = 0.0018771884497255087
Done

関連記事

PyTorch ~ 深層学習ライブラリ ~
https://dk521123.hatenablog.com/entry/2020/07/05/000000
NumPy ~ 数値計算ライブラリ ~
https://dk521123.hatenablog.com/entry/2018/03/28/224532
Pandas ~ データ解析支援ライブラリ ~
https://dk521123.hatenablog.com/entry/2019/10/22/014957
TensorFlow ~ 入門編 ~
https://dk521123.hatenablog.com/entry/2018/02/16/103500
Keras ~ 深層学習用ライブラリ ~
https://dk521123.hatenablog.com/entry/2020/03/03/235302
Matplotlib ~ グラフ描画ライブラリ ~
https://dk521123.hatenablog.com/entry/2020/03/01/000000
scikit-learn ~ 機械学習用ライブラリ・入門編 ~
https://dk521123.hatenablog.com/entry/2020/03/02/233902
scikit-learn ~ 機械学習用ライブラリ・基本編 ~
https://dk521123.hatenablog.com/entry/2020/03/08/113356
scikit-learn ~ 決定木 / ランダムフォレスト ~
https://dk521123.hatenablog.com/entry/2020/04/04/021413
scikit-learn ~ 線形回帰 ~
https://dk521123.hatenablog.com/entry/2020/07/04/000000
scikit-learn ~ リッジ回帰 ~
https://dk521123.hatenablog.com/entry/2020/04/25/174503