שימוש באוגמנטציה לפתרון הבעיה של מחסור בתמונות בלמידת מכונה

מחבר:
בתאריך:

אחת הבעיות העיקריות באימון רשתות נוירונויות הוא שהם דורשות המון מידע, אבל מה לעשות שלא תמיד זמינה לנו כמות מספיקה. לדוגמה, אם אנחנו רוצים להציב מצלמות בהרי אפגניסטן במטרה לזהות נמרי שלג נדירים אז תהיה לנו בעיה כי אין כמעט תמונות של מין זה המונה מספר מצומצם של פרטים. וגם לגבי דוגמאות שימוש יותר קונבנציונליות ייתכן שהמידע הזמין לנו הוא מצומצם מכדי לאמן רשת נוירונית.

אחת הדרכים המקובלות להגדלת מאגר המידע העומד לרשותנו היא באמצעות אוגמנטציה augmentation, המגדילה את כמות הדוגמאות באמצעות תוספת עותקים שעברו שינויים קלים של המידע הקיים או בדרך של יצירת מידע סינטטי. במדריך נדגים את הדרך לעשות augmentation של תמונות באמצעות ספריית PyTorch. הספרייה תייצר מגוון תמונות המבוססות על תמונות מקוריות שעברו שינויים דוגמת סיבוב וצביעה. כך נדגים כיצד ניתן באמצעות פקודה פשוטה להגדיל את כמות התמונות היכולות לשמש לאימון הרשת הנוירונית בסדרי גודל.

 

יבוא הספריות שישמשו במדריך ומערכת הקבצים

import os
import random

import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

import torch
import torchvision
import torchvision.transforms
from torchvision.utils import save_image
  • המודולים os ו-random ישמשו לפעולות על מערכת הקבצים ולבחירת מספרים אקראיים.
  • Numpy משמשת לביצוע חישובים מהירים על מערכים וטנסורים.
  • Matplotlib ו-PIL ישמשו לעבודה עם תמונות.
  • PyTorch היא ספרייה של למידת מכונה.
  • בשביל הטרנספורמציות של התמונות נשתמש במודול torchvision.transforms

 

במדריך נשתמש בשתי תיקיות:

DIR_SRC = './data/src/'
DIR_AUGMENTED = './data/augmented/'
  • DIR_SRC - בה יהיו התמונות המקוריות.
  • DIR_AUGMENTED - בה יהיו התמונות שיווצרו בתהליך האוגמנטציה.

 

התמונות המקוריות

את התמונות המקוריות הורדתי מהוויקיפדיה והם שייכות לשלושת סנדקי הבינה המלאכותית AI GodFathers: יהושע בנג'ו, ג'פרי הינטון Geoffrey Hinton , ויאן לאקון Yann LeCun . הרביעי הוא אנדרו נג Andrew Ng שאי אפשר בלעדיו.

נציג את התמונות בתוך טבלה באמצעות הפונקציה הבאה:

def plot_imgs(imgs, nrows, ncols, at_random=True):  
    figsize = (12,12)
    number_of_images = len(imgs)
    
    figure = plt.figure(figsize=figsize)

    for i in range(nrows*ncols):
      ax = figure.add_subplot(nrows, ncols, i+1)

      if at_random:
        idx = random.randint(0, (number_of_images-1))
      else:
        idx = i

      # show the image with the index
      img = Image.open(imgs[idx])
      print(img)
      ax.imshow(img)

    plt.tight_layout()

הפונקציה מקבלת את הפרמטרים:

  • imgs - רשימת התמונות להצגה.
  • nrows - מספר השורות בטבלה.
  • ncols - מספר העמודות בטבלה.
  • at_random - בוליאני. האם להציג את התמונות לפי הסדר או באקראי.
img_paths = []

for filename in os.listdir("./data/src/godfathers_of_ai/"):
  f = os.path.join("./data/src/godfathers_of_ai/", filename)
  # checking if it is a file
  if os.path.isfile(f):
      img_paths.append(f)

plot_imgs(img_paths, 1, 4, False)

התמונות המקוריות שעליהם נעשה אוגמנטציה במדריך

 

יצירת dataset ואוגמנטציה

