Source code for fairseq.modules.fp32_batch_norm

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
batch norm done in fp32 (for fp16 training)
"""
import torch
import torch.nn as nn


[docs]class Fp32BatchNorm(nn.Module): def __init__(self, sync=False, *args, **kwargs): super().__init__() if sync: from fairseq.distributed import utils if utils.get_global_world_size() == 1: sync = False if sync: self.bn = nn.SyncBatchNorm(*args, **kwargs) else: self.bn = nn.BatchNorm1d(*args, **kwargs) self.sync = sync
[docs] def forward(self, input): if self.bn.running_mean.dtype != torch.float: if self.sync: self.bn.running_mean = self.bn.running_mean.float() self.bn.running_var = self.bn.running_var.float() if self.bn.affine: try: self.bn.weight = self.bn.weight.float() self.bn.bias = self.bn.bias.float() except: self.bn.float() else: self.bn.float() output = self.bn(input.float()) return output.type_as(input)