main의 files 배열에 평균과 표준 편차를 구할 이미지들의 경로를 전달하기만 하면 평균과 표준 편차를 torch.tensor() 형태로 반환해 준다.
from typing import List
import torch
from PIL import Image
from torchvision.transforms import transforms
from glob import glob
class NormStdDataSet(object):
def __init__(self, dataset_list: List[str]):
super().__init__()
self.dataset_list = dataset_list
self.length = len(self.dataset_list)
self.transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
])
def __getitem__(self, index):
img = Image.open(self.dataset_list[index], 'r')
img = self.transform(img)
return img
def __len__(self):
return self.length
def getNormStd(target_set: List[str]):
norm_std_dataset = NormStdDataSet(dataset_list=target_set)
# 평균 및 표준 편차 초기화
mean = torch.zeros(3)
std = torch.zeros(3)
images = None
# 평균 계산
for images in norm_std_dataset:
for i in range(3):
mean[i] += images[i, :, :].mean()
mean /= norm_std_dataset.length
# 표준 편차 계산
for images in norm_std_dataset:
for i in range(3):
std[i] += ((images[i, :, :] - mean[i]) ** 2).sum()
std = torch.sqrt(std / (norm_std_dataset.length * images.size(1) * images.size(2)))
return mean, std
if __name__ == '__main__':
# Norm, Std를 구할 이미지의 경로가 담긴 배열
files = glob('파일경로/*.png')
print(getNormStd(files))