Â
def rmsprop_update(loss, params, grad_sq, lr=1e-3, alpha=0.8, epsilon=1e-8): """ Perform an RMSprop update on a collection of parameters Args: loss: Tensor A scalar tensor containing the loss whose gradient will be computed params: Iterable Collection of parameters with respect to which we compute gradients grad_sq: Iterable Moving average of squared gradients lr: Float Scalar specifying the learning rate or step-size for the update alpha: Float Moving average parameter epsilon: Float quotient for numerical stability Returns: Nothing """ # Clear up gradients as Pytorch automatically accumulates gradients from # successive backward calls zero_grad(params) # Compute gradients on given objective loss.backward() with torch.no_grad(): for (par, gsq) in zip(params, grad_sq): # Update estimate of gradient variance gsq.data = alpha * gsq.data + (1 - alpha) * par.grad**2 # Update parameters par.data -= lr * (par.grad / (epsilon + gsq.data)**0.5)