תחי ישראל - אין לנו ארץ אחרת

תחי ישראל -אין לנו ארץ אחרת

זיהוי הודעות SMS ספאמיות בעברית בעזרת מודל טרנספורמר heBERT

מחבר:
בתאריך:

הפצצה של הטלפונים הניידים בהודעות SMS ספאמיות היא מטרד שכיח. במדריך אני משתמש בלמידת מכונה לסיווג מהיר של הודעות. ישנם מגוון של מודלים של למידת מכונה שיכולים להתמודד עם שפה אנושית NLP אבל מ-2017 המודלים הטובים ביותר הם מסוג טרנספורמר transformer. השנה נוסף מודל טרנספורמר heBERT שאומן על טקסטים בעברית ובמקור משמש לאנליזת סנטימנט. במדריך זה אדגים כיצד לאמן את המודל למשימה של זיהוי הודעות ספאם קצרות.

detect spam in hebrew sms messages with hebrew transformer and machine learning

נתקין את חבילת טרנספורמר ע"פ ההוראות בדף החבילה אצל hugging face:

!pip install transformers -q
!pip install tqdm -q
בנוסף, חבילת tqdm בשביל progress bar שיאפשר לנו לעקוב אחר התקדמות תהליכים.
  • בנוסף, חבילת tqdm בשביל progress bar שיאפשר לנו לעקוב אחר התקדמות תהליכים.

נייבא את החבילות:

import numpy as np
import pandas as pd
 
import torch
 
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel, BertTokenizerFast, BertForSequenceClassification
 
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
 
from tqdm.notebook import tqdm

נעבוד עם ספריית pyTorch. אפשר להסתפק ב-CPU:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

הטרנספורמר המאומן בו נשתמש:

MODEL_NAME = "avichr/heBERT"

ניתן לקרוא אודות המודל בדף הפרויקט avichr/heBERT.

בשביל אימון המודל והערכת היכולות השתמשתי ב-141 הודעות SMS שקיבלתי לנייד האישי (80 מתוכם סיווגתי כספאם). את ההודעות סדרתי בשתי רשימות כעין אילה:

spam_messages = [
"תבחרו בי לראשות הממשלה כי אין כמוני. אני בטוח אציל אתכם",
"תענו לסקר הקצר שלנו ותקבלו 2 מסטיק בזוקה במחיר מבצע"
]
Ham_messages = [
"עדכון מהבנק: חשבונך זוכה ב-XXXX שח"
"קוד האימות הוא אישי ואין לשתף אותו עם שום גורם"
]

את הרשימות ריכזתי לתוך data frame כאשר קטגורית ספאם סווגה 1, וההודעות התקינות מסווגות 0:

df = pd.DataFrame(columns=["Message","Category"])
 
for i,msg in enumerate(spam_messages):
 df = df.append({"Message": msg, "Category": 1}, ignore_index=True)
 
for i,msg in enumerate(ham_messages):
 df = df.append({"Message": msg, "Category": 0}, ignore_index=True)

את ההודעות חילקתי ל-3 קבוצות אקראיות על פי החלוקה המקובלת - train, val, test - באמצעות:

X = df["Message"]
y = df["Category"]
 
train_texts, val_texts, train_labels, val_labels = train_test_split(list(X), list(y),
  test_size=0.3,
  random_state=42)
val_texts, test_texts, val_labels, test_labels = train_test_split(val_texts, val_labels,
  test_size=0.33,
  random_state=42)
  • בסוף התהליך, היו לי 98 הודעות ששימשו לאימון המודל 28 להערכה בזמן הריצה של המודל, ו-15 למבחן סופי.

נאתחל את המודל:

model = BertForSequenceClassification.from_pretrained(MODEL_NAME)

נשתמש בטוקנייזר כדי לקודד את ההודעות העבריות למערכים מספריים איתם המחשב יודע לעבוד:

tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME, max_length=512)
train_encodings = tokenizer(list(train_texts), truncation=True, padding=True)
val_encodings = tokenizer(list(val_texts), truncation=True, padding=True)
test_encodings = tokenizer(list(test_texts), truncation=True, padding=True)

קלאס מסד נתונים דרוש להעמסת הנתונים על המודל:

class HeSmsDataset(torch.utils.data.Dataset):
   def __init__(self, encodings, labels): #
       #print(encodings)
       self.encodings = encodings
       self.labels = labels
       # data loading
       # self.data = torch.unsqueeze(torch.from_numpy(X_train.values), 1)
       # self.labels = torch.unsqueeze(torch.from_numpy(y_train.values), 1)
 
   def __getitem__(self, idx):
       # for key, val in self.encodings.items():
       #     print("keys: ", key," vals: ", val)
       item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
       item['labels'] = torch.tensor(self.labels[idx])
       return item
       # dataset[index] to get the index-th item
       # return self.data[idx], self.labels[idx]
 
   def __len__(self):
       # size of dataset
       return len(self.labels)
 
train_dataset = HeSmsDataset(train_encodings, train_labels)
val_dataset = HeSmsDataset(val_encodings, val_labels)
test_dataset = HeSmsDataset(test_encodings, test_labels)

נאמן את המודל משך 5 epochs:

from transformers import Trainer, TrainingArguments
 
model.to(device)
model.train()
 
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
 
optim = torch.optim.AdamW(model.parameters(), lr=5e-5)
 
for epoch in range(5):
   for batch in train_loader:
       optim.zero_grad()
       input_ids = batch['input_ids'].to(device)
       attention_mask = batch['attention_mask'].to(device)
       labels = batch['labels'].to(device)
       outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
       loss = outputs[0]
       loss.backward()
       optim.step()
 
model.eval()

נעריך את התוצאות על סט המבחן שלא השתתף באימון:

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)
def validation(dataloader, device_):
 global model
 predictions_labels = []
 true_labels = []
 #total loss for this epoch.
 total_loss = 0
 model.eval()
 # Evaluate data for one epoch
 for batch in tqdm(dataloader, total=len(dataloader)):
   # add original labels
   true_labels += batch['labels'].numpy().flatten().tolist()
   # move batch to device
   batch = {k:v.type(torch.long).to(device_) for k,v in batch.items()}
   with torch.no_grad():       
     outputs = model(**batch)
     loss, logits = outputs[:2]
     logits = logits.detach().cpu().numpy()
     total_loss += loss.item()
     predict_content = logits.argmax(axis=-1).flatten().tolist()
     predictions_labels += predict_content
 
   avg_epoch_loss = total_loss / len(dataloader)
 
   return true_labels, predictions_labels, avg_epoch_loss
# Get predictions
y_actual, y_pred, avg_epoch_loss = validation(test_loader, device)

מה מידת הדיוק?

acc = accuracy_score(y_actual, y_pred)
f'The accuracy is %.2f' % (acc)
The accuracy is 0.93
  • 93% דיוק

באיזו קטגוריה השגיאות?

print(y_actual)
print(y_pred)
[1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0]
[1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0]
print(confusion_matrix(y_actual, y_pred))
[[5 0]
 [1 8]]
  • 1 מתוך 9 הודעות ספאם בסט הבדיקה סווגה בטעות כלגיטימית.

באיזו הודעה שגה המודל?

z = list(zip(y_actual,y_pred))
for idx,s in enumerate(z):
   if s[0]!=s[1]:
       print(s[0],s[1],idx,list(val_texts)[idx])
1 0 10 מי לדעתך מתאים יותר להיות ראש הממשלה (נא להגיב במספר בלבד)? 1. לפיד 2. נתניהו 3. אחר 4. גנץ איננו שולחים סקר זה למי שכבר ענה. הליכוד

את המודל המאומן שמרתי ב-drive שלי לשימוש בעתיד:

from google.colab import drive
drive.mount('/content/gdrive')
import os
BASE_DIR = '/content/gdrive/MyDrive/projects/he_sms/'

PATH_TO_TRAINED_MODEL = os.path.join(BASE_DIR, 'trained_models', '061022')
 
# save
model.save_pretrained(PATH_TO_TRAINED_MODEL)

לטעינת המודל:

# load
model = BertForSequenceClassification.from_pretrained(PATH_TO_TRAINED_MODEL)

 

מדריכים נוספים שעשויים לעניין אותך

הטרנספורמרים משנים את עולם הבינה המלאכותית

heBERT - מודל Transformer בעברית

איתור SMS ספאמי באמצעות טכנולוגית Transformer

 

לכל המדריכים בנושא של למידת מכונה

 

אהבתם? לא אהבתם? דרגו!

0 הצבעות, ממוצע 0 מתוך 5 כוכבים

 

 

המדריכים באתר עוסקים בנושאי תכנות ופיתוח אישי. הקוד שמוצג משמש להדגמה ולצרכי לימוד. התוכן והקוד המוצגים באתר נבדקו בקפידה ונמצאו תקינים. אבל ייתכן ששימוש במערכות שונות, דוגמת דפדפן או מערכת הפעלה שונה ולאור השינויים הטכנולוגיים התכופים בעולם שבו אנו חיים יגרום לתוצאות שונות מהמצופה. בכל מקרה, אין בעל האתר נושא באחריות לכל שיבוש או שימוש לא אחראי בתכנים הלימודיים באתר.

למרות האמור לעיל, ומתוך רצון טוב, אם נתקלת בקשיים ביישום הקוד באתר מפאת מה שנראה לך כשגיאה או כחוסר עקביות נא להשאיר תגובה עם פירוט הבעיה באזור התגובות בתחתית המדריכים. זה יכול לעזור למשתמשים אחרים שנתקלו באותה בעיה ואם אני רואה שהבעיה עקרונית אני עשוי לערוך התאמה במדריך או להסיר אותו כדי להימנע מהטעיית הציבור.

שימו לב! הסקריפטים במדריכים מיועדים למטרות לימוד בלבד. כשאתם עובדים על הפרויקטים שלכם אתם צריכים להשתמש בספריות וסביבות פיתוח מוכחות, מהירות ובטוחות.

המשתמש באתר צריך להיות מודע לכך שאם וכאשר הוא מפתח קוד בשביל פרויקט הוא חייב לשים לב ולהשתמש בסביבת הפיתוח המתאימה ביותר, הבטוחה ביותר, היעילה ביותר וכמובן שהוא צריך לבדוק את הקוד בהיבטים של יעילות ואבטחה. מי אמר שלהיות מפתח זו עבודה קלה ?

השימוש שלך באתר מהווה ראייה להסכמתך עם הכללים והתקנות שנוסחו בהסכם תנאי השימוש.

הוסף תגובה חדשה

 

 

ענה על השאלה הפשוטה הבאה כתנאי להוספת תגובה:

מהם שלוש רשויות השלטון בישראל?