!pip install accelerate transformers datasets
def load_ud_data(filename, verbose=True, cutoff=None):
"""
Loads a treebank from a file.
:param filename: the filename of the treebank
:param verbose: whether to print progress information
:param cutoff: the number of sentences to load
:return: A tuple of lists: One for the words, one with the upos tags"""
print("read file ", filename) if verbose else None
with open(filename, 'r') as f:
raw_data = f.readlines()
raw_data = ''.join(raw_data)
print('parse data into token lists') if verbose else None
dep_sents = parse(raw_data)
print('apply cutoff of ', cutoff) if verbose else None
if cutoff is not None:
dep_sents = dep_sents[:cutoff]
words = []
uposs = []
print('extract words and tags') if verbose else None
for sent in dep_sents:
wordss.append([word['form'] for word in sent])
uposs.append([word['upos'] for word in sent])
return words, uposs
def get_word_spans(text):
"""Given a list of sentences, return a list of tensors of word spans."""
spanss = []
for sent in text:
spans = torch.zeros(2, len(sent), dtype=int)
start = 0
for i,word in enumerate(sent):
end = start + len(word)
spans[0, i] = start
spans[1, i] = end
start = end + 1
spanss.append(spans)
return spanss
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm import tqdm
rare_chars = {'$', 'ã', '¸', 'ë', '[', 'Ï', '·', 'à', 'ô', 'ù', '£', 'Á', 'Ø',
'ø', 'ñ', 'µ', 'Ì', '¼', 'Ô', 'ò', '&', 'É', 'ç', 'ú', 'Å',
']', 'Ù', 'Í', 'û', 'å', 'ï', 'ê', 'æ', 'º'}
class LinesDataset(Dataset):
def __init__(self, lines):
self.lines = lines
def __len__(self):
return len(self.lines)
def __getitem__(self, idx):
return self.lines[idx]
def load_splits_from_file(num_sents):
lines = []
with open("europarl-v7.de-en.de", "r") as f:
for i, line in enumerate(f):
if i == num_sents:
break
lines.append(line.strip())
# split into train and test. Test is the last 10% of the data
split = int(0.9 * len(lines))
train_lines = lines[:split]
test_lines = lines[split:]
return train_lines, test_lines
def preprocess(train_lines, test_lines):
"""Specific preprocessing steps for the German LM.
Replaces rare chars with _ and adds a $ as end of line token.
Returns the preprocessed train and test lines."""
train_lines_filtered = []
for line in train_lines:
# replace rare chars with _
line_filtered = ''.join(['_' if c in rare_chars else c for c in line])
train_lines_filtered.append(line_filtered)
test_lines_filtered = []
for line in test_lines:
# replace rare chars with _
line_filtered = ''.join(['_' if c in rare_chars else c for c in line])
test_lines_filtered.append(line_filtered)
# add a dollar sign as end of line token
train_lines_filtered = [line + "$" for line in train_lines_filtered]
test_lines_filtered = [line + "$" for line in test_lines_filtered]
return train_lines_filtered, test_lines_filtered
def encode(text, chars2idx):
# return list of indices, unknown chars are mapped to '_'
return [chars2idx[c] if c in chars2idx else chars2idx['_'] for c in text]
def decode(indices, idx2chars):
# return string of chars
if type(indices) == int:
# sometimes you only decode a single index
indices = [indices]
return ''.join([idx2chars[i] for i in indices])
def load_german_data(num_sents=10000, verbose=True):
"""All preprocessing steps for training German LSTM character LM.
THe data is loaded from the europarl corpus, which should
be located in the same directory.
It returns a tuple with
- the encoding dict
- the decoding dict
- the training data as a LinesDataSet object
- the validation data as a LinesDataSet object
Validation data is the last 10% of the num_sents sentences.
"""
# load from file
train_lines, test_lines = load_splits_from_file(num_sents)
print(f"Loaded {len(train_lines)}, {len(test_lines)} train/test sents") if verbose else None
print(f"Preprocess...") if verbose else None
train_lines, test_lines = preprocess(train_lines, test_lines)
# build an encoding dictionary. "&" is the padding token
chars_in_train = sorted(set(''.join(train_lines)+ "&"))
chars2idx = {c: i for i, c in enumerate(chars_in_train)}
idxchars = {i: c for i, c in enumerate(chars_in_train)}
if verbose:
print(f"Built encoding and decoding dicts")
print(f"Vocab size: {len(chars2idx)}")
print(f"Encode train and test data...") if verbose else None
train_lines_enc = [torch.tensor(encode(line, chars2idx), dtype=int) for line in train_lines]
test_lines_enc = [torch.tensor(encode(line, chars2idx), dtype=int) for line in test_lines]
print(f"Create datasets...") if verbose else None
train_dataset = LinesDataset(train_lines_enc)
test_dataset = LinesDataset(test_lines_enc)
return chars2idx, idxchars, train_dataset, test_dataset
def determine_early_stopping(losses, k=3):
"""Losses is a list of losses, k is the number of epochs without improvement
Return true if the lowest loss is at least k+1 epochs ago
"""
if len(losses) < k+1:
return False
epoch_of_lowest_loss = np.argmin(losses)
distance_to_end = len(losses) - epoch_of_lowest_loss
# print(f"Epoch of lowest loss: {epoch_of_lowest_loss}, distance to end: {distance_to_end}")
return distance_to_end > k
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import encode, decode
class LSTMLanguageModel(nn.Module):
def __init__(self, vocab_size, embedding_dim=128, hidden_dim=128, lstm_layers=2, lstm_dropout=0.0):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, lstm_layers, dropout=lstm_dropout)
self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
def forward(self, x):
# x: (batch_size, seq_len)
embs = self.embedding(x)
# x: (batch_size, seq_len, embedding_dim)
# lstm_out: (batch_size, seq_len, hidden_dim)
lstm_out, _ = self.lstm(embs) # TODO
# logits: (batch_size, seq_len, vocab_size)
logits = self.lm_head(lstm_out)
return logits
def perplexity(model, dataset, verbose=True):
"""Compute the perplexity of the model on the given dataset"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
# compute the total loss
all_log_probs = []
num_sents = len(dataset)
for i, sentence in enumerate(dataset):
sentence = sentence.to(device)
logits = model(sentence)
y = sentence[1:]
log_probs = F.log_softmax(logits, dim=-1)
log_probs_cut = log_probs[:-1]
# get the log probs for the true chars
log_probs_true = torch.gather(log_probs_cut, dim=-1, index=y.unsqueeze(-1)).squeeze(-1)
all_log_probs.append(log_probs_true)
print(f"PPL: {i+1}/{num_sents} ", end="\r", flush=True) if verbose else None
log_ps_collected = torch.cat(all_log_probs)
num_toks = log_ps_collected.shape[0]
# compute the perplexity
ppl = ((log_ps_collected / num_toks).sum()*-1).exp()
return ppl
def generate(model, start_text, chars2idx, idx2chars, num_chars=5, mode="sample"):
"""Generates num_chars characters as continuation of start_text
mode: "greedy" or "sample". "sample" is default.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
for i in range(num_chars):
# encode start_text
start_text_enc = torch.tensor(encode(start_text, chars2idx), dtype=int).unsqueeze(0).to(device)
# start_text_enc: (1, seq_len)
# generate next char
logits = model.forward(start_text_enc)
# logits: (1, seq_len, vocab_size)
# take last char
last_char_logits = logits[:, -1, :]
# last_char_logits: (1, vocab_size)
if mode == "greedy":
last_char_idx = torch.argmax(last_char_logits, dim=1)
elif mode == "sample":
# sample from the distribution
last_char_idx = torch.multinomial(F.softmax(last_char_logits, dim=-1), num_samples=1)
# last_char_idx: (1)
# decode last char
last_char = decode(last_char_idx.item(), idx2chars)
# append last char to start_text
start_text += last_char
return start_text
----------------------
# RNNs
h = torch.tanh(self.W(x[t]) + self.U(h) + self.b)
--------------------------------------
class Encoder:
"""Mapping between classes and the corresponding indices.
>>> classes = ["English", "German", "French"]
>>> enc = Encoding(classes)
>>> assert "English" == enc.decode(enc.encode("English"))
>>> assert "German" == enc.decode(enc.encode("German"))
>>> assert "French" == enc.decode(enc.encode("French"))
>>> set(range(3)) == set(enc.encode(cl) for cl in classes)
True
>>> for cl in classes:
... ix = enc.encode(cl)
... assert 0 <= ix <= enc.class_num
... assert cl == enc.decode(ix)
"""
def __init__(self, classes):
self.class_to_ix = {}
self.ix_to_class = {}
for cl in classes:
if cl not in self.class_to_ix:
ix = len(self.class_to_ix)
self.class_to_ix[cl] = ix
self.ix_to_class[ix] = cl
def size(self):
return len(self.class_to_ix)
def encode(self, cl):
return self.class_to_ix[cl]
def decode(self, ix):
return self.ix_to_class[ix]
--------------------------
Development set:
# load the dev data
dev_words, dev_upos = load_ud_data('en_ewt-ud-dev.conllu')
assert len(dev_words) == len(dev_upos)
assert all([len(dev_words[i]) == len(dev_upos[i]) for i in range(len(dev_words))])
# encode the dev data wrt. the training data
dev_upos_enc = []
for upos_list in dev_upos:
upos_list_enc = torch.tensor([pos_enc.encode(tag) for tag in upos_list])
dev_upos_enc.append(upos_list_enc)
ft_pos_tagger = MLPTagger(ft, pos_enc)
optim = torch.optim.Adam(ft_pos_tagger.parameters(), lr=.01)
criterion = torch.nn.CrossEntropyLoss()
# define the new training loop
num_sents = len(train_words)
for e in range(5):
train_accs = []
train_losses = []
for i, (sent, tags) in enumerate(zip(train_words, train_upos_enc)):
tag_pred = ft_pos_tagger(sent)
loss = criterion(tag_pred, tags)
loss.backward()
optim.step()
optim.zero_grad()
acc = (tag_pred.argmax(dim=1) == tags).float().mean()
train_accs.append(acc)
train_losses.append(loss.item())
print(f"Epoch {e}, sent. {i} / {num_sents}", end="\r", flush=True),
train_loss = sum(train_losses)
train_acc = sum(train_accs) / len(train_accs)
# report dev set performance
with torch.no_grad():
dev_accs = []
dev_losses = []
for i, (sent, tags) in enumerate(zip(dev_words, dev_upos_enc)):
tag_pred = ft_pos_tagger(sent)
loss = criterion(tag_pred, tags)
acc = (tag_pred.argmax(dim=1) == tags).float().mean()
dev_accs.append(acc)
dev_losses.append(loss.item())
print(f"Epoch {e}, train loss: {train_loss:.2f}, train acc: {train_acc:.2f} | dev loss: {sum(dev_losses):.2f}, dev acc: {sum(dev_accs) / len(dev_accs):.2f}")