!pip install accelerate transformers datasets

def load_ud_data(filename, verbose=True, cutoff=None):
    """
    Loads a treebank from a file.
    :param filename: the filename of the treebank
    :param verbose: whether to print progress information
    :param cutoff: the number of sentences to load
    :return: A tuple of lists: One for the words, one with the upos tags"""
    print("read file ", filename) if verbose else None 
    with open(filename, 'r') as f:
        raw_data = f.readlines()
    raw_data = ''.join(raw_data)
    print('parse data into token lists') if verbose else None
    dep_sents = parse(raw_data)
    print('apply cutoff of ', cutoff) if verbose else None
    if cutoff is not None:
        dep_sents = dep_sents[:cutoff]

    words = []
    uposs = []
    print('extract words and tags') if verbose else None
    for sent in dep_sents:
        wordss.append([word['form'] for word in sent])
        uposs.append([word['upos'] for word in sent])

    return words, uposs


def get_word_spans(text):
    """Given a list of sentences, return a list of tensors of word spans."""

    spanss = []
    for sent in text:
        spans = torch.zeros(2, len(sent), dtype=int)
        start = 0
        for i,word in enumerate(sent):
            end = start + len(word)
            spans[0, i] = start
            spans[1, i] = end
            start = end + 1
        spanss.append(spans)
    return spanss



import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm import tqdm


rare_chars = {'$', 'ã', '¸', 'ë', '[', 'Ï', '·', 'à', 'ô', 'ù', '£', 'Á', 'Ø',
               'ø', 'ñ', 'µ', 'Ì', '¼', 'Ô', 'ò', '&', 'É', 'ç', 'ú', 'Å', 
               ']', 'Ù', 'Í', 'û', 'å', 'ï', 'ê', 'æ', 'º'}

class LinesDataset(Dataset):
    def __init__(self, lines):
        self.lines = lines

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        return self.lines[idx]

def load_splits_from_file(num_sents):
    lines = []
    with open("europarl-v7.de-en.de", "r") as f:
        for i, line in enumerate(f):
            if i == num_sents:
                break
            lines.append(line.strip())

    # split into train and test. Test is the last 10% of the data
    split = int(0.9 * len(lines))
    train_lines = lines[:split]
    test_lines = lines[split:]
    return train_lines, test_lines

def preprocess(train_lines, test_lines):
    """Specific preprocessing steps for the German LM.
    
    Replaces rare chars with _ and adds a $ as end of line token.
    
    Returns the preprocessed train and test lines."""
    train_lines_filtered = []
    for line in train_lines:
        # replace rare chars with _
        line_filtered = ''.join(['_' if c in rare_chars else c for c in line])
        train_lines_filtered.append(line_filtered)

    test_lines_filtered = []
    for line in test_lines:
        # replace rare chars with _
        line_filtered = ''.join(['_' if c in rare_chars else c for c in line])
        test_lines_filtered.append(line_filtered)

    # add a dollar sign as end of line token
    train_lines_filtered = [line + "$" for line in train_lines_filtered]
    test_lines_filtered = [line + "$" for line in test_lines_filtered]

    return train_lines_filtered, test_lines_filtered


def encode(text, chars2idx):
    # return list of indices, unknown chars are mapped to '_'
    return [chars2idx[c] if c in chars2idx else chars2idx['_'] for c in text]

def decode(indices, idx2chars):
    # return string of chars
    if type(indices) == int:
        # sometimes you only decode a single index
        indices = [indices]
    return ''.join([idx2chars[i] for i in indices])

def load_german_data(num_sents=10000, verbose=True):
    """All preprocessing steps for training German LSTM character LM.
    
    THe data is loaded from the europarl corpus, which should
    be located in the same directory.
    It returns a tuple with 
    - the encoding dict
    - the decoding dict
    - the training data as a LinesDataSet object
    - the validation data as a LinesDataSet object

    Validation data is the last 10% of the num_sents sentences.
    """

    # load from file
    train_lines, test_lines = load_splits_from_file(num_sents)
    print(f"Loaded {len(train_lines)}, {len(test_lines)} train/test sents") if verbose else None
    
    print(f"Preprocess...") if verbose else None
    train_lines, test_lines = preprocess(train_lines, test_lines)


    # build an encoding dictionary. "&" is the padding token
    chars_in_train = sorted(set(''.join(train_lines)+ "&"))
    chars2idx = {c: i for i, c in enumerate(chars_in_train)}
    idxchars = {i: c for i, c in enumerate(chars_in_train)}
    if verbose:
        print(f"Built encoding and decoding dicts")
        print(f"Vocab size: {len(chars2idx)}")

    print(f"Encode train and test data...") if verbose else None
    train_lines_enc = [torch.tensor(encode(line, chars2idx), dtype=int) for line in train_lines]
    test_lines_enc = [torch.tensor(encode(line, chars2idx), dtype=int) for line in test_lines]

    print(f"Create datasets...") if verbose else None
    train_dataset = LinesDataset(train_lines_enc)
    test_dataset = LinesDataset(test_lines_enc)

    return chars2idx, idxchars, train_dataset, test_dataset

def determine_early_stopping(losses, k=3):
    """Losses is a list of losses, k is the number of epochs without improvement
    
    Return true if the lowest loss is at least k+1 epochs ago
    """
    if len(losses) < k+1:
        return False
    
    epoch_of_lowest_loss = np.argmin(losses)
    distance_to_end = len(losses) - epoch_of_lowest_loss
    # print(f"Epoch of lowest loss: {epoch_of_lowest_loss}, distance to end: {distance_to_end}")
    return distance_to_end > k
    
    
    
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils import encode, decode

class LSTMLanguageModel(nn.Module):
    
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=128, lstm_layers=2, lstm_dropout=0.0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm    = nn.LSTM(embedding_dim, hidden_dim, lstm_layers, dropout=lstm_dropout)
        self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)

    def forward(self, x):
        # x: (batch_size, seq_len)
        embs = self.embedding(x)
        # x: (batch_size, seq_len, embedding_dim)

        # lstm_out: (batch_size, seq_len, hidden_dim)
        lstm_out, _ = self.lstm(embs) # TODO

        # logits: (batch_size, seq_len, vocab_size)
        logits = self.lm_head(lstm_out)
        return logits


def perplexity(model, dataset, verbose=True): 
    """Compute the perplexity of the model on the given dataset"""
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        
        # compute the total loss
        all_log_probs = []
        num_sents = len(dataset)
        for i, sentence in enumerate(dataset):
            sentence = sentence.to(device)

            logits = model(sentence)
            y = sentence[1:]

            log_probs = F.log_softmax(logits, dim=-1)
            log_probs_cut = log_probs[:-1]

            # get the log probs for the true chars
            log_probs_true = torch.gather(log_probs_cut, dim=-1, index=y.unsqueeze(-1)).squeeze(-1)
            all_log_probs.append(log_probs_true)
            print(f"PPL: {i+1}/{num_sents}     ", end="\r", flush=True) if verbose else None
        
        log_ps_collected = torch.cat(all_log_probs)
        num_toks = log_ps_collected.shape[0]

        # compute the perplexity
        ppl = ((log_ps_collected / num_toks).sum()*-1).exp()
        return ppl
    
def generate(model, start_text, chars2idx, idx2chars, num_chars=5, mode="sample"):
    """Generates num_chars characters as continuation of start_text
    
    mode: "greedy" or "sample". "sample" is default.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        for i in range(num_chars):
            # encode start_text
            start_text_enc = torch.tensor(encode(start_text, chars2idx), dtype=int).unsqueeze(0).to(device)
            # start_text_enc: (1, seq_len)
            # generate next char
            logits = model.forward(start_text_enc)
            # logits: (1, seq_len, vocab_size)
            # take last char
            last_char_logits = logits[:, -1, :]
            # last_char_logits: (1, vocab_size)
            if mode == "greedy":
                last_char_idx = torch.argmax(last_char_logits, dim=1)
            elif mode == "sample":
                # sample from the distribution
                last_char_idx = torch.multinomial(F.softmax(last_char_logits, dim=-1), num_samples=1)
            # last_char_idx: (1)
            # decode last char
            last_char = decode(last_char_idx.item(), idx2chars)
            # append last char to start_text
            start_text += last_char
        return start_text
        
        
 ----------------------







# RNNs

h = torch.tanh(self.W(x[t]) + self.U(h) + self.b)




--------------------------------------

class Encoder:

    """Mapping between classes and the corresponding indices.

    >>> classes = ["English", "German", "French"]
    >>> enc = Encoding(classes)
    >>> assert "English" == enc.decode(enc.encode("English"))
    >>> assert "German" == enc.decode(enc.encode("German"))
    >>> assert "French" == enc.decode(enc.encode("French"))
    >>> set(range(3)) == set(enc.encode(cl) for cl in classes)
    True
    >>> for cl in classes:
    ...     ix = enc.encode(cl)
    ...     assert 0 <= ix <= enc.class_num
    ...     assert cl == enc.decode(ix)
    """

    def __init__(self, classes):
        self.class_to_ix = {}
        self.ix_to_class = {}
        for cl in classes:
            if cl not in self.class_to_ix:
                ix = len(self.class_to_ix)
                self.class_to_ix[cl] = ix
                self.ix_to_class[ix] = cl

    def size(self):
        return len(self.class_to_ix)

    def encode(self, cl):
        return self.class_to_ix[cl]

    def decode(self, ix):
        return self.ix_to_class[ix]
        
        
--------------------------

Development set:

# load the dev data
dev_words, dev_upos = load_ud_data('en_ewt-ud-dev.conllu')

assert len(dev_words) == len(dev_upos)
assert all([len(dev_words[i]) == len(dev_upos[i]) for i in range(len(dev_words))])

# encode the dev data wrt. the training data
dev_upos_enc = []
for upos_list in dev_upos:
    upos_list_enc = torch.tensor([pos_enc.encode(tag) for tag in upos_list])
    dev_upos_enc.append(upos_list_enc)
 
 
 
ft_pos_tagger = MLPTagger(ft, pos_enc)
optim = torch.optim.Adam(ft_pos_tagger.parameters(), lr=.01)
criterion = torch.nn.CrossEntropyLoss()

# define the new training loop

num_sents = len(train_words)

for e in range(5): 
    train_accs = []
    train_losses = []
    for i, (sent, tags) in enumerate(zip(train_words, train_upos_enc)):
        tag_pred = ft_pos_tagger(sent)

        loss = criterion(tag_pred, tags)
        loss.backward()
        optim.step()
        optim.zero_grad()

        acc = (tag_pred.argmax(dim=1) == tags).float().mean()
        train_accs.append(acc)
        train_losses.append(loss.item())
        print(f"Epoch {e}, sent. {i} / {num_sents}", end="\r", flush=True), 
    train_loss = sum(train_losses)
    train_acc = sum(train_accs) / len(train_accs)

    # report dev set performance
    with torch.no_grad():
        dev_accs = []
        dev_losses = []
        for i, (sent, tags) in enumerate(zip(dev_words, dev_upos_enc)):
            tag_pred = ft_pos_tagger(sent)
            loss = criterion(tag_pred, tags)
            acc = (tag_pred.argmax(dim=1) == tags).float().mean()
            dev_accs.append(acc)
            dev_losses.append(loss.item())
    
    print(f"Epoch {e}, train loss: {train_loss:.2f}, train acc: {train_acc:.2f} | dev loss: {sum(dev_losses):.2f}, dev acc: {sum(dev_accs) / len(dev_accs):.2f}")