# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq import options, utils
from fairseq.models import (
FairseqEncoder,
FairseqIncrementalDecoder,
FairseqEncoderDecoderModel,
register_model,
register_model_architecture,
)
from fairseq.modules import AdaptiveSoftmax
[docs]@register_model('lstm')
class LSTMModel(FairseqEncoderDecoderModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
[docs] @staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
# fmt: off
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-embed-path', type=str, metavar='STR',
help='path to pre-trained encoder embedding')
parser.add_argument('--encoder-freeze-embed', action='store_true',
help='freeze encoder embeddings')
parser.add_argument('--encoder-hidden-size', type=int, metavar='N',
help='encoder hidden size')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='number of encoder layers')
parser.add_argument('--encoder-bidirectional', action='store_true',
help='make all layers of encoder bidirectional')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-embed-path', type=str, metavar='STR',
help='path to pre-trained decoder embedding')
parser.add_argument('--decoder-freeze-embed', action='store_true',
help='freeze decoder embeddings')
parser.add_argument('--decoder-hidden-size', type=int, metavar='N',
help='decoder hidden size')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='number of decoder layers')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
help='decoder output embedding dimension')
parser.add_argument('--decoder-attention', type=str, metavar='BOOL',
help='decoder attention')
parser.add_argument('--adaptive-softmax-cutoff', metavar='EXPR',
help='comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion')
parser.add_argument('--share-decoder-input-output-embed', default=False,
action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
# Granular dropout settings (if not specified these default to --dropout)
parser.add_argument('--encoder-dropout-in', type=float, metavar='D',
help='dropout probability for encoder input embedding')
parser.add_argument('--encoder-dropout-out', type=float, metavar='D',
help='dropout probability for encoder output')
parser.add_argument('--decoder-dropout-in', type=float, metavar='D',
help='dropout probability for decoder input embedding')
parser.add_argument('--decoder-dropout-out', type=float, metavar='D',
help='dropout probability for decoder output')
# fmt: on
[docs] @classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure that all args are properly defaulted (in case there are any new ones)
base_architecture(args)
if args.encoder_layers != args.decoder_layers:
raise ValueError('--encoder-layers must match --decoder-layers')
def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
embed_dict = utils.parse_embedding(embed_path)
utils.print_embed_overlap(embed_dict, dictionary)
return utils.load_embedding(embed_dict, dictionary, embed_tokens)
if args.encoder_embed_path:
pretrained_encoder_embed = load_pretrained_embedding_from_file(
args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim)
else:
num_embeddings = len(task.source_dictionary)
pretrained_encoder_embed = Embedding(
num_embeddings, args.encoder_embed_dim, task.source_dictionary.pad()
)
if args.share_all_embeddings:
# double check all parameters combinations are valid
if task.source_dictionary != task.target_dictionary:
raise ValueError('--share-all-embeddings requires a joint dictionary')
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path):
raise ValueError(
'--share-all-embed not compatible with --decoder-embed-path'
)
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
'--share-all-embeddings requires --encoder-embed-dim to '
'match --decoder-embed-dim'
)
pretrained_decoder_embed = pretrained_encoder_embed
args.share_decoder_input_output_embed = True
else:
# separate decoder input embeddings
pretrained_decoder_embed = None
if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path,
task.target_dictionary,
args.decoder_embed_dim
)
# one last double check of parameter combinations
if args.share_decoder_input_output_embed and (
args.decoder_embed_dim != args.decoder_out_embed_dim):
raise ValueError(
'--share-decoder-input-output-embeddings requires '
'--decoder-embed-dim to match --decoder-out-embed-dim'
)
if args.encoder_freeze_embed:
pretrained_encoder_embed.weight.requires_grad = False
if args.decoder_freeze_embed:
pretrained_decoder_embed.weight.requires_grad = False
encoder = LSTMEncoder(
dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim,
hidden_size=args.encoder_hidden_size,
num_layers=args.encoder_layers,
dropout_in=args.encoder_dropout_in,
dropout_out=args.encoder_dropout_out,
bidirectional=args.encoder_bidirectional,
pretrained_embed=pretrained_encoder_embed,
)
decoder = LSTMDecoder(
dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim,
hidden_size=args.decoder_hidden_size,
out_embed_dim=args.decoder_out_embed_dim,
num_layers=args.decoder_layers,
dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out,
attention=options.eval_bool(args.decoder_attention),
encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed,
share_input_output_embed=args.share_decoder_input_output_embed,
adaptive_softmax_cutoff=(
options.eval_str_list(args.adaptive_softmax_cutoff, type=int)
if args.criterion == 'adaptive_loss' else None
),
)
return cls(encoder, decoder)
[docs]class LSTMEncoder(FairseqEncoder):
"""LSTM encoder."""
def __init__(
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
dropout_in=0.1, dropout_out=0.1, bidirectional=False,
left_pad=True, pretrained_embed=None, padding_value=0.,
):
super().__init__(dictionary)
self.num_layers = num_layers
self.dropout_in = dropout_in
self.dropout_out = dropout_out
self.bidirectional = bidirectional
self.hidden_size = hidden_size
num_embeddings = len(dictionary)
self.padding_idx = dictionary.pad()
if pretrained_embed is None:
self.embed_tokens = Embedding(num_embeddings, embed_dim, self.padding_idx)
else:
self.embed_tokens = pretrained_embed
self.lstm = LSTM(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=self.dropout_out if num_layers > 1 else 0.,
bidirectional=bidirectional,
)
self.left_pad = left_pad
self.padding_value = padding_value
self.output_units = hidden_size
if bidirectional:
self.output_units *= 2
[docs] def forward(self, src_tokens, src_lengths):
if self.left_pad:
# nn.utils.rnn.pack_padded_sequence requires right-padding;
# convert left-padding to right-padding
src_tokens = utils.convert_padding_direction(
src_tokens,
self.padding_idx,
left_to_right=True,
)
bsz, seqlen = src_tokens.size()
# embed tokens
x = self.embed_tokens(src_tokens)
x = F.dropout(x, p=self.dropout_in, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# pack embedded source tokens into a PackedSequence
packed_x = nn.utils.rnn.pack_padded_sequence(x, src_lengths.data.tolist())
# apply LSTM
if self.bidirectional:
state_size = 2 * self.num_layers, bsz, self.hidden_size
else:
state_size = self.num_layers, bsz, self.hidden_size
h0 = x.new_zeros(*state_size)
c0 = x.new_zeros(*state_size)
packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0))
# unpack outputs and apply dropout
x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value)
x = F.dropout(x, p=self.dropout_out, training=self.training)
assert list(x.size()) == [seqlen, bsz, self.output_units]
if self.bidirectional:
def combine_bidir(outs):
out = outs.view(self.num_layers, 2, bsz, -1).transpose(1, 2).contiguous()
return out.view(self.num_layers, bsz, -1)
final_hiddens = combine_bidir(final_hiddens)
final_cells = combine_bidir(final_cells)
encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
return {
'encoder_out': (x, final_hiddens, final_cells),
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
}
[docs] def reorder_encoder_out(self, encoder_out, new_order):
encoder_out['encoder_out'] = tuple(
eo.index_select(1, new_order)
for eo in encoder_out['encoder_out']
)
if encoder_out['encoder_padding_mask'] is not None:
encoder_out['encoder_padding_mask'] = \
encoder_out['encoder_padding_mask'].index_select(1, new_order)
return encoder_out
[docs] def max_positions(self):
"""Maximum input length supported by the encoder."""
return int(1e5) # an arbitrary large number
class AttentionLayer(nn.Module):
def __init__(self, input_embed_dim, source_embed_dim, output_embed_dim, bias=False):
super().__init__()
self.input_proj = Linear(input_embed_dim, source_embed_dim, bias=bias)
self.output_proj = Linear(input_embed_dim + source_embed_dim, output_embed_dim, bias=bias)
def forward(self, input, source_hids, encoder_padding_mask):
# input: bsz x input_embed_dim
# source_hids: srclen x bsz x source_embed_dim
# x: bsz x source_embed_dim
x = self.input_proj(input)
# compute attention
attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2)
# don't attend over padding
if encoder_padding_mask is not None:
attn_scores = attn_scores.float().masked_fill_(
encoder_padding_mask,
float('-inf')
).type_as(attn_scores) # FP16 support: cast to float and back
attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz
# sum weighted sources
x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0)
x = torch.tanh(self.output_proj(torch.cat((x, input), dim=1)))
return x, attn_scores
[docs]class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder."""
def __init__(
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_output_units=512, pretrained_embed=None,
share_input_output_embed=False, adaptive_softmax_cutoff=None,
):
super().__init__(dictionary)
self.dropout_in = dropout_in
self.dropout_out = dropout_out
self.hidden_size = hidden_size
self.share_input_output_embed = share_input_output_embed
self.need_attn = True
self.adaptive_softmax = None
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
if pretrained_embed is None:
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
else:
self.embed_tokens = pretrained_embed
self.encoder_output_units = encoder_output_units
if encoder_output_units != hidden_size:
self.encoder_hidden_proj = Linear(encoder_output_units, hidden_size)
self.encoder_cell_proj = Linear(encoder_output_units, hidden_size)
else:
self.encoder_hidden_proj = self.encoder_cell_proj = None
self.layers = nn.ModuleList([
LSTMCell(
input_size=hidden_size + embed_dim if layer == 0 else hidden_size,
hidden_size=hidden_size,
)
for layer in range(num_layers)
])
if attention:
# TODO make bias configurable
self.attention = AttentionLayer(hidden_size, encoder_output_units, hidden_size, bias=False)
else:
self.attention = None
if hidden_size != out_embed_dim:
self.additional_fc = Linear(hidden_size, out_embed_dim)
if adaptive_softmax_cutoff is not None:
# setting adaptive_softmax dropout to dropout_out for now but can be redefined
self.adaptive_softmax = AdaptiveSoftmax(num_embeddings, hidden_size, adaptive_softmax_cutoff,
dropout=dropout_out)
elif not self.share_input_output_embed:
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
[docs] def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
x, attn_scores = self.extract_features(
prev_output_tokens, encoder_out, incremental_state
)
return self.output_layer(x), attn_scores
[docs] def output_layer(self, x):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = self.fc_out(x)
return x
[docs] def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if cached_state is None:
return
def reorder_state(state):
if isinstance(state, list):
return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order)
new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
[docs] def max_positions(self):
"""Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number
def make_generation_fast_(self, need_attn=False, **kwargs):
self.need_attn = need_attn
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.uniform_(m.weight, -0.1, 0.1)
nn.init.constant_(m.weight[padding_idx], 0)
return m
def LSTM(input_size, hidden_size, **kwargs):
m = nn.LSTM(input_size, hidden_size, **kwargs)
for name, param in m.named_parameters():
if 'weight' in name or 'bias' in name:
param.data.uniform_(-0.1, 0.1)
return m
def LSTMCell(input_size, hidden_size, **kwargs):
m = nn.LSTMCell(input_size, hidden_size, **kwargs)
for name, param in m.named_parameters():
if 'weight' in name or 'bias' in name:
param.data.uniform_(-0.1, 0.1)
return m
def Linear(in_features, out_features, bias=True, dropout=0):
"""Linear layer (input: N x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.uniform_(-0.1, 0.1)
if bias:
m.bias.data.uniform_(-0.1, 0.1)
return m
@register_model_architecture('lstm', 'lstm')
def base_architecture(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
args.encoder_freeze_embed = getattr(args, 'encoder_freeze_embed', False)
args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', args.encoder_embed_dim)
args.encoder_layers = getattr(args, 'encoder_layers', 1)
args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False)
args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout)
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_freeze_embed = getattr(args, 'decoder_freeze_embed', False)
args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 1)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
args.decoder_attention = getattr(args, 'decoder_attention', '1')
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
args.share_decoder_input_output_embed = getattr(args, 'share_decoder_input_output_embed', False)
args.share_all_embeddings = getattr(args, 'share_all_embeddings', False)
args.adaptive_softmax_cutoff = getattr(args, 'adaptive_softmax_cutoff', '10000,50000,200000')
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
def lstm_wiseman_iwslt_de_en(args):
args.dropout = getattr(args, 'dropout', 0.1)
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', 0)
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', 0)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
base_architecture(args)
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000)
args.encoder_layers = getattr(args, 'encoder_layers', 4)
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1000)
args.decoder_layers = getattr(args, 'decoder_layers', 4)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1000)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', 0)
base_architecture(args)