# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
[docs]class GradMultiply(torch.autograd.Function):
[docs] @staticmethod
def forward(ctx, x, scale):
ctx.scale = scale
res = x.new(x)
return res
[docs] @staticmethod
def backward(ctx, grad):
return grad * ctx.scale, None