!pip install accelerate transformers datasets def load_ud_data(filename, verbose=True, cutoff=None): """ Loads a treebank from a file. :param filename: the filename of the treebank :param verbose: whether to print progress information :param cutoff: the number of sentences to load :return: A tuple of lists: One for the words, one with the upos tags""" print("read file ", filename) if verbose else None with open(filename, 'r') as f: raw_data = f.readlines() raw_data = ''.join(raw_data) print('parse data into token lists') if verbose else None dep_sents = parse(raw_data) print('apply cutoff of ', cutoff) if verbose else None if cutoff is not None: dep_sents = dep_sents[:cutoff] words = [] uposs = [] print('extract words and tags') if verbose else None for sent in dep_sents: wordss.append([word['form'] for word in sent]) uposs.append([word['upos'] for word in sent]) return words, uposs def get_word_spans(text): """Given a list of sentences, return a list of tensors of word spans.""" spanss = [] for sent in text: spans = torch.zeros(2, len(sent), dtype=int) start = 0 for i,word in enumerate(sent): end = start + len(word) spans[0, i] = start spans[1, i] = end start = end + 1 spanss.append(spans) return spanss import numpy as np import torch import torch.nn.functional as F from torch.utils.data import Dataset from tqdm import tqdm rare_chars = {'$', 'ã', '¸', 'ë', '[', 'Ï', '·', 'à', 'ô', 'ù', '£', 'Á', 'Ø', 'ø', 'ñ', 'µ', 'Ì', '¼', 'Ô', 'ò', '&', 'É', 'ç', 'ú', 'Å', ']', 'Ù', 'Í', 'û', 'å', 'ï', 'ê', 'æ', 'º'} class LinesDataset(Dataset): def __init__(self, lines): self.lines = lines def __len__(self): return len(self.lines) def __getitem__(self, idx): return self.lines[idx] def load_splits_from_file(num_sents): lines = [] with open("europarl-v7.de-en.de", "r") as f: for i, line in enumerate(f): if i == num_sents: break lines.append(line.strip()) # split into train and test. Test is the last 10% of the data split = int(0.9 * len(lines)) train_lines = lines[:split] test_lines = lines[split:] return train_lines, test_lines def preprocess(train_lines, test_lines): """Specific preprocessing steps for the German LM. Replaces rare chars with _ and adds a $ as end of line token. Returns the preprocessed train and test lines.""" train_lines_filtered = [] for line in train_lines: # replace rare chars with _ line_filtered = ''.join(['_' if c in rare_chars else c for c in line]) train_lines_filtered.append(line_filtered) test_lines_filtered = [] for line in test_lines: # replace rare chars with _ line_filtered = ''.join(['_' if c in rare_chars else c for c in line]) test_lines_filtered.append(line_filtered) # add a dollar sign as end of line token train_lines_filtered = [line + "$" for line in train_lines_filtered] test_lines_filtered = [line + "$" for line in test_lines_filtered] return train_lines_filtered, test_lines_filtered def encode(text, chars2idx): # return list of indices, unknown chars are mapped to '_' return [chars2idx[c] if c in chars2idx else chars2idx['_'] for c in text] def decode(indices, idx2chars): # return string of chars if type(indices) == int: # sometimes you only decode a single index indices = [indices] return ''.join([idx2chars[i] for i in indices]) def load_german_data(num_sents=10000, verbose=True): """All preprocessing steps for training German LSTM character LM. THe data is loaded from the europarl corpus, which should be located in the same directory. It returns a tuple with - the encoding dict - the decoding dict - the training data as a LinesDataSet object - the validation data as a LinesDataSet object Validation data is the last 10% of the num_sents sentences. """ # load from file train_lines, test_lines = load_splits_from_file(num_sents) print(f"Loaded {len(train_lines)}, {len(test_lines)} train/test sents") if verbose else None print(f"Preprocess...") if verbose else None train_lines, test_lines = preprocess(train_lines, test_lines) # build an encoding dictionary. "&" is the padding token chars_in_train = sorted(set(''.join(train_lines)+ "&")) chars2idx = {c: i for i, c in enumerate(chars_in_train)} idxchars = {i: c for i, c in enumerate(chars_in_train)} if verbose: print(f"Built encoding and decoding dicts") print(f"Vocab size: {len(chars2idx)}") print(f"Encode train and test data...") if verbose else None train_lines_enc = [torch.tensor(encode(line, chars2idx), dtype=int) for line in train_lines] test_lines_enc = [torch.tensor(encode(line, chars2idx), dtype=int) for line in test_lines] print(f"Create datasets...") if verbose else None train_dataset = LinesDataset(train_lines_enc) test_dataset = LinesDataset(test_lines_enc) return chars2idx, idxchars, train_dataset, test_dataset def determine_early_stopping(losses, k=3): """Losses is a list of losses, k is the number of epochs without improvement Return true if the lowest loss is at least k+1 epochs ago """ if len(losses) < k+1: return False epoch_of_lowest_loss = np.argmin(losses) distance_to_end = len(losses) - epoch_of_lowest_loss # print(f"Epoch of lowest loss: {epoch_of_lowest_loss}, distance to end: {distance_to_end}") return distance_to_end > k import torch import torch.nn as nn import torch.nn.functional as F from utils import encode, decode class LSTMLanguageModel(nn.Module): def __init__(self, vocab_size, embedding_dim=128, hidden_dim=128, lstm_layers=2, lstm_dropout=0.0): super().__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, lstm_layers, dropout=lstm_dropout) self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False) def forward(self, x): # x: (batch_size, seq_len) embs = self.embedding(x) # x: (batch_size, seq_len, embedding_dim) # lstm_out: (batch_size, seq_len, hidden_dim) lstm_out, _ = self.lstm(embs) # TODO # logits: (batch_size, seq_len, vocab_size) logits = self.lm_head(lstm_out) return logits def perplexity(model, dataset, verbose=True): """Compute the perplexity of the model on the given dataset""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): # compute the total loss all_log_probs = [] num_sents = len(dataset) for i, sentence in enumerate(dataset): sentence = sentence.to(device) logits = model(sentence) y = sentence[1:] log_probs = F.log_softmax(logits, dim=-1) log_probs_cut = log_probs[:-1] # get the log probs for the true chars log_probs_true = torch.gather(log_probs_cut, dim=-1, index=y.unsqueeze(-1)).squeeze(-1) all_log_probs.append(log_probs_true) print(f"PPL: {i+1}/{num_sents} ", end="\r", flush=True) if verbose else None log_ps_collected = torch.cat(all_log_probs) num_toks = log_ps_collected.shape[0] # compute the perplexity ppl = ((log_ps_collected / num_toks).sum()*-1).exp() return ppl def generate(model, start_text, chars2idx, idx2chars, num_chars=5, mode="sample"): """Generates num_chars characters as continuation of start_text mode: "greedy" or "sample". "sample" is default. """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with torch.no_grad(): for i in range(num_chars): # encode start_text start_text_enc = torch.tensor(encode(start_text, chars2idx), dtype=int).unsqueeze(0).to(device) # start_text_enc: (1, seq_len) # generate next char logits = model.forward(start_text_enc) # logits: (1, seq_len, vocab_size) # take last char last_char_logits = logits[:, -1, :] # last_char_logits: (1, vocab_size) if mode == "greedy": last_char_idx = torch.argmax(last_char_logits, dim=1) elif mode == "sample": # sample from the distribution last_char_idx = torch.multinomial(F.softmax(last_char_logits, dim=-1), num_samples=1) # last_char_idx: (1) # decode last char last_char = decode(last_char_idx.item(), idx2chars) # append last char to start_text start_text += last_char return start_text ---------------------- # RNNs h = torch.tanh(self.W(x[t]) + self.U(h) + self.b) -------------------------------------- class Encoder: """Mapping between classes and the corresponding indices. >>> classes = ["English", "German", "French"] >>> enc = Encoding(classes) >>> assert "English" == enc.decode(enc.encode("English")) >>> assert "German" == enc.decode(enc.encode("German")) >>> assert "French" == enc.decode(enc.encode("French")) >>> set(range(3)) == set(enc.encode(cl) for cl in classes) True >>> for cl in classes: ... ix = enc.encode(cl) ... assert 0 <= ix <= enc.class_num ... assert cl == enc.decode(ix) """ def __init__(self, classes): self.class_to_ix = {} self.ix_to_class = {} for cl in classes: if cl not in self.class_to_ix: ix = len(self.class_to_ix) self.class_to_ix[cl] = ix self.ix_to_class[ix] = cl def size(self): return len(self.class_to_ix) def encode(self, cl): return self.class_to_ix[cl] def decode(self, ix): return self.ix_to_class[ix] -------------------------- Development set: # load the dev data dev_words, dev_upos = load_ud_data('en_ewt-ud-dev.conllu') assert len(dev_words) == len(dev_upos) assert all([len(dev_words[i]) == len(dev_upos[i]) for i in range(len(dev_words))]) # encode the dev data wrt. the training data dev_upos_enc = [] for upos_list in dev_upos: upos_list_enc = torch.tensor([pos_enc.encode(tag) for tag in upos_list]) dev_upos_enc.append(upos_list_enc) ft_pos_tagger = MLPTagger(ft, pos_enc) optim = torch.optim.Adam(ft_pos_tagger.parameters(), lr=.01) criterion = torch.nn.CrossEntropyLoss() # define the new training loop num_sents = len(train_words) for e in range(5): train_accs = [] train_losses = [] for i, (sent, tags) in enumerate(zip(train_words, train_upos_enc)): tag_pred = ft_pos_tagger(sent) loss = criterion(tag_pred, tags) loss.backward() optim.step() optim.zero_grad() acc = (tag_pred.argmax(dim=1) == tags).float().mean() train_accs.append(acc) train_losses.append(loss.item()) print(f"Epoch {e}, sent. {i} / {num_sents}", end="\r", flush=True), train_loss = sum(train_losses) train_acc = sum(train_accs) / len(train_accs) # report dev set performance with torch.no_grad(): dev_accs = [] dev_losses = [] for i, (sent, tags) in enumerate(zip(dev_words, dev_upos_enc)): tag_pred = ft_pos_tagger(sent) loss = criterion(tag_pred, tags) acc = (tag_pred.argmax(dim=1) == tags).float().mean() dev_accs.append(acc) dev_losses.append(loss.item()) print(f"Epoch {e}, train loss: {train_loss:.2f}, train acc: {train_acc:.2f} | dev loss: {sum(dev_losses):.2f}, dev acc: {sum(dev_accs) / len(dev_accs):.2f}")