def forward(self, X: TT): """Linearly transform the row vectors in X. Arguments: X: a batch (matrix) of input vectors, one vector per row. Returns: A batch of output vectors (one vector per fow). """ # Explicitely check that the dimensions match assert X.shape[1] == self.isize() return torch.mm(X, self.M.t())