Source code for fairseq.modules.mean_pool_gating_network

# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch
import torch.nn.functional as F


[docs]class MeanPoolGatingNetwork(torch.nn.Module): """A simple mean-pooling gating network for selecting experts. This module applies mean pooling over an encoder's output and returns reponsibilities for each expert. The encoder format is expected to match :class:`fairseq.models.transformer.TransformerEncoder`. """ def __init__(self, embed_dim, num_experts, dropout=None): super().__init__() self.embed_dim = embed_dim self.num_experts = num_experts self.fc1 = torch.nn.Linear(embed_dim, embed_dim) self.dropout = torch.nn.Dropout(dropout) if dropout is not None else None self.fc2 = torch.nn.Linear(embed_dim, num_experts)
[docs] def forward(self, encoder_out): if not ( isinstance(encoder_out, dict) and 'encoder_out' in encoder_out and 'encoder_padding_mask' in encoder_out and encoder_out['encoder_out'].size(2) == self.embed_dim ): raise ValueError('Unexpected format for encoder_out') # mean pooling over time encoder_padding_mask = encoder_out['encoder_padding_mask'] # B x T encoder_out = encoder_out['encoder_out'].transpose(0, 1) # B x T x C if encoder_padding_mask is not None: encoder_out = encoder_out.clone() # required because of transpose above encoder_out[encoder_padding_mask] = 0 ntokens = torch.sum(1 - encoder_padding_mask, dim=1, keepdim=True) x = torch.sum(encoder_out, dim=1) / ntokens.type_as(encoder_out) else: x = torch.mean(encoder_out, dim=1) x = torch.tanh(self.fc1(x)) if self.dropout is not None: x = self.dropout(x) x = self.fc2(x) return F.log_softmax(x, dim=-1, dtype=torch.float32).type_as(x)