בזמן שאנחנו מייצרים את מסד הנתונים שישמש ללמידת המכונה אנחנו מעבירים פרמטר transform שבתוכו נגדיר רשימה של טרנספורמציות להפעלה על התמונות. במקרה זה אנחנו יוצרים את מסד הנתונים בהתבסס על תיקייה קיימת המכילה תמונות:

# to create a dataset based on a PyTorch class
# load the train and dataset with "torchvision.datasets"
# module which provides utility classes for building your datasets
# here we use the "ImageFolder"
# to load the images from folders
IMG_SIZE = 120

transformations = [
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomRotation(degrees=40),
        transforms.RandomVerticalFlip(p=0.05),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomGrayscale(p=0.3),
        transforms.RandomInvert(p=0.1),
        transforms.ToTensor(),            
]

train_dataset = torchvision.datasets.ImageFolder(
    root = DIR_SRC,
    # data transformation pipeline
    # including: resizing and converting from image to tensor
    transform = transforms.Compose(transformations)
)

התוצאה:

Dataset ImageFolder
    Number of datapoints: 4
    Root location: ./data/src/
    StandardTransform
Transform: Compose(
               Resize(size=(120, 120), interpolation=bilinear, max_size=None, antialias=None)
               RandomRotation(degrees=[-40.0, 40.0], interpolation=nearest, expand=False, fill=0)
               RandomVerticalFlip(p=0.05)
               RandomHorizontalFlip(p=0.5)
               RandomGrayscale(p=0.3)
               RandomInvert(p=0.1)
               ToTensor()
           )

את הטרנספורמציות עשינו באמצעות המודול torchvision.transforms שיודע לשנות את המראה של התמונות. קיימות מגוון גדול של טרנספורמציות. במקרה זה השתמשנו:

  • transforms.Resize((IMG_SIZE, IMG_SIZE)) - משנה את ממדי התמונות.
  • transforms.RandomRotation(degrees) - מטה חלק אקראי של התמונות במספר מעלות שמעבירים לו כפרמטר.
  • transforms.RandomVerticalFlip(p) - הופך חלק אקראי של התמונות סביב הציר האנכי. החלק מוגדר בפרמטר p.
  • transforms.RandomHorizontalFlip(p) - הופך חלק אקראי p של התמונות סביב הציר האופקי
  • transforms.RandomGrayscale(p) - צובע חלק אקראי p של התמונות בצבע אפור.
  • transforms.RandomInvert(p) - היפוך גוון של חלק אקראי p של התמונות.
  • transforms.ToTensor() - טרנספורמציה שהופכת את התמונה לטנסורים חייבת לבוא בסוף. ההפיכה לטנסורים מאפשרת ל- PyTorch לעבוד עם התמונות.

 

התמונות יווצרו בפועל רק בתנאי שנחלץ אותם מתוך האיטרטור train_dataset.

בתוך הלולאה הבאה נייצר 5 תמונות שעברו טרנספורמציה מכל תמונה מקורית:

counter = 0
for _ in range(5):
    for img, label in train_dataset:
        save_image(img, DIR_AUGMENTED+str(counter)+'.png')
        counter += 1
  • בתוך התיקייה DIR_AUGMENTED נוכל למצוא עכשיו 20 תמונות שנוצרו מ-4 המקוריות.

נאסוף את הנתיבים של התמונות שנוצרו בתהליך לתוך רשימה ונציג אותם בטבלה באמצעות הפונקציה:

img_paths = []

for filename in os.listdir(DIR_AUGMENTED):
  f = os.path.join(DIR_AUGMENTED, filename)
  # checking if it is a file
  if os.path.isfile(f):
      img_paths.append(f)


plot_imgs(img_paths, 3, 3)

התמונות המקוריות שעליהם נעשה אוגמנטציה במדריך

 

מדריכים נוספים שעשויים לעניין אותך:

הגדלת כמות התמונות ללמידת מכונה באמצעות augmentation וספריית Keras

10 דברים שחובה להכיר כשעובדים עם טנסורים של pytorch

רגרסיה קווית להערכת מחירי דירות באמצעות PyTorch ולמידת מכונה

 

לכל המדריכים בנושא של למידת מכונה

 

אהבתם? לא אהבתם? דרגו!

0 הצבעות, ממוצע 0 מתוך 5 כוכבים

 

 

הוסף תגובה חדשה

 

= 4 + 5