Language Model Pre-training Improves Generalization in Policy Learning

Shuang Li1      Xavier Puig1      Yilun Du1      Ekin Akyürek1      Antonio Torralba1      Jacob Andreas1      Igor Mordatch2
1MIT CSAIL      2Google Brain

Paper      Github


Language model (LM) pre-training has proven useful for a wide variety of language processing tasks, including tasks that require nontrivial planning and reasoning capabilities. Can these capabilities be leveraged for more general machine learning problems? We investigate the effectiveness of LM pretraining to scaffold learning and generalization in autonomous decision-making. We use a pre-trained GPT-2 LM to initialize an interactive policy, which we fine-tune via imitation learning to perform interactive tasks in a simulated household environment featuring partial observability, large action spaces, and long time horizons. To leverage pre-training, we first encode observations, goals, and history information as templated English strings, and train the policy to predict the next action. We find that this form of pre-training enables generalization in policy learning: for test tasks involving novel goals or environment states, initializing policies with language models improves task completion rates by nearly 20%. Additional experiments explore the role of language-based encodings in these results; we find that it is possible to train a simple adapter layer that maps from observations and action histories to LM embeddings, and thus that language modeling provides an effective initializer even for tasks with no language as input or output. Together, these results suggest that language modeling induces representations that are useful for modeling not just language, but natural goals and plans; these representations can aid learning and generalization even outside of language processing.


Shuang Li, Xavier Puig, Yilun Du, Ekin Akyürek, Antonio Torralba, Jacob Andreas, and Igor Mordatch
Language Model Pre-training Improves Generalization in Policy Learning
arxiv 2021 [Paper] [Code] [BibTex]

Policy generated by pre-trained language model for a given household task

The policy learned by fine-tuning the pre-trained language model successfully finishes the task described in the goal predicates. We highlight the key actions in the map, where the agent is finding, grabbing, or placing objects in the target positions.

Quantitative results

Experiment 1. Can LMs be used to initialize policies if state and action information is presented in a format that looks like a standard language modeling problem.

To do so, we encode the inputs to the policy---including observations, goals, and action histories---as templated English phrases.

"LM (ft) (Ours)" is the proposed model. "MLP-N", "MLP-1", and "LSTM" are baselines without using transformer. "LM (scratch) w/o Hist" and "LM (ft) w/o Hist" are based on the transformer architecture but do not use history in the input for decision making. "LM (scratch)" and "LM (ft) (Ours)" are based transformer and uses history in the input. The "scratch" means the transformer is trained from scratch on our data while "ft" means the transformer is pre-trained on language tasks and then fine-tuned on our data.

Experiment 2A. When using the pre-trained models' own string encoding mechanism, how important is it that strings passed as input resemble the training data?

To evaluate this question, we replace the "natural language" tokens (e.g. serializing the goal "ON(fork, table)" as "put one fork on the table") with random ones (e.g. serializing "ON(fork, table)" as "brought wise character trees fine yet").

"LM (scratch)" and "LM (ft)" are the pre-trained language model using natural strings as input while "LM (ft) (Random)" uses random strings as input. In "LM (scratch)", the language model is trained from scratch on the collected data while "LM (ft)" and "LM (ft) (random)" fine-tune the pre-trained language model on the collected data.

Experiment 2B. Given a non-linguistic task, if an effective string-based encoding cannot be generated arbitrarily, can such an encoding at least be learned?

To answer this question, we retain the discrete, serial format of the goal, history, and observation representation, but replace the embedding layer from the pre-trained language model with a new embedding layer trained from scratch.

"LM (ft) (Ours)" uses the pre-trained language encodings. In "rep Goal", we use the learned encoding for goal and the pre-trained language encodings for history and observation. Similarly, "rep Hist" and "rep Obs" use the the learned encoding for history and observation, respectively. "rep Goal-Hist-Obs" uses the learned encoding for goal, history, and observation.

Qualitative Results

1. Policies with pretrained language model on different test settings

2. Policy trained from scratch vs. Policy with pretrained language model

3. Policy with LSTM vs. Policy with pretrained language model