Source code for fairseq.modules.character_token_embedder

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import List, Tuple

import torch
import torch.nn.functional as F
from fairseq.data import Dictionary
from torch import nn


CHAR_PAD_IDX = 0
CHAR_EOS_IDX = 257


logger = logging.getLogger(__name__)


[docs]class CharacterTokenEmbedder(torch.nn.Module): def __init__( self, vocab: Dictionary, filters: List[Tuple[int, int]], char_embed_dim: int, word_embed_dim: int, highway_layers: int, max_char_len: int = 50, char_inputs: bool = False, ): super(CharacterTokenEmbedder, self).__init__() self.onnx_trace = False self.embedding_dim = word_embed_dim self.max_char_len = max_char_len self.char_embeddings = nn.Embedding(257, char_embed_dim, padding_idx=0) self.symbol_embeddings = nn.Parameter(torch.FloatTensor(2, word_embed_dim)) self.eos_idx, self.unk_idx = 0, 1 self.char_inputs = char_inputs self.convolutions = nn.ModuleList() for width, out_c in filters: self.convolutions.append( nn.Conv1d(char_embed_dim, out_c, kernel_size=width) ) last_dim = sum(f[1] for f in filters) self.highway = Highway(last_dim, highway_layers) if highway_layers > 0 else None self.projection = nn.Linear(last_dim, word_embed_dim) assert ( vocab is not None or char_inputs ), "vocab must be set if not using char inputs" self.vocab = None if vocab is not None: self.set_vocab(vocab, max_char_len) self.reset_parameters()
[docs] def prepare_for_onnx_export_(self): self.onnx_trace = True
[docs] def set_vocab(self, vocab, max_char_len): word_to_char = torch.LongTensor(len(vocab), max_char_len) truncated = 0 for i in range(len(vocab)): if i < vocab.nspecial: char_idxs = [0] * max_char_len else: chars = vocab[i].encode() # +1 for padding char_idxs = [c + 1 for c in chars] + [0] * (max_char_len - len(chars)) if len(char_idxs) > max_char_len: truncated += 1 char_idxs = char_idxs[:max_char_len] word_to_char[i] = torch.LongTensor(char_idxs) if truncated > 0: logger.info( "truncated {} words longer than {} characters".format( truncated, max_char_len ) ) self.vocab = vocab self.word_to_char = word_to_char
@property def padding_idx(self): return Dictionary().pad() if self.vocab is None else self.vocab.pad()
[docs] def reset_parameters(self): nn.init.xavier_normal_(self.char_embeddings.weight) nn.init.xavier_normal_(self.symbol_embeddings) nn.init.xavier_uniform_(self.projection.weight) nn.init.constant_( self.char_embeddings.weight[self.char_embeddings.padding_idx], 0.0 ) nn.init.constant_(self.projection.bias, 0.0)
[docs] def forward( self, input: torch.Tensor, ): if self.char_inputs: chars = input.view(-1, self.max_char_len) pads = chars[:, 0].eq(CHAR_PAD_IDX) eos = chars[:, 0].eq(CHAR_EOS_IDX) if eos.any(): if self.onnx_trace: chars = torch.where(eos.unsqueeze(1), chars.new_zeros(1), chars) else: chars[eos] = 0 unk = None else: flat_words = input.view(-1) chars = self.word_to_char[flat_words.type_as(self.word_to_char)].type_as( input ) pads = flat_words.eq(self.vocab.pad()) eos = flat_words.eq(self.vocab.eos()) unk = flat_words.eq(self.vocab.unk()) word_embs = self._convolve(chars) if self.onnx_trace: if pads.any(): word_embs = torch.where( pads.unsqueeze(1), word_embs.new_zeros(1), word_embs ) if eos.any(): word_embs = torch.where( eos.unsqueeze(1), self.symbol_embeddings[self.eos_idx], word_embs ) if unk is not None and unk.any(): word_embs = torch.where( unk.unsqueeze(1), self.symbol_embeddings[self.unk_idx], word_embs ) else: if pads.any(): word_embs[pads] = 0 if eos.any(): word_embs[eos] = self.symbol_embeddings[self.eos_idx] if unk is not None and unk.any(): word_embs[unk] = self.symbol_embeddings[self.unk_idx] return word_embs.view(input.size()[:2] + (-1,))
def _convolve( self, char_idxs: torch.Tensor, ): char_embs = self.char_embeddings(char_idxs) char_embs = char_embs.transpose(1, 2) # BTC -> BCT conv_result = [] for conv in self.convolutions: x = conv(char_embs) x, _ = torch.max(x, -1) x = F.relu(x) conv_result.append(x) x = torch.cat(conv_result, dim=-1) if self.highway is not None: x = self.highway(x) x = self.projection(x) return x
class Highway(torch.nn.Module): """ A `Highway layer <https://arxiv.org/abs/1505.00387>`_. Adopted from the AllenNLP implementation. """ def __init__(self, input_dim: int, num_layers: int = 1): super(Highway, self).__init__() self.input_dim = input_dim self.layers = nn.ModuleList( [nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)] ) self.activation = nn.ReLU() self.reset_parameters() def reset_parameters(self): for layer in self.layers: # As per comment in AllenNLP: # We should bias the highway layer to just carry its input forward. We do that by # setting the bias on `B(x)` to be positive, because that means `g` will be biased to # be high, so we will carry the input forward. The bias on `B(x)` is the second half # of the bias vector in each Linear layer. nn.init.constant_(layer.bias[self.input_dim :], 1) nn.init.constant_(layer.bias[: self.input_dim], 0) nn.init.xavier_normal_(layer.weight) def forward(self, x: torch.Tensor): for layer in self.layers: projection = layer(x) proj_x, gate = projection.chunk(2, dim=-1) proj_x = self.activation(proj_x) gate = torch.sigmoid(gate) x = gate * x + (gate.new_tensor([1]) - gate) * proj_x return x