This is a follow up to the previous notebook training a binary classification model that should detect cyberbullying in Polish Tweets. The dataset comes from a Polish NLP competition - PolEval 2019 (http://2019.poleval.pl/index.php/tasks/task6). It is also included in Polish NLP Benchmark KLEJ (https://klejbenchmark.com/). Our goal is to reach state-of-the-art results, with the following points of reference:

  • Best result in last year competition: 58.58 f1 (n-waves ULMFiT)
  • Best result for a base BERT model on KLEJ: 66.7 (Polish Roberta base)
  • Best result for a large BERT model on KLEJ: 72.4 (XLM-RoBERTa large + NKJP)

To achieve that, we will work with the HuggingFace library and Pytorch.

Setup

Let's start by installing transformers, and importing the relevant libraries. We will now work mostly with Pytorch.

!pip install transformers -q
     |████████████████████████████████| 1.3MB 2.7MB/s 
     |████████████████████████████████| 1.1MB 16.2MB/s 
     |████████████████████████████████| 2.9MB 18.5MB/s 
     |████████████████████████████████| 890kB 53.0MB/s 
  Building wheel for sacremoses (setup.py) ... done
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn import model_selection, metrics
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from transformers import AdamW, get_linear_schedule_with_warmup, BertTokenizerFast, BertPreTrainedModel, BertModel, BertConfig
from tqdm.autonotebook import tqdm
from torch.utils.data.sampler import WeightedRandomSampler

Data preparation

Let's start with downloading the dataset and converting it into a dataframe.

!wget -q https://klejbenchmark.com/static/data/klej_cbd.zip
!unzip -q klej_cbd.zip
df = pd.read_csv('train.tsv', delimiter='\t')
df.columns = ['text', 'label']
df.label = df.label.astype(int)
df = df.dropna().reset_index(drop=True)

We will now switch from a single train-validation split into cross-validation (5 fold). We will be also more careful with the split, applying stratified k-fold split, so that each fold has similar amount of positive labels. With cross-validation, our goal is to benefit from all training data. At the same time, by ensembling the models trained on each fold, we should be reducing random errors, making our ensemble more predictable.

df["kfold"] = -1
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
kf = model_selection.StratifiedKFold(n_splits=5)

for fold, (trn_, val_) in enumerate(kf.split(X=df, y=df.label.values)):
    df.loc[val_, 'kfold'] = fold

df.to_csv('train.csv', index=False)

We will also apply some pre-processing of the tweets. First, we will replace '@anonymized_account' with '@ użytkownik'. Second, we will replace the emoji characters with their plain text counterparts. Both modifications are based on the Polish Roberta training scripts (https://github.com/sdadas/polish-roberta). These changes should allow the Polish BERT model better represent the text.

emoji = {
    '😀': ':D',
    '😃': ':)',
    '😄': ':)',
    '😁': ':)',
    '😆': 'xD',
    '😅': ':)',
    '🤣': 'xD',
    '😂': 'xD',
    '🙂': ':)',
    '🙃': ':)',
    '😉': ';)',
    '😊': ':)',
    '😇': ':)',
    '🥰': ':*',
    '😍': ':*',
    '🤩': ':*',
    '😘': ':*',
    '😗': ':*',
    '☺': ':)',
    '😚': ':*',
    '😋': ':P',
    '😛': ':P',
    '😜': ':P',
    '😝': ':P',
    '🤑': ':P',
    '🤪': ':P',
    '🤗': ':P',
    '🤭': ':P',
    '🤫': ':|',
    '🤔': ':|',
    '🤨': ':|',
    '😐': ':|',
    '😑': ':|',
    '😶': ':|',
    '😏': ':)',
    '😒': ':(',
    '🙄': ':|',
    '🤐': ':|',
    '😬': ':$',
    '😌': 'zzz',
    '😔': ':(',
    '😪': 'zzz',
    '🤤': ':(',
    '🤒': ':(',
    '🤕': ':(',
    '🤢': ':(',
    '🤮': ':(',
    '🤧': ':(',
    '🥵': ':(',
    '🥶': ':(',
    '🥴': ':(',
    '😵': ':(',
    '🤯': ':(',
    '🤠': ':)',
    '🥳': ':)',
    '😎': ':)',
    '🤓': ':)',
    '🧐': ':)',
    '😕': ':(',
    '😟': ':(',
    '🙁': ':(',
    '☹': ':(',
    '😮': ':O',
    '😯': ':O',
    '😲': ':O',
    '😳': ':(',
    '🥺': ':(',
    '😦': ':(',
    '😧': ':(',
    '😨': ':(',
    '😰': ':(',
    '😥': ':(',
    '😢': ':(',
    '😭': ':(',
    '😱': ':(',
    '😖': ':(',
    '😣': ':(',
    '😞': ':(',
    '😓': ':(',
    '😩': ':(',
    '😫': ':(',
    '🥱': 'zzz',
    '😤': ':(',
    '😡': ':(',
    '😠': ':(',
    '🤬': ':(',
    '😈': ']:->',
    '👿': ']:->',
    '💀': ':(',
    '☠': ':(',
    '💋': ':*',
    '💔': ':(',
    '💤': 'zzz'
}
df['text'] = df['text'].apply(lambda r: r.replace("@anonymized_account", "@ użytkownik"))
df['text'] = df['text'].apply(lambda r: "".join((emoji.get(c, c) for c in r)))

Helper functions

class AverageMeter:
    """
    Computes and stores the average and current value
    """

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

Configuration

Let's define some key hyperparameters that influence our training:

  • max length: how many tokens should be used per tweet? Based on the training data, the longest tweet is 91 tokens with Polbert tokenizer, so we will set the max length to 92 tokens and pad all tokens to that length with [PAD] token.
  • batch size: we will use batch size 64, it might be difficult to use a bigger one on some GPUs
  • number of epochs: our dataset is fairly small, so training for a large number of epochs might lead to overfitting. Let's set on 2 epochs here.
  • learning rate: we will use discriminative learning rate, applying a higher learning rate to the classifier layer (that we start with random weights), and a lower learning rate to the encoder (which has been pretrained so it already should have 'good' weights)
  • warm up: we will be using linear schedule with warm up, so the learning rate will be increased for the number of steps defined here, and then linearly decreased to zero
  • pretrained model and tokenizer: we will work again with Polbert uncased model
MAX_LEN = 92
TRAIN_BATCH_SIZE = 64
VALID_BATCH_SIZE = 64
EPOCHS = 2
LR = 2e-5
HEAD_LR = 1e-4
WARMUP_STEPS = 30
BERT_PATH = 'dkleczek/bert-base-polish-uncased-v1'
TOKENIZER = BertTokenizerFast.from_pretrained('dkleczek/bert-base-polish-uncased-v1')

Pytorch Dataset and Model

Let's start by defining Pytorch Dataset. It needs to implement the len and getitem methods. We will again use the HuggingFace tokenizer to convert text into input_ids, mask and token_type_ids that are expected by our BERT layer.

class CBDDataset:
    def __init__(self, text, label):
        self.text = text
        self.label = label
        self.tokenizer = TOKENIZER
    
    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        text = ' '.join(self.text[item].split())
        label = self.label[item]
        enc = self.tokenizer(text, max_length=MAX_LEN, truncation=True, padding='max_length', return_tensors='pt')

        return {
            'ids': enc.input_ids[0],
            'mask': enc.attention_mask[0],
            'token_type_ids': enc.token_type_ids[0],
            'targets': torch.tensor(label, dtype=torch.long)
        } 

Now is the time to define our model! First, let's look at the elements that are normally expected:

  • bert layer: the entire BERT pretrained model is a single layer in our model. We are using again pretrained weights from HuggingFace hub.
  • drop out: it's another hyperparameter that can be tuned, here we set it directly in the model
  • linear classification layer: this is a binary classification problem with 2 classes (True and False) and we define a linear layer for this. This comes with random weights that we initialize here.

We are also doing some modifications here that should help us improve the results:

  • using the full hidden state rather than [CLS] token output: there is some research showing that the last layers of pretrained model are very specific to pretraining task and don't help in finetuning. We will output all hidden states from the model and use the penultimate layer (-2) for our task
  • max pooling: we will take the output from all tokens (768 features * 92 tokens) and take the max value for each feature across all tokens. The intutition here is that the model may encode 'cyberbullying' in the token representation, and if it's contained somewhere in a tweet, we should use that information.
class CBDModel(BertPreTrainedModel):
    def __init__(self, conf):
        super(CBDModel, self).__init__(conf)
        self.bert = BertModel.from_pretrained(BERT_PATH, config=conf)
        self.mx = nn.MaxPool1d(MAX_LEN)
        self.drop_out = nn.Dropout(0.5)
        self.l0 = nn.Linear(768, 2)
        torch.nn.init.normal_(self.l0.weight, std=0.02)
    
    def forward(self, ids, mask, token_type_ids):
        _, _, out = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)
        out = out[-2]
        out = out.permute(0,2,1)
        out = torch.squeeze(self.mx(out))
        out = self.drop_out(out)
        out = self.l0(out)
        return out

We will use cross entropy loss here.

def loss_fn(outputs, targets):
    return nn.CrossEntropyLoss()(outputs, targets)

Training and Evaluation Loop with Weighted Random Sampling

In this section, we define our training and evaluation functions and the runner that executes the training. The key modification here is using Weigthed Random Sampler to address the class imbalance issue.

def train_fn(data_loader, model, optimizer, device, scheduler=None):
    model.train()
    losses = AverageMeter()
    f1s = AverageMeter()
    tk0 = tqdm(data_loader, total=len(data_loader))
    
    for bi, d in enumerate(tk0):
        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]
        targets = d["targets"]

        ids = ids.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)
        targets = targets.to(device, dtype=torch.long)

        model.zero_grad()
        outputs = model(ids=ids, mask=mask, token_type_ids=token_type_ids)
        loss = loss_fn(outputs, targets)
        loss.backward()
        optimizer.step()
        scheduler.step()
        outputs = torch.argmax(outputs, dim=1).cpu().detach().numpy()
        targets = targets.cpu().detach().numpy().astype(int)
        f1 = metrics.f1_score(targets,outputs)
        f1s.update(f1, ids.size(0))
        losses.update(loss.item(), ids.size(0))
        tk0.set_postfix(loss=losses.avg, f1=f1s.avg)
