# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn.functional as F
from fairseq import utils
from .conv_tbc import ConvTBC
[docs]class LinearizedConvolution(ConvTBC):
"""An optimized version of nn.Conv1d.
At training time, this module uses ConvTBC, which is an optimized version
of Conv1d. At inference time, it optimizes incremental generation (i.e.,
one time step at a time) by replacing the convolutions with linear layers.
Note that the input order changes from training to inference.
"""
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self._linearized_weight = None
self.register_backward_hook(self._clear_linearized_weight)
[docs] def forward(self, input, incremental_state=None):
"""
Args:
incremental_state: Used to buffer signal; if not None, then input is
expected to contain a single frame. If the input order changes
between time steps, call reorder_incremental_state.
Input:
Time x Batch x Channel during training
Batch x Time x Channel during inference
"""
if incremental_state is None:
output = super().forward(input)
if self.kernel_size[0] > 1 and self.padding[0] > 0:
# remove future timesteps added by padding
output = output[:-self.padding[0], :, :]
return output
# reshape weight
weight = self._get_linearized_weight()
kw = self.kernel_size[0]
bsz = input.size(0) # input: bsz x len x dim
if kw > 1:
input = input.data
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is None:
input_buffer = input.new(bsz, kw, input.size(2)).zero_()
self._set_input_buffer(incremental_state, input_buffer)
else:
# shift buffer
input_buffer[:, :-1, :] = input_buffer[:, 1:, :].clone()
# append next input
input_buffer[:, -1, :] = input[:, -1, :]
input = input_buffer
with torch.no_grad():
output = F.linear(input.view(bsz, -1), weight, self.bias)
return output.view(bsz, 1, -1)
[docs] def reorder_incremental_state(self, incremental_state, new_order):
input_buffer = self._get_input_buffer(incremental_state)
if input_buffer is not None:
input_buffer = input_buffer.index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer)
def _get_input_buffer(self, incremental_state):
return utils.get_incremental_state(self, incremental_state, 'input_buffer')
def _set_input_buffer(self, incremental_state, new_buffer):
return utils.set_incremental_state(self, incremental_state, 'input_buffer', new_buffer)
def _get_linearized_weight(self):
if self._linearized_weight is None:
kw = self.kernel_size[0]
weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous()
assert weight.size() == (self.out_channels, kw, self.in_channels)
self._linearized_weight = weight.view(self.out_channels, -1)
return self._linearized_weight
def _clear_linearized_weight(self, *args):
self._linearized_weight = None