# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
batch norm done in fp32 (for fp16 training)
import torch
import torch.nn as nn

[docs]class Fp32BatchNorm(nn.Module): def __init__(self, sync=False, *args, **kwargs): super().__init__() if sync: from fairseq.distributed import utils if utils.get_global_world_size() == 1: sync = False if sync: = nn.SyncBatchNorm(*args, **kwargs) else: = nn.BatchNorm1d(*args, **kwargs) self.sync = sync
[docs] def forward(self, input): if != torch.float: if self.sync: = = if try: = = except: else: output = return output.type_as(input)