def eval_fn(data_loader, model, device):
    model.eval()
    fin_targets = []
    fin_outputs = []
    with torch.no_grad():
        for bi, d in tqdm(enumerate(data_loader), total=len(data_loader)):
            ids = d["ids"]
            token_type_ids = d["token_type_ids"]
            mask = d["mask"]
            targets = d["targets"]

            ids = ids.to(device, dtype=torch.long)
            token_type_ids = token_type_ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            
            outputs = model(ids=ids, mask=mask, token_type_ids=token_type_ids)
            outputs = torch.argmax(outputs, dim=1).cpu().detach().numpy().tolist()
            fin_targets.extend(targets.cpu().detach().numpy().tolist())
            fin_outputs.extend(outputs)

    f1 = metrics.f1_score(fin_targets,fin_outputs)
    return f1
def run(fold):
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)
    # df_train = df_train[:64]
    # df_valid = df_valid[:64]

    target = df_train.label.values
    class_sample_count = np.array([len(np.where(target == t)[0]) for t in np.unique(target)])
    weight = 1. / class_sample_count
    samples_weight = np.array([weight[t] for t in target])
    samples_weight = torch.from_numpy(samples_weight)
    samples_weigth = samples_weight.double()
    sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
    
    train_dataset = CBDDataset(
        text=df_train.text.values,
        label=df_train.label.values,
    )

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        num_workers=1, 
        sampler=sampler
    )

    valid_dataset = CBDDataset(
        text=df_valid.text.values,
        label=df_valid.label.values,
    )

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=VALID_BATCH_SIZE,
        shuffle=False,
        num_workers=2
    )

    device = torch.device("cuda")
    model_config = BertConfig.from_pretrained(BERT_PATH)
    model_config.output_hidden_states = True
    model = CBDModel(conf=model_config)
    model.to(device)

    num_train_steps = int(len(df_train) / TRAIN_BATCH_SIZE * EPOCHS)
    param_optimizer = list(model.named_parameters())[:-2]
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    
    optimizer_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if (any(nd in n for nd in no_decay))], 'weight_decay': 0.0},
        {'params': model.l0.weight, "lr": HEAD_LR, 'weight_decay': 0.01}, 
        {'params': model.l0.bias, "lr": HEAD_LR, 'weight_decay': 0.0}, 
    ]
    optimizer = AdamW(optimizer_parameters, lr=LR)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, 
        num_warmup_steps=WARMUP_STEPS, 
        num_training_steps=num_train_steps
    )

    print(f"Training is starting for fold: {fold}")
    
    for epoch in range(EPOCHS):
        train_fn(train_data_loader, model, optimizer, device, scheduler=scheduler)
        f1 = eval_fn(valid_data_loader, model, device)
        print(f"Epoch: {epoch}, F1 score = {f1}")
    
    model_path=f"model_{fold}.bin"
    torch.save(model.state_dict(), model_path)

Let's train!

run(0)

Training is starting for fold: 0


Epoch: 0, F1 score = 0.4273255813953489


Epoch: 1, F1 score = 0.5588235294117647
run(1)
Training is starting for fold: 1


Epoch: 0, F1 score = 0.4831460674157303


Epoch: 1, F1 score = 0.5245901639344263
run(2)
Training is starting for fold: 2


Epoch: 0, F1 score = 0.4551083591331269


Epoch: 1, F1 score = 0.5592233009708738
run(3)
Training is starting for fold: 3


Epoch: 0, F1 score = 0.4318181818181818


Epoch: 1, F1 score = 0.5283018867924528
run(4)
Training is starting for fold: 4


Epoch: 0, F1 score = 0.47330960854092524


Epoch: 1, F1 score = 0.5536480686695279

Evaluation and results

We have now trained 5 models on different folds. Let's apply these models on our test set, pre-processed in the same way as our training set. We will average the raw logits (outputs) from each model, and then apply argmax to choose the outputted class.

!wget -q https://raw.githubusercontent.com/ptaszynski/cyberbullying-Polish/master/task%2001/test_set_clean_only_tags.txt
df_test = pd.read_csv('test_features.tsv', delimiter='\t')
df_test.columns = ['text']
df_test['text'] = df_test['text'].apply(lambda r: r.replace("@anonymized_account", "@ użytkownik"))
df_test['text'] = df_test['text'].apply(lambda r: "".join((emoji.get(c, c) for c in r)))
df_test['label'] = 0
df_lbls = pd.read_csv('test_set_clean_only_tags.txt',names=['label'])
labels = df_lbls.label.values
device = torch.device("cuda")
model_config = BertConfig.from_pretrained(BERT_PATH)
model_config.output_hidden_states = True
model1 = CBDModel(conf=model_config)
model1.to(device)
model1.load_state_dict(torch.load("model_0.bin"))
model1.eval()

model2 = CBDModel(conf=model_config)
model2.to(device)
model2.load_state_dict(torch.load("model_1.bin"))
model2.eval()

model3 = CBDModel(conf=model_config)
model3.to(device)
model3.load_state_dict(torch.load("model_2.bin"))
model3.eval()

model4 = CBDModel(conf=model_config)
model4.to(device)
model4.load_state_dict(torch.load("model_3.bin"))
model4.eval()

model5 = CBDModel(conf=model_config)
model5.to(device)
model5.load_state_dict(torch.load("model_4.bin"))
model5.eval();
final_output = []

test_dataset = CBDDataset(
        text=df_test.text.values,
        label=df_test.label.values,
    )

data_loader = torch.utils.data.DataLoader(
    test_dataset,
    shuffle=False,
    batch_size=VALID_BATCH_SIZE,
    num_workers=1
)

with torch.no_grad():
    tk0 = tqdm(data_loader, total=len(data_loader))
    for bi, d in enumerate(tk0):
        ids = d["ids"]
        token_type_ids = d["token_type_ids"]
        mask = d["mask"]

        ids = ids.to(device, dtype=torch.long)
        token_type_ids = token_type_ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype=torch.long)

        outputs1 = model1(ids=ids, mask=mask, token_type_ids=token_type_ids)
        outputs2 = model2(ids=ids, mask=mask, token_type_ids=token_type_ids)
        outputs3 = model3(ids=ids, mask=mask, token_type_ids=token_type_ids)
        outputs4 = model4(ids=ids, mask=mask, token_type_ids=token_type_ids)
        outputs5 = model5(ids=ids, mask=mask, token_type_ids=token_type_ids)

        outputs = (outputs1 + outputs2 + outputs3 + outputs4 + outputs5) / 5
        outputs = torch.argmax(outputs, dim=1).cpu().detach().numpy().tolist()
        final_output.extend(outputs)

precision, recall, f1, _ = precision_recall_fscore_support(labels, final_output, average='binary')
acc = accuracy_score(labels, final_output)
print( {
    'accuracy': acc,
    'f1': f1,
    'precision': precision,
    'recall': recall
})
{'accuracy': 0.905, 'f1': 0.671280276816609, 'precision': 0.6258064516129033, 'recall': 0.7238805970149254}

This looks good! Our F1 score is around 0.66 - 0.68, which is in the range of the best base BERT model on KLEJ Benchmark (Polish Roberta base reported results in the range 0.63-0.69).

Improvements

What can be done to further improve the results? Here are some ideas:

  • Data augmentation. Can we add more variety/examples via text augmentation?
  • More hyperparameter tuning. Key watch out is to ensure a good cross-validation approach, so that we don't tune on the test set.
  • Multi-sample dropout. This technique was used by winning teams in recent Kaggle NLP competitions.
  • Multi-lingual transfer. We have large toxicity datasets in English, can we use that with a multi-lingual model like XLM-Roberta to classify Polish Tweets?
  • Multi-task learning. We could benefit from training a single model on several tasks, e.g. from KLEJ Benchmark, to see if that helps.
  • Ensembling/Stacking. Ensembling results across models with different encoders and fine-tuning protocols is very likely to improve the score even further.