def forward(self, X: TT):
"""Linearly transform the row vectors in X.
Arguments:
X: a batch (matrix) of input vectors, one vector per row.
Returns:
A batch of output vectors (one vector per fow).
"""
# Explicitely check that the dimensions match
assert X.shape[1] == self.isize()
return torch.mm(X, self.M.t())