CV&DL

PyTorch를 이용한 3채널 이미지 평규과 표준 편차 구하기

main의 files 배열에 평균과 표준 편차를 구할 이미지들의 경로를 전달하기만 하면 평균과 표준 편차를 torch.tensor() 형태로 반환해 준다.

from typing import List
import torch
from PIL import Image
from torchvision.transforms import transforms
from glob import glob


class NormStdDataSet(object):
    def __init__(self, dataset_list: List[str]):
        super().__init__()
        self.dataset_list = dataset_list
        self.length = len(self.dataset_list)
        self.transform = transforms.Compose([
            transforms.Resize(32),
            transforms.ToTensor(),
        ])

    def __getitem__(self, index):
        img = Image.open(self.dataset_list[index], 'r')
        img = self.transform(img)
        return img

    def __len__(self):
        return self.length


def getNormStd(target_set: List[str]):
    norm_std_dataset = NormStdDataSet(dataset_list=target_set)

    # 평균 및 표준 편차 초기화
    mean = torch.zeros(3)
    std = torch.zeros(3)
    images = None

    # 평균 계산
    for images in norm_std_dataset:
        for i in range(3):
            mean[i] += images[i, :, :].mean()

    mean /= norm_std_dataset.length

    # 표준 편차 계산
    for images in norm_std_dataset:
        for i in range(3):
            std[i] += ((images[i, :, :] - mean[i]) ** 2).sum()

    std = torch.sqrt(std / (norm_std_dataset.length * images.size(1) * images.size(2)))

    return mean, std


if __name__ == '__main__':
    # Norm, Std를 구할 이미지의 경로가 담긴 배열
    files = glob('파일경로/*.png')

    print(getNormStd(files))