Tutorial: Classifying Names with a Character-Level RNN

In this tutorial we will extend fairseq to support classification tasks. In particular we will re-implement the PyTorch tutorial for Classifying Names with a Character-Level RNN in fairseq. It is recommended to quickly skim that tutorial before beginning this one.

This tutorial covers:

  1. Preprocessing the data to create dictionaries.
  2. Registering a new Model that encodes an input sentence with a simple RNN and predicts the output label.
  3. Registering a new Task that loads our dictionaries and dataset.
  4. Training the Model using the existing command-line tools.
  5. Writing an evaluation script that imports fairseq and allows us to interactively evaluate our model on new inputs.

1. Preprocessing the data

The original tutorial provides raw data, but we’ll work with a modified version of the data that is already tokenized into characters and split into separate train, valid and test sets.

Download and extract the data from here: tutorial_names.tar.gz

Once extracted, let’s preprocess the data using the fairseq-preprocess command-line tool to create the dictionaries. While this tool is primarily intended for sequence-to-sequence problems, we’re able to reuse it here by treating the label as a “target” sequence of length 1. We’ll also output the preprocessed files in “raw” format using the --dataset-impl option to enhance readability:

> fairseq-preprocess \
  --trainpref names/train --validpref names/valid --testpref names/test \
  --source-lang input --target-lang label \
  --destdir names-bin --dataset-impl raw

After running the above command you should see a new directory, names-bin/, containing the dictionaries for inputs and labels.

2. Registering a new Model

Next we’ll register a new model in fairseq that will encode an input sentence with a simple RNN and predict the output label. Compared to the original PyTorch tutorial, our version will also work with batches of data and GPU Tensors.

First let’s copy the simple RNN module implemented in the PyTorch tutorial. Create a new file named fairseq/models/rnn_classifier.py with the following contents:

import torch
import torch.nn as nn

class RNN(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(RNN, self).__init__()

        self.hidden_size = hidden_size

        self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
        self.i2o = nn.Linear(input_size + hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.i2h(combined)
        output = self.i2o(combined)
        output = self.softmax(output)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, self.hidden_size)

We must also register this model with fairseq using the register_model() function decorator. Once the model is registered we’ll be able to use it with the existing Command-line Tools.

All registered models must implement the BaseFairseqModel interface, so we’ll create a small wrapper class in the same file and register it in fairseq with the name 'rnn_classifier':

from fairseq.models import BaseFairseqModel, register_model

# Note: the register_model "decorator" should immediately precede the
# definition of the Model class.

class FairseqRNNClassifier(BaseFairseqModel):

    def add_args(parser):
        # Models can override this method to add new command-line arguments.
        # Here we'll add a new command-line argument to configure the
        # dimensionality of the hidden state.
            '--hidden-dim', type=int, metavar='N',
            help='dimensionality of the hidden state',

    def build_model(cls, args, task):
        # Fairseq initializes models by calling the ``build_model()``
        # function. This provides more flexibility, since the returned model
        # instance can be of a different type than the one that was called.
        # In this case we'll just return a FairseqRNNClassifier instance.

        # Initialize our RNN module
        rnn = RNN(
            # We'll define the Task in the next section, but for now just
            # notice that the task holds the dictionaries for the "source"
            # (i.e., the input sentence) and "target" (i.e., the label).

        # Return the wrapped version of the module
        return FairseqRNNClassifier(

    def __init__(self, rnn, input_vocab):
        super(FairseqRNNClassifier, self).__init__()

        self.rnn = rnn
        self.input_vocab = input_vocab

        # The RNN module in the tutorial expects one-hot inputs, so we can
        # precompute the identity matrix to help convert from indices to
        # one-hot vectors. We register it as a buffer so that it is moved to
        # the GPU when ``cuda()`` is called.
        self.register_buffer('one_hot_inputs', torch.eye(len(input_vocab)))

    def forward(self, src_tokens, src_lengths):
        # The inputs to the ``forward()`` function are determined by the
        # Task, and in particular the ``'net_input'`` key in each
        # mini-batch. We'll define the Task in the next section, but for
        # now just know that *src_tokens* has shape `(batch, src_len)` and
        # *src_lengths* has shape `(batch)`.
        bsz, max_src_len = src_tokens.size()

        # Initialize the RNN hidden state. Compared to the original PyTorch
        # tutorial we'll also handle batched inputs and work on the GPU.
        hidden = self.rnn.initHidden()
        hidden = hidden.repeat(bsz, 1)  # expand for batched inputs
        hidden = hidden.to(src_tokens.device)  # move to GPU

        for i in range(max_src_len):
            # WARNING: The inputs have padding, so we should mask those
            # elements here so that padding doesn't affect the results.
            # This is left as an exercise for the reader. The padding symbol
            # is given by ``self.input_vocab.pad()`` and the unpadded length
            # of each input is given by *src_lengths*.

            # One-hot encode a batch of input characters.
            input = self.one_hot_inputs[src_tokens[:, i].long()]

            # Feed the input to our RNN.
            output, hidden = self.rnn(input, hidden)

        # Return the final output state for making a prediction
        return output

Finally let’s define a named architecture with the configuration for our model. This is done with the register_model_architecture() function decorator. Thereafter this named architecture can be used with the --arch command-line argument, e.g., --arch pytorch_tutorial_rnn:

from fairseq.models import register_model_architecture

# The first argument to ``register_model_architecture()`` should be the name
# of the model we registered above (i.e., 'rnn_classifier'). The function we
# register here should take a single argument *args* and modify it in-place
# to match the desired architecture.

@register_model_architecture('rnn_classifier', 'pytorch_tutorial_rnn')
def pytorch_tutorial_rnn(args):
    # We use ``getattr()`` to prioritize arguments that are explicitly given
    # on the command-line, so that the defaults defined below are only used
    # when no other value has been specified.
    args.hidden_dim = getattr(args, 'hidden_dim', 128)

3. Registering a new Task

Now we’ll register a new FairseqTask that will load our dictionaries and dataset. Tasks can also control how the data is batched into mini-batches, but in this tutorial we’ll reuse the batching provided by fairseq.data.LanguagePairDataset.

Create a new file named fairseq/tasks/simple_classification.py with the following contents:

import os
import torch

from fairseq.data import Dictionary, LanguagePairDataset
from fairseq.tasks import LegacyFairseqTask, register_task

class SimpleClassificationTask(LegacyFairseqTask):

    def add_args(parser):
        # Add some command-line arguments for specifying where the data is
        # located and the maximum supported input length.
        parser.add_argument('data', metavar='FILE',
                            help='file prefix for data')
        parser.add_argument('--max-positions', default=1024, type=int,
                            help='max input length')

    def setup_task(cls, args, **kwargs):
        # Here we can perform any setup required for the task. This may include
        # loading Dictionaries, initializing shared Embedding layers, etc.
        # In this case we'll just load the Dictionaries.
        input_vocab = Dictionary.load(os.path.join(args.data, 'dict.input.txt'))
        label_vocab = Dictionary.load(os.path.join(args.data, 'dict.label.txt'))
        print('| [input] dictionary: {} types'.format(len(input_vocab)))
        print('| [label] dictionary: {} types'.format(len(label_vocab)))

        return SimpleClassificationTask(args, input_vocab, label_vocab)

    def __init__(self, args, input_vocab, label_vocab):
        self.input_vocab = input_vocab
        self.label_vocab = label_vocab

    def load_dataset(self, split, **kwargs):
        """Load a given dataset split (e.g., train, valid, test)."""

        prefix = os.path.join(self.args.data, '{}.input-label'.format(split))

        # Read input sentences.
        sentences, lengths = [], []
        with open(prefix + '.input', encoding='utf-8') as file:
            for line in file:
                sentence = line.strip()

                # Tokenize the sentence, splitting on spaces
                tokens = self.input_vocab.encode_line(
                    sentence, add_if_not_exist=False,


        # Read labels.
        labels = []
        with open(prefix + '.label', encoding='utf-8') as file:
            for line in file:
                label = line.strip()
                    # Convert label to a numeric ID.

        assert len(sentences) == len(labels)
        print('| {} {} {} examples'.format(self.args.data, split, len(sentences)))

        # We reuse LanguagePairDataset since classification can be modeled as a
        # sequence-to-sequence task where the target sequence has length 1.
        self.datasets[split] = LanguagePairDataset(
            tgt_sizes=torch.ones(len(labels)),  # targets have length 1
            # Since our target is a single class label, there's no need for
            # teacher forcing. If we set this to ``True`` then our Model's
            # ``forward()`` method would receive an additional argument called
            # *prev_output_tokens* that would contain a shifted version of the
            # target sequence.

    def max_positions(self):
        """Return the max input length allowed by the task."""
        # The source should be less than *args.max_positions* and the "target"
        # has max length 1.
        return (self.args.max_positions, 1)

    def source_dictionary(self):
        """Return the source :class:`~fairseq.data.Dictionary`."""
        return self.input_vocab

    def target_dictionary(self):
        """Return the target :class:`~fairseq.data.Dictionary`."""
        return self.label_vocab

    # We could override this method if we wanted more control over how batches
    # are constructed, but it's not necessary for this tutorial since we can
    # reuse the batching provided by LanguagePairDataset.
    # def get_batch_iterator(
    #     self, dataset, max_tokens=None, max_sentences=None, max_positions=None,
    #     ignore_invalid_inputs=False, required_batch_size_multiple=1,
    #     seed=1, num_shards=1, shard_id=0, num_workers=0, epoch=1,
    #     data_buffer_size=0, disable_iterator_cache=False,
    # ):
    #     (...)

4. Training the Model

Now we’re ready to train the model. We can use the existing fairseq-train command-line tool for this, making sure to specify our new Task (--task simple_classification) and Model architecture (--arch pytorch_tutorial_rnn):


You can also configure the dimensionality of the hidden state by passing the --hidden-dim argument to fairseq-train.

> fairseq-train names-bin \
  --task simple_classification \
  --arch pytorch_tutorial_rnn \
  --optimizer adam --lr 0.001 --lr-shrink 0.5 \
  --max-tokens 1000
| epoch 027 | loss 1.200 | ppl 2.30 | wps 15728 | ups 119.4 | wpb 116 | bsz 116 | num_updates 3726 | lr 1.5625e-05 | gnorm 1.290 | clip 0% | oom 0 | wall 32 | train_wall 21
| epoch 027 | valid on 'valid' subset | valid_loss 1.41304 | valid_ppl 2.66 | num_updates 3726 | best 1.41208
| done training in 31.6 seconds

The model files should appear in the checkpoints/ directory.

5. Writing an evaluation script

Finally we can write a short script to evaluate our model on new inputs. Create a new file named eval_classifier.py with the following contents:

from fairseq import checkpoint_utils, data, options, tasks

# Parse command-line arguments for generation
parser = options.get_generation_parser(default_task='simple_classification')
args = options.parse_args_and_arch(parser)

# Setup task
task = tasks.setup_task(args)

# Load model
print('| loading model from {}'.format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble([args.path], task=task)
model = models[0]

while True:
    sentence = input('\nInput: ')

    # Tokenize into characters
    chars = ' '.join(list(sentence.strip()))
    tokens = task.source_dictionary.encode_line(
        chars, add_if_not_exist=False,

    # Build mini-batch to feed to the model
    batch = data.language_pair_dataset.collate(
        samples=[{'id': -1, 'source': tokens}],  # bsz = 1

    # Feed batch to the model and get predictions
    preds = model(**batch['net_input'])

    # Print top 3 predictions and their log-probabilities
    top_scores, top_labels = preds[0].topk(k=3)
    for score, label_idx in zip(top_scores, top_labels):
        label_name = task.target_dictionary.string([label_idx])
        print('({:.2f})\t{}'.format(score, label_name))

Now we can evaluate our model interactively. Note that we have included the original data path (names-bin/) so that the dictionaries can be loaded:

> python eval_classifier.py names-bin --path checkpoints/checkpoint_best.pt
| [input] dictionary: 64 types
| [label] dictionary: 24 types
| loading model from checkpoints/checkpoint_best.pt

Input: Satoshi
(-0.61) Japanese
(-1.20) Arabic
(-2.86) Italian

Input: Sinbad
(-0.30) Arabic
(-1.76) English
(-4.08) Russian