■ はじめに
https://dk521123.hatenablog.com/entry/2020/07/05/000000
の続き。 今回は、畳み込みニューラルネットワーク(CNN)について、扱う。
■ サンプル
例1:SRCNN(Super-Resolution Convolutional Neural Network) => 超解像 畳み込みニューラルネットワーク
超解像(super resolution)とは
* 入力画像の解像度を高めて出力する技術
例1:SRCNN
import torch import numpy as np import matplotlib.pyplot as plt from torch import nn, optim from PIL import Image from sklearn.datasets import fetch_lfw_people from torch.utils.data import TensorDataset, DataLoader # 有名人の画像データ(400サンプル)を取得 # ↓に時間がかかる lfw_people = fetch_lfw_people(resize=1, color=True) print("Done download") lfw_people = lfw_people.images.astype(np.uint8)[:400] inputs = [] for face_image in lfw_people: image = Image.fromarray(face_image) small_image = image.resize(( image.size[0] // 3, image.size[1] // 3)) low_image = small_image.resize(( image.size[0], image.size[1])) inputs.append(np.asarray(low_image)) inputs = np.asarray(inputs) / 255 outputs = lfw_people / 255 # テストデータ(10サンプル)をTensor形式で抽出 x = torch.tensor(inputs[10:], dtype=torch.float32) \ .transpose(1, 2) \ .transpose(1, 3) y = torch.tensor(outputs[10:], dtype=torch.float32) \ .transpose(1, 2) \ .transpose(1, 3) dataset = TensorDataset(x, y) loader = DataLoader(dataset, batch_size=32, shuffle=True) # SRCNN(Super-Resolution Convolutional Neural Network) model = nn.Sequential( # 畳み込み層 nn.Conv2d(3, 64, 9, padding=4), # 活性化関数 nn.ReLU(), # 畳み込み層 nn.Conv2d(64, 32, 1), # 活性化関数 nn.ReLU(), # 畳み込み層 nn.Conv2d(32, 3, 5, padding=2) ) # 最適化アルゴリズム adAdam optimizer = optim.Adam(model.parameters()) mse = nn.MSELoss() max_loop = 10 for i in range(max_loop): for batch_x, batch_y in loader: # 勾配リセット optimizer.zero_grad() # 予想値を計算 y_product = model(batch_x) # 平均二乗誤差 loss = mse(y_product, batch_y) # 微分して勾配を計算 loss.backward() # パラメータを更新 optimizer.step() print("[{}/{}] Loss = {}".format( (i+1), max_loop, loss.item())) # テスト用データに適用する x = torch.tensor(inputs[:20], dtype=torch.float32) \ .transpose(1, 2) \ .transpose(1, 3) y = model(x) \ .transpose(1, 3) \ .transpose(1, 2) y = y.detach().numpy().clip(0, 1) for i in range(10): plt.subplot(131).imshow(inputs[i]) plt.subplot(132).imshow(outputs[i]) plt.subplot(133).imshow(y[i]) plt.show() print("Done")
出力結果
Done download [1/10] Loss = 0.024243367835879326 [2/10] Loss = 0.010735847055912018 [3/10] Loss = 0.007179690524935722 [4/10] Loss = 0.006752449087798595 [5/10] Loss = 0.00441069295629859 [6/10] Loss = 0.004010963719338179 [7/10] Loss = 0.0025954623706638813 [8/10] Loss = 0.001879471936263144 [9/10] Loss = 0.0017351649003103375 [10/10] Loss = 0.0018771884497255087 Done
関連記事
PyTorch ~ 深層学習ライブラリ ~
https://dk521123.hatenablog.com/entry/2020/07/05/000000
NumPy ~ 数値計算ライブラリ ~
https://dk521123.hatenablog.com/entry/2018/03/28/224532
Pandas ~ データ解析支援ライブラリ ~
https://dk521123.hatenablog.com/entry/2019/10/22/014957
TensorFlow ~ 入門編 ~
https://dk521123.hatenablog.com/entry/2018/02/16/103500
Keras ~ 深層学習用ライブラリ ~
https://dk521123.hatenablog.com/entry/2020/03/03/235302
Matplotlib ~ グラフ描画ライブラリ ~
https://dk521123.hatenablog.com/entry/2020/03/01/000000
scikit-learn ~ 機械学習用ライブラリ・入門編 ~
https://dk521123.hatenablog.com/entry/2020/03/02/233902
scikit-learn ~ 機械学習用ライブラリ・基本編 ~
https://dk521123.hatenablog.com/entry/2020/03/08/113356
scikit-learn ~ 決定木 / ランダムフォレスト ~
https://dk521123.hatenablog.com/entry/2020/04/04/021413
scikit-learn ~ 線形回帰 ~
https://dk521123.hatenablog.com/entry/2020/07/04/000000
scikit-learn ~ リッジ回帰 ~
https://dk521123.hatenablog.com/entry/2020/04/25/174503