Blog
The latest news from Google AI
Capturing Special Video Moments with Google Photos
Wednesday, April 3, 2019
Posted by Sudheendra Vijayanarasimhan and David Ross, Software Engineers
Recording video of memorable moments to share with friends and loved ones has become commonplace. But as anyone with a sizable video library can tell you, it's a time consuming task to go through all that raw footage searching for the perfect clips to relive or share with family and friends. Google Photos makes this easier by automatically finding magical moments in your videos—like when your child blows out the candle or when your friend jumps into a pool—and creating animations from them that you can easily share with friends and family.
In "
Rethinking the Faster R-CNN Architecture for Temporal Action Localization
", we address some of the challenges behind automating this task, which are due to the complexity of identifying and categorizing actions from a highly variable array of input data, by introducing an improved method to identify the exact location within a video where a given action occurs. Our
temporal action localization network
(TALNet) draws inspiration from advances in region-based object detection methods such as the
Faster R-CNN
network. TALNet enables identification of moments with large variation in duration, achieving state-of-the-art performance compared to other methods, allowing Google Photos to recommend the best part of a video for you to share with friends and family.
An example of the detected action "blowing out candles"
Identifying Actions for Model Training
The first step in identifying magic moments in videos is to assemble a list of actions that people might wish to highlight. Some examples of actions include "blow out birthday candles", "strike (bowling)", "cat wags tail", etc. We then crowdsourced the annotation of segments within a collection of public videos where these specific actions occurred, in order to create a large training dataset. We asked the raters to find and label all moments, accommodating videos that might have several moments. This final annotated dataset was then used to train our model so that it could identify the desired actions in new, unknown videos.
Comparison to Object Detection
The challenge of recognizing these actions belongs to the field of computer vision known as
temporal action localization,
which, like the more familiar
object detection
, falls under the umbrella of visual detection problems. Given a long, untrimmed video as input, temporal action localization aims to identify the start and end times, as well as the action label (like "blowing out candles"), for each action instance in the full video. While object detection aims to produce spatial bounding boxes around an object in a 2D image, temporal action localization aims to produce temporal segments including an action in a 1D sequence of video frames.
Our approach to TALNet is inspired by the
faster R-CNN
object detection framework for 2D images. So, to understand TALNet, it is useful to first understand faster R-CNN. The figure below demonstrates how the faster R-CNN architecture is used for object detection. The first step is to generate a set of object proposals, regions of the image that can be used for classification. To do this, an input image is first converted into a 2D
feature
map by a
convolutional neural network
(CNN). The
region proposal network
then generates bounding boxes around candidate objects. These boxes are generated at multiple scales in order to capture the large variability in objects' sizes in natural images. With the object proposals now defined, the subjects in the bounding boxes are then classified by a deep neural network (DNN) into specific objects, such as "person", "bike", etc.
Faster R-CNN architecture for object detection
Temporal Action Localization
Temporal action localization is accomplished in a fashion similar to that used by R-CNN. A sequence of input frames from a video are first converted into a sequence of 1D feature maps that encode scene context. This map is passed to a
segment proposal network
that generates candidate segments, each defined by start and end times. A DNN then applies the representations learned from the training dataset to classify the actions in the proposed video segments (e.g., "slam dunk", "pass", etc.). The actions identified in each segment are given weights according to their learned representations, with the top scoring moment selected to share with the user.
Architecture for temporal action localization
Special Considerations for Temporal Action Localization
While temporal action localization can be viewed as the 1D counterpart of the object detection problem, care must be taken to address a number of issues unique to action localization. In particular, we address three specific issues in order to apply the Faster R-CNN approach to the action localization domain, and redesign the architecture to specifically address them.
Actions have much larger variations in durations
The temporal extent of actions varies dramatically—from a fraction of a second to minutes. For long actions, it is not important to understand each and every frame of the action. Instead, we can get a better handle on the action by skimming quickly through the video, using
dilated temporal convolutions
. This approach allows TALNet to search the video for temporal patterns, while skipping over alternate frames based on a given dilation rate. Analysing the video with several different rates that are selected automatically according to the anchor segment's length enables efficient identification of actions as large as the entire video or as short as a second.
The context before and after an action are important
The moments preceding and following an action instance contain critical information for localization and classification, arguably more so than the spatial context of an object. Therefore, we explicitly encode the temporal context by extending the length of proposal segments on both the left and right by a fixed percentage of the segment's length in both the proposal generation stage and the classification stage.
Actions require multi-modal input
Actions are defined by appearance, motion and sometimes even audio information. Therefore, it is important to consider multiple modalities of features for the best results. We use a late fusion scheme for both the proposal generation network and the classification network, in which each modality has a separate proposal generation network whose outputs are combined together to obtain the final set of proposals. These proposals are classified using separate classification networks for each modality, which are then averaged to obtain the final predictions.
TALNet in Action
As a consequence of these improvements, TALNet achieves state-of-the-art performance for both action proposal and action localization tasks on the
THUMOS'14
detection benchmark and competitive performance on the
ActivityNet
challenge. Now, whenever people save videos to Google Photos, our model identifies these moments and creates animations to share. Here are a few examples shared by our initial testers.
An example of the detected action "sliding down a slide"
An example of the detected actions "jump into the pool" (left), "twirl in a dress" (center) and "feed baby a spoonful" (right).
Next steps
We are continuing work to improve the
precision and recall
of action localization using more data, features and models. Improvements in temporal action localization can drive progress on a large number of important topics ranging from video highlights, video summarization, search and more. We hope to continue improving the state-of-the-art in this domain and at the same time provide more ways for people to reminisce on their memories, big and small.
Acknowledgements
Special thanks Tim Novikoff and Yu-Wei Chao, as well as Bryan Seybold, Lily Kharevych, Siyu Gu, Tracy Gu, Tracy Utley, Yael Marzan, Jingyu Cui, Balakrishnan Varadarajan, Paul Natsev for their critical contributions to this project.
Using Deep Learning to Improve Usability on Mobile Devices
Tuesday, April 2, 2019
Posted by Yang Li, Research Scientist, Google AI
Tapping is the most commonly used gesture on mobile interfaces, and is used to trigger all kinds of actions ranging from launching an app to entering text. While the style of clickable elements (e.g., buttons) in traditional desktop
graphical user interfaces
is often conventionally defined, on mobile interfaces it can still be difficult for people to distinguish tappable versus non-tappable elements due to the diversity of styles. This confusion can lead to
false affordances
(e.g., a feature that could be mistaken for a button) and a lack of discoverability that can lead to user frustration, uncertainty, and errors. To avoid this, interface designers can conduct a study or a visual affordance test to help clarify the tappability of items in their interfaces. However, such studies are time-consuming and their findings are often limited to a specific app or interface design.
In our
CHI'19
paper, "
Modeling Mobile Interface Tappability Using Crowdsourcing and Deep Learning
", we introduced an approach for modeling the usability of mobile interfaces at scale. We crowdsourced a task to study UI elements across a range of mobile apps to measure the perceived tappability by a user. Our model predictions were consistent with the user group at the ~90% level, demonstrating that a machine learning model can be effectively used to estimate the perceived tappability of interface elements in their design without the need for expensive and time consuming user testing.
Predicting Tappability with Deep Learning
Designers often use visual properties such as the color or depth of an element to signify its availability for interaction on interfaces, e.g.,
the blue color and underline of a link
. While these common signifiers are useful, it is not always clear when to apply them in each specific design setting. Furthermore, with design trends evolving, traditional signifiers are constantly being altered and challenged, potentially causing user uncertainty and mistakes.
To understand how users perceive this changing landscape, we analyzed the potential signifiers affecting tappability in real mobile apps—element
type
(e.g., check boxes, text boxes, etc.),
location
,
size
,
color
, and
words
. We started by crowdsourcing volunteers to label the perceived clickability of ~20,000 unique interface elements from ~3,500 apps. With the exception of text boxes,
type
signifiers yielded low uncertainty in user perceived tappability. The
location
signifier refers to the position of a feature on the screen and is informed by the common layout design in mobile apps, as demonstrated in the figure below.
Heatmaps displaying the accuracy of tappable and non-tappable elements by location, where warmer colors represent areas of higher accuracy. Users labeled non-tappable elements more accurately towards the upper center of the interface, and tappable elements towards the bottom center of the interface.
The impact of element
size
was relatively weak, but did indicate confusion in the case of large non-tappable elements. Users showed a tendency to bright
colors
and short
word
counts for tappable elements, though
word
semantics also played a significant role.
We used these labels to train a simple
deep neural network
that predicts the likelihood that a user will perceive an interface element as tappable versus non-tappable. For a given element of the interface, the model uses a range of features, including the spatial context of the element on the screen (
location
), the semantics and functionality of the element (
words
and
type
), and the visual appearance (
size
as well as raw pixels). The neural network model applies a
convolutional neural network
(CNN) to extract features from raw pixels, and uses learned semantic embeddings to represent text content and element properties. The concatenation of all these features are then fed to a fully-connected network layer, the output of which produces a binary classification of an element's tappability.
Evaluation of the Model
The model allowed us to automatically diagnose mismatches between the tappability of each interface element as perceived by a user—predicted by our model—and the intended or actual tappable state of the element specified by the developer or designer. In the example below, our model predicts that there is a 73% chance that a user would think the labels such as "Followers" or "Following" are tappable, while these interface elements are in fact not programmed to be tappable.
To understand how our model behaves compared to human users, particularly when there is ambiguity in human perception, we generated a second, independent dataset by crowdsourcing an effort among 290 volunteers to label each of 2,000 unique interface elements with respect to their perceived tappability. Each element was labeled independently by five different users. We found that more than 40% of the elements in our sample were labeled inconsistently by volunteers. Our model matches this uncertainty in human perception quite well, as demonstrated in the figure below.
The scatterplot of the tappability probability predicted by the model (the Y axis) versus the consistency in the human user labels (the X axis) for each element in the consistency dataset.
When users agree an element's tappability, our model tends to give a more definite answer—a probability close to 1 for tappable and close to 0 for not tappable. When workers are less consistent on an element (towards the middle of the X axis), our model is also less certain about the decision. Overall, our model achieved reasonable accuracy of matching human perception in identifying tappable UI elements with a mean
precision
of 90.2% and
recall
of 87.0%.
Predicting tappability is merely one example of what we can do with machine learning to solve usability issues in user interfaces. There are many other challenges in interaction design and user experience research where deep learning models can offer a vehicle to distill large, diverse user experience datasets and advance scientific understandings about interaction behaviors.
Acknowledgements
This research was a joint work of Amanda Swangson, summer intern at Google, and Yang Li, a Research Scientist in Deep Learning and Human Computer Interaction.
Unifying Physics and Deep Learning with TossingBot
Tuesday, March 26, 2019
Posted by Andy Zeng, Student Researcher, Robotics at Google
Though considerable progress has been made in enabling robots to
grasp objects efficiently
,
visually self adapt
or even
learn from real-world experiences
, robotic operations still require careful consideration in how they pick up, handle, and place various objects -- especially in unstructured settings. Consider for example, this
picking robot
which took 1st place in the stowing task of the Amazon Robotics Challenge:
It's an impressive system, built with many design features that kinematically prevent it from dropping objects due to unforeseen dynamics: from its steady and deliberate movements, to its gripper fingers that mechanically constrain the momentum of the object so that it doesn't slip.
This robot, like many others, is designed to tolerate the dynamics of the unstructured world. But instead of just tolerating dynamics, can robots learn to use them advantageously, developing an "intuition" of physics that would allow them to complete tasks more efficiently? Perhaps in doing so, robots can improve their capabilities and acquire complex athletic skills like tossing, sliding, spinning, swinging, or catching, potentially leading to many useful applications, such as more efficient debris clearing robots in disaster response scenarios -- where time is of the essence.
To explore this concept, we worked with researchers at Princeton, Columbia, and MIT to develop TossingBot: a picking robot for our real, random world that learns to grasp and throw objects into selected boxes outside its natural range. We find that by learning to throw, TossingBot is capable of achieving picking speeds that are twice as fast as
previous systems
, with twice the effective placing range. TossingBot jointly learns grasping and throwing policies using an end-to-end neural network that maps from visual observations (RGB-D images) to control parameters for motion primitives. Using overhead cameras to track where objects land, TossingBot improves itself over time through self-supervision. More technical details are available in an early
preprint
on arXiv.
The Challenges
Throwing is a particularly difficult task as it depends on many factors: from how the object is picked up (i.e., "pre-throw conditions"), to the object's physical properties like mass, friction, aerodynamics, etc. For example, if you grasp a screwdriver by the handle near the center of mass and throw it, it would land much closer than if you had grasped it from the metal tip, which would swing forward and land much farther away. Regardless of how you grasped it though, tossing a screwdriver is incredibly different from tossing a ping pong ball, which would land closer due to air resistance. Manually designing a solution that explicitly handles these factors for every random object is nearly impossible.
Throwing depends on many factors: from how you picked it up, to object properties and dynamics.
Through deep learning, however, our robots can learn from experience rather than rely on manual case-by-case engineering. Previously we've shown that our robots can learn to
push
and
grasp
a large variety of objects, but accurately throwing objects requires a larger understanding of projectile physics. Acquiring this knowledge from scratch with only trial-and-error is not only time consuming and expensive, but also generally doesn't work outside of very specific, and carefully set up training scenarios.
Unifying Physics and Deep Learning
A fundamental component of TossingBot is that it learns to throw by integrating simple physics and deep learning, which enables it to train quickly and generalize to new scenarios. Physics provides prior models of how the world works, and we can leverage these models to develop initial controllers for our robots. In the case of throwing, for example, we can use projectile ballistics to provide an estimate for the throwing velocity that is needed to get an object to land at a target location. We can then use neural networks to predict adjustments on top of that estimate from physics, in order to compensate for unknown dynamics as well as the noise and variability of the real world. We call this hybrid formulation Residual Physics, and it enables TossingBot to achieve throwing accuracies of 85%.
At the start of training with randomly initialized weights, TossingBot repeatedly attempts bad grasps. Over time, however, TossingBot learns better ways to grasp objects and simultaneously improves its ability to throw. Occasionally the robot randomly explores what happens if it throws an object at a velocity that it hasn't tried before. When the bin is emptied, TossingBot lifts the boxes to allow objects to slide back into the bin. This way, human intervention is kept at a minimum during training. By 10,000 grasp and throw attempts (or 14 hours of training time), it is capable of achieving throwing accuracies of 85%, with a grasping reliability of 87% in clutter.
TossingBot starts out performing poorly (left), but progressively learns to grasp and toss overnight (right).
Generalizing to New Scenarios
By integrating physics and deep learning, TossingBot is capable of rapidly adapting to never-before-seen throwing locations and objects. For example, after training on objects with simple shapes like wooden blocks, balls, and markers, it can perform reasonably well on new objects such as fake fruit, decorative items, and office objects. On new objects, TossingBot starts out with lower performance, but quickly adapts within a few hundred training steps (i.e., an hour or two) to achieve similar performance as with training objects. We've found that combining physics and deep learning with Residual Physics yields better performance than baseline alternatives (e.g. deep learning without physics). We even tried this task ourselves, and we were pleasantly surprised to learn that TossingBot is more accurate than any of us engineers! Though take that with a grain of salt, as we've yet to test TossingBot against anyone with any
actual
athletic talent.
TossingBot can generalize to new objects, and is more accurate at throwing than the average Googler.
We also test our policies on their ability to generalize to new target locations previously unseen in training. To this end, we train on a set of boxes, then later test on a different set of boxes with entirely different landing areas. In this setting, we find that Residual Physics for throwing helps significantly, since the initial estimates of throwing velocities from projectile ballistics easily generalize to new target locations, while the residuals help make adjustments on top of those estimates to compensate for varying object properties in the real world. This is in contrast to the baseline alternative of using deep learning without physics, which can only handle target locations seen during training.
TossingBot uses Residual Physics to throw objects to unforeseen locations.
Emerging Semantics from Interaction
To explore what TossingBot learns, we place several objects in the bin, capture images, and feed them into TossingBot's trained neural network to extract intermediate pixel-wise deep features. By clustering these features based on similarity and visualizing nearest neighbors as a heatmap (hotter regions indicate more similarity in feature space), we can localize all ping pong balls in the scene. Even though the orange block shares a similar color with the ping pong balls, its features are different enough for TossingBot to make a distinction. Likewise, we can also use the extracted features to localize all marker pens, which share similar shape and mass, but do not share color. These observations suggest that TossingBot likely learns to rely more on geometric cues (e.g. shape) to learn grasping and throwing. It is also possible that the learned features reflect second-order attributes such as physical properties, which can influence how the objects should be thrown.
TossingBot learns deep features that distinguish object categories without explicit supervision.
These emerging features were learned implicitly from scratch without any explicit supervision beyond task-level grasping and throwing. Yet, they seem to be sufficient for enabling the system to distinguish between object categories (i.e., ping pong balls and marker pens). As such, this experiment speaks out to a broader concept related to machine vision: how should robots learn the semantics of the visual world? From the perspective of classic computer vision, semantics are often pre-defined using human-fabricated image datasets and manually constructed class categories. However, our experiment suggests that it is possible to implicitly learn such object-level semantics from physical interactions alone, as long as they matter for the task at hand. The more complex these interactions, the higher the resolution of the semantics. Towards more generally intelligent robots -- perhaps it is sufficient for them to develop their own notion of semantics through interaction, without requiring any human intervention.
Limitations and Future Work
Although TossingBot's results are promising, it does have its limitations. For example, it assumes that objects are robust enough to withstand landing collisions after being thrown -- further work is required to learn throws that account for fragile objects, or possibly train other robots to catch objects in ways that cushion the landing. Furthermore, TossingBot infers control parameters only from visual data -- exploring additional senses (e.g. force-torque or tactile) may enable the system to better react to new objects.
The combination of physics and deep learning that made TossingBot possible naturally leads to an interesting question: what else could benefit from Residual Physics? Investigating how the idea generalizes to other types of tasks and interactions is a promising direction for future research.
You can learn more about this work in the summary video below.
Acknowledgements
This research was done by Andy Zeng, Shuran Song (faculty at Columbia University), Johnny Lee, Alberto Rodriguez (faculty at MIT), and Thomas Funkhouser (faculty at Princeton University), with special thanks to Ryan Hickman for valuable managerial support, Ivan Krasin and Stefan Welker for fruitful technical discussions, Brandon Hurd and Julian Salazar and Sean Snyder for hardware support, Chad Richards and Jason Freidenfelds for helpful feedback on writing, Erwin Coumans for advice on PyBullet, Laura Graesser for video narration, and Regina Hickman for photography. An early
preprint
is available on arXiv.
Simulated Policy Learning in Video Models
Monday, March 25, 2019
Posted by Łukasz Kaiser and Dumitru Erhan, Research Scientists, Google AI
Deep
reinforcement learning
(RL) techniques can be used to learn policies for complex tasks from visual inputs, and have been applied with great success to
classic Atari 2600 games
.
Recent work
in this field has shown that it is possible to get super-human performance in many of them, even in challenging
exploration regimes
such as that exhibited by
Montezuma's Revenge
. However, one of the limitations of many state-of-the-art approaches is that they require a very large number of interactions with the game environment, often much larger than what people would need to learn to play well. One plausible hypothesis explaining why people learn these tasks so much more efficiently is that they are able to predict the effect of their own actions, and thus implicitly learn a model of which action sequences will lead to desirable outcomes. This general idea—building a so-called model of the game and using it to learn a good policy for selecting actions—is the main premise of model-based reinforcement learning (MBRL).
In "
Model-Based Reinforcement Learning for Atari
", we introduce the Simulated Policy Learning (SimPLe) algorithm, an MBRL framework to train agents for Atari gameplay that is significantly more efficient than current state-of-the-art techniques, and shows competitive results using only ~100K interactions with the game environment (equivalent to roughly two hours of real-time play by a person). In addition, we have open sourced our
code
as part of the
tensor2tensor
open source library. The release contains a pretrained world model that can be run with a simple command line and that can be played using an Atari-like interface.
Learning a SimPLe World Model
At a high-level, the idea behind SimPLe is to alternate between learning a world model of how the game behaves and using that model to optimize a policy (with model-free reinforcement learning) within the simulated game environment. The basic principles behind this algorithm are
well established
and have been employed in numerous
recent
model-based
reinforcement
learning
methods
.
Main loop of SimPLe. 1) The agent starts interacting with the real environment. 2) The collected observations are used to update the current world model. 3) The agent updates the policy by learning inside the world model.
To train an Atari game playing model we first need to generate plausible versions of the future in
pixel space
. In other words, we seek to predict what the next frame will look like, by taking as input a sequence of already observed frames and the commands given to the game, such as "left", "right", etc. One of the important reasons for training a world model in observation space is that it is, in effect, a form of self-supervision, where the observations—pixels, in our case—form a dense and rich supervision signal.
If successful in training such a model (e.g. a video predictor), one essentially has a learned simulator of the game environment that can be used to generate
trajectories
for training a good policy for a gaming agent, i.e. choosing a sequence of actions such that long-term reward of the agent is maximized. In other words, instead of having the policy be trained on sequences from the
real
game, which is prohibitively intensive in both time and computation, we train the policy on sequences coming from the world model / learned simulator.
Our world model is a
feedforward convolutional network
that takes in four frames and predicts the next frame as well as the reward (see figure above). However, in the case of Atari, the future is non-deterministic given only a horizon of the previous four frames. For example, a pause in the game longer than four frames, such as when the ball falls out of the frame in
Pong
, can lead to a failure of the model to predict subsequent frames successfully. We handle stochasticity problems such as these with a new video model architecture that does much better in this setting, inspired by
previous
work
.
One example of an issue arising from stochasticity is seen when the SimPLe model is applied to
Kung Fu Master
. In the animation, the left is the output of the model, the middle is the groundtruth, and the right panel is the pixel-wise difference between the two. Here the model's predictions deviate from the real game by spawning a different number of opponents.
At each iteration, after the world model is trained, we use this learned simulator to generate rollouts (i.e. sample sequences of actions, observations and outcomes) that are used to improve the game playing policy using the
Proximal Policy Optimization (PPO)
algorithm. One important detail for making SimPLe work is that the sampling of rollouts starts from the real dataset frames. Because prediction errors typically compound over time and make long-term predictions very difficult, SimPLe only uses medium-length rollouts. Luckily, the PPO algorithm can learn long-term effects between actions and rewards from its internal value function too, so rollouts of limited length are sufficient even for games with sparse rewards like
Freeway
.
SimPLe Efficiency
One measure of success is to demonstrate that the model is highly efficient. For this, we evaluated the output of our policies after 100K interactions with the environment, which corresponds to roughly two hours of real-time game play by a person. We compare our SimPLe method with two state of the art model-free RL methods,
Rainbow
and PPO, applied to 26 different games. In most cases, the SimPLe approach has a sample efficiency more than 2x better than the other methods.
The number of interactions needed by the respective model-free algorithms (left - Rainbow; right - PPO) to match the score achieved using our SimPLe training method. The red line indicates the number of interactions used by our method.
SimPLe Success
An exciting result of the SimPLe approach is that for two of the games, Pong and
Freeway
, an agent trained in the simulated environment is able to achieve the maximum score. Here is a video of our agent playing the game using the game model that we learned for Pong:
For Freeway, Pong and
Breakout
, SimPLe can generate nearly pixel-perfect predictions up to 50 steps into the future, as shown below.
Nearly pixel perfect predictions can be made by SimPLe, on Breakout (top) and Freeway (bottom). In each animation, the left is the output of the model, the middle is the groundtruth, and the right pane is the pixel-wise difference between the two.
SimPLe Surprises
SimPLe does not always make correct predictions, however. The most common failure is due to the world model not accurately capturing or predicting small but highly relevant objects. Some examples are: (1) in
Atlantis
and
Battlezone
bullets are so small that they tend to disappear, and (2)
Private Eye
, in which the agent traverses different scenes, teleporting from one to the other. We found that our model generally struggled to capture such large global changes.
In Battlezone, we find the model struggles with predicting small, relevant parts, such as the bullet.
Conclusion
The main promise of model-based reinforcement learning methods is in environments where interactions are either costly, slow or require human labeling, such as many robotics tasks. In such environments, a learned simulator would enable a better understanding of the agent's environment and could lead to new, better and faster ways for doing multi-task reinforcement learning. While SimPLe does not yet match the performance of standard model-free RL methods, it is substantially more efficient, and we expect future work to further improve the performance of model-based techniques.
If you'd like to develop your own models and experiments, head to our
repository
and
colab
where you'll find instructions on how to reproduce our work along with pre-trained world models.
Acknowledgements
This work was done in collaboration with the University of Illinois at Urbana-Champaign, the University of Warsaw and deepsense.ai. We would like to give special recognition to paper co-authors Mohammad Babaeizadeh, Piotr Miłos, Błażej Osiński, Roy H Campbell, Konrad Czechowski, Chelsea Finn, Piotr Kozakowski, Sergey Levine, Ryan Sepassi, George Tucker and Henryk Michalewski.
Reducing the Need for Labeled Data in Generative Adversarial Networks
Wednesday, March 20, 2019
Posted by Mario Lučić, Research Scientist and Marvin Ritter, Software Engineer, Google AI Zürich
Generative adversarial networks
(GANs) are a powerful class of deep generative models.The main idea behind
GANs
is to train two neural networks:
the generator,
which learns how to synthesise data (such as an image), and the
discriminator,
which learns how to distinguish real data from the ones synthesised by the generator. This approach has been successfully used for
high-fidelity natural image synthesis
,
improving learned image compression
, data augmentation, and more.
Evolution of the generated samples as training progresses on ImageNet. The generator network is conditioned on the class (e.g., "great gray owl" or "golden retriever").
For natural image synthesis, state-of-the-art results are achieved by
conditional GANs
that, unlike unconditional GANs, use labels (e.g. car, dog, etc.) during training. While this makes the task easier and leads to significant improvements, this approach requires a large amount of labeled data that is rarely available in practice.
In "
High-Fidelity Image Generation With Fewer Labels
", we propose a new approach to reduce the amount of labeled data required to train state-of-the-art conditional GANs. When combined with recent advancements on large-scale GANs, we match the state-of-the-art in high-fidelity natural image synthesis using
10x fewer labels
. Based on this research, we are also releasing a major update to the
Compare GAN library
, which contains all the components necessary to train and evaluate modern GANs.
Improvements via Semi-supervision and Self-supervision
In conditional GANs, both the generator and discriminator are typically conditioned on class labels. In this work, we propose to replace the hand-annotated ground truth labels with inferred ones. To infer high-quality labels for a large dataset of mostly unlabeled data, we take a two-step approach: First, we learn a feature representation using only the unlabeled portion of the dataset. To learn the feature representations we make use of
self-supervision
in the form of a
recently introduced
approach, in which the unlabeled images are randomly rotated and a deep convolutional neural network is tasked with predicting the rotation angle. The idea is that the models need to be able to recognize the main objects and their shapes in order to be successful on this task.
An unlabeled image is randomly rotated and the network is tasked with predicting the rotation angle. Successful models need to capture semantically meaningful image features which can then be used for other vision tasks.
We then consider the activation pattern of one of the intermediate layers of the trained network as the new feature representation of the input, and train a classifier to recognize the label of that input using the labeled portion of the original data set. As the network was pre-trained to extract semantically meaningful features from the data (on the rotation prediction task), training this classifier is more sample-efficient than training the entire network from scratch. Finally, we use this classifier to label the unlabeled data.
To further improve the model quality and training stability we encourage the discriminator network to learn meaningful feature representations which are not forgotten during training by means of an auxiliary loss we introduced
previously
. These two advancements, combined with large-scale training lead to state-of-the-art conditional GANs for the task of ImageNet synthesis as measured by the
Fréchet Inception Distance
.
Given a latent vector the generator network produces an image. In each row, linear interpolation between the latent codes of the leftmost and the rightmost image results in a semantic interpolation in the image space.
Compare GAN: A Library for Training and Evaluating GANs
Cutting-edge research on GANs is heavily dependent on a well-engineered and well-tested codebase, since even replicating prior results and techniques requires a significant effort. In order to foster open science and allow the research community benefit from recent advancements, we are releasing a major update of the
Compare GAN
library. The library includes loss functions, regularization and normalization schemes, neural architectures, and quantitative metrics commonly used in modern GANs, and now supports:
Training on GPUs and TPUs.
Lightweight configuration via
Gin
(
examples
).
A plethora of data sets via the
TensorFlow datasets
library.
Conclusions and Future Work
Given the growing gap between labeled and unlabeled data sources, it is becoming
increasingly important
to be able to learn from only partially labeled data. We have shown that a simple yet powerful combination of self-supervision and semi-supervision can help to close this gap for GANs. We believe that self-supervision is a powerful idea that should be investigated for other generative modeling tasks.
Acknowledgments
Work conducted in collaboration with colleagues on the Google Brain team in Zürich, ETH Zürich and UCLA. We would like to thank our paper co-authors Michael Tschannen, Xiaohua Zhai, Olivier Bachem and Sylvain Gelly for their input and feedback. We would like to thank Alexander Kolesnikov, Lucas Beyer and Avital Oliver for helpful discussion on
self-supervised learning
and semi-supervised learning. We would like to thank Karol Kurach and Marcin Michalski for their major contributions to the Compare GAN library. We would also like to thank Andy Brock, Jeff Donahue and Karen Simonyan for their insights into training GANs on TPUs. The work described in this post also builds upon our work on “
Self-Supervised Generative Adversarial Networks
” with Ting Chen and Neil Houlsby.
Measuring the Limits of Data Parallel Training for Neural Networks
Tuesday, March 19, 2019
Posted by Chris Shallue, Senior Software Engineer and George Dahl, Senior Research Scientist, Google AI
Over the past decade, neural networks have achieved state-of-the-art results in a wide variety of prediction tasks, including
image classification
,
machine translation
, and
speech recognition
. These successes have been driven, at least in part, by hardware and software improvements that have significantly accelerated neural network training. Faster training has directly resulted in dramatic improvements to model quality, both by allowing more training data to be processed and by allowing researchers to try new ideas and configurations more rapidly. Today, hardware developments like
Cloud TPU Pods
are rapidly increasing the amount of computation available for neural network training, which raises the possibility of harnessing additional computation to make neural networks train even faster and facilitate even greater improvements to model quality. But how exactly should we harness this unprecedented amount of computation, and should we always expect more computation to facilitate faster training?
The most common way to utilize massive compute power is to
distribute computations
between different processors and perform those computations simultaneously. When training neural networks, the primary ways to achieve this are
model parallelism
, which involves distributing the neural network across different processors, and
data parallelism
, which involves distributing training examples across different processors and computing updates to the neural network in parallel. While model parallelism makes it possible to train neural networks that are larger than a single processor can support, it usually requires tailoring the model architecture to the available hardware. In contrast, data parallelism is model agnostic and applicable to any neural network architecture – it is the simplest and most widely used technique for parallelizing neural network training. For the most common neural network training algorithms (synchronous
stochastic gradient descent
and its variants), the scale of data parallelism corresponds to the
batch size
, the number of training examples used to compute each update to the neural network. But what are the limits of this type of parallelization, and when should we expect to see large speedups?
In "
Measuring the Effects of Data Parallelism in Neural Network Training
", we investigate the relationship between batch size and training time by running experiments on six different types of neural networks across seven different datasets using three different optimization algorithms ("optimizers"). In total, we trained over 100K individual models across ~450 workloads, and observed a seemingly universal relationship between batch size and training time across all workloads we tested. We also study how this relationship varies with the dataset, neural network architecture, and optimizer, and found extremely large variation between workloads. Additionally, we are excited to
share our raw data
for further analysis by the research community. The data includes over 71M model evaluations to make up the training curves of all 100K+ individual models we trained, and can be used to reproduce all 24 plots in
our paper
.
Universal Relationship Between Batch Size and Training Time
In an idealized data parallel system that spends negligible time synchronizing between processors, training time can be measured in the number of training
steps
(updates to the neural network's parameters). Under this assumption, we observed three distinct scaling regimes in the relationship between batch size and training time: a "perfect scaling" regime where doubling the batch size halves the number of training steps required to reach a target
out-of-sample error
, followed by a regime of "diminishing returns", and finally a "maximal data parallelism" regime where further increasing the batch size does not reduce training time, even assuming idealized hardware.
For all workloads we tested, we observed a universal relationship between batch size and training speed with three distinct regimes: perfect scaling (following the dashed line), diminishing returns (diverging from the dashed line), and maximal data parallelism (where the trend plateaus). The transition points between the regimes vary dramatically between different workloads.
Although the basic relationship between batch size and training time appears to be universal, we found that the transition points between the different scaling regimes vary dramatically across neural network architectures and datasets. This means that while simple data parallelism can provide large speedups for some workloads at the limits of today's hardware (e.g.
Cloud TPU Pods
), and perhaps beyond, some workloads require moving beyond simple data parallelism in order to benefit from the largest scale hardware that exists today, let alone hardware that has yet to be built. For example, in the plot above,
ResNet-8
on
CIFAR-10
cannot benefit from batch sizes larger than 1,024, whereas
ResNet-50
on
ImageNet
continues to benefit from increasing the batch size up to at least 65,536.
Optimizing Workloads
If one could predict which workloads benefit most from data parallel training, then one could tailor their workloads to make maximal use of the available hardware. However, our results suggest that this will often not be straightforward, because the maximum useful batch size depends, at least somewhat, on every aspect of the workload: the neural network architecture, the dataset, and the optimizer. For example, some neural network architectures can benefit from much larger batch sizes than others, even when trained on the same dataset with the same optimizer. Although this effect sometimes depends on the width and depth of the network, it is inconsistent between different types of network and some networks do not even have obvious notions of "width" and "depth". And while we found that some datasets can benefit from much larger batch sizes than others, these differences are not always explained by the size of the dataset—sometimes smaller datasets benefit more from larger batch sizes than larger datasets.
Left:
A
transformer
neural network scales to much larger batch sizes than an
LSTM
neural network on the
LM1B
dataset.
Right:
The
Common Crawl
dataset does not benefit from larger batch sizes than the LM1B dataset, even though it is 1,000 times the size.
Perhaps our most promising finding is that even small changes to the optimization algorithm, such as allowing
momentum
in stochastic gradient descent, can dramatically improve how well training scales with increasing batch size. This raises the possibility of designing new optimizers, or testing the scaling properties of optimizers that we did not consider, to find optimizers that can make maximal use of increased data parallelism.
Future Work
Utilizing additional data parallelism by increasing the batch size is a simple way to produce valuable speedups across a range of workloads, but, for all the workloads we tried, the benefits diminished within the limits of state-of-the-art hardware. However, our results suggest that some optimization algorithms may be able to consistently extend the perfect scaling regime across many models and data sets. Future work could perform the same measurements with other optimizers, beyond the few closely-related ones we tried, to see if any existing optimizer extends perfect scaling across many problems.
Acknowledgements
The authors of this study were Chris Shallue, Jaehoon Lee, Joe Antognini, Jascha Sohl-Dickstein, Roy Frostig and George Dahl (Chris and Jaehoon contributed equally). Many researchers have done work in this area that we have built on, so please see
our paper
for a full discussion of related work.
A Summary of the Google Flood Forecasting Meets Machine Learning Workshop
Monday, March 18, 2019
Posted by Sella Nevo, Senior Software Engineer and Rainier Aliment, Program Manager
Recently, we hosted the
Google Flood Forecasting Meets Machine Learning
workshop in our Tel Aviv office, which brought hydrology and machine learning experts from Google and the broader research community to discuss existing efforts in this space, build a common vocabulary between these groups, and catalyze promising collaborations. In line with
our belief
that machine learning has the potential to significantly improve flood forecasting efforts and help the hundreds of millions of people affected by floods every year, this workshop discussed improving flood forecasting by aggregating and sharing large data sets, automating calibration and modeling processes, and applying modern statistical and machine learning tools to the problem.
Panel on challenges and opportunities in flood forecasting, featuring (from left to right):
Prof. Paolo Burlando
(ETH Zürich),
Dr. Tyler Erickson
(Google Earth Engine),
Dr. Peter Salamon
(Joint Research Centre) and
Prof. Dawei Han
(University of Bristol).
The event was kicked off by Google's
Yossi Matias
, who discussed
recent
machine
learning
work and its potential relevance for
flood forecasting
,
crisis response
and
AI for Social Good
, followed by two introductory sessions aimed at bridging some of the knowledge gap between the two fields - introduction to hydrology for computer scientists by
Prof. Peter Molnar
of ETH Zürich, and introduction to machine learning for hydrologists by
Prof. Yishay Mansour
of Tel Aviv University and
Google
Included in the 2-day event was a wide range of
fascinating talks
and
posters
across the flood forecasting landscape, from both hydrologic and machine learning points of view.
An overview of research areas in flood forecasting addressed in the workshop.
Presentations from the research community included:
Dr. Dhanya C. T.
of IIT Delhi gave a talk on satellite precipitation error characterization.
Adarsh M. S.
, Assistant Director of the Indian Ministry of Water Resources presented India's
Central Water Commission's
role and challenges.
Prof. Andras Bardossy
of the University of Stuttgart discussed variation in discharge series and the challenges this presents.
Frederik Kratzert
of Johannes Kepler University presented recent work on hydrologic modeling using
LSTMs
.
Prof. Paul Bates
of the University of Bristol gave a keynote on the potential uses of machine learning in inundation modelling.
Prof. Emmanouil Anagnostou
of the University of Connecticut spoke about hyper-resolution hydrologic simulations at global-scale.
Prof. Efrat Morin
of the Hebrew University highlighted flood prediction challenges in dry climate regions.
Dr. Zachary Flamig
of the University of Chicago presented NASA's new
global flash flood prediction project
.
Alongside these talks, we presented the various efforts across Google to try and improve flood forecasting and foster collaborations in the field, including:
Vova Anisimov presented our progress in hydraulic modeling.
Ami Weisel presented our research on
remote discharge estimation
.
Stephan Hoyer presented our work on
data-driven discretization approach to solving partial differential equations
.
Jason Hickey presented our efforts using machine learning for precipitation prediction.
Avinatan Hassidim presented lessons learned from
previous
projects
in
Google
, and how they apply to our
flood forecasting efforts
.
Additionally, at this workshop we piloted an experimental "ML Consultation" panel, where Googlers Gal Elidan, Sasha Goldshtein and Doron Kukliansky gave advice on how to best use machine learning in several hydrology-related tasks. Finally, we concluded the workshop with a moderated panel on the greatest challenges and opportunities in flood forecasting, with hydrology experts
Prof. Paolo Burlando
of ETH Zürich,
Prof. Dawei Han
of the University of Bristol,
Dr. Peter Salamon
of the Joint Research Centre and
Dr. Tyler Erickson
of Google Earth Engine.
Flood forecasting is an incredibly important and challenging task that is one part of our larger
AI for Social Good
efforts. We believe that effective global-scale solutions can be achieved by combining modern techniques with the domain expertise already existing in the field. The workshop was a great first step towards creating much-needed understanding, communication and collaboration between the flood forecasting community and the machine learning community, and we look forward to
our continued engagement
with the broad research community to tackle this challenge.
Acknowledgements
We would like to thank Avinatan Hassidim, Carla Bromberg, Doron Kukliansky, Efrat Morin, Gal Elidan, Guy Shalev, Jennifer Ye, Nadav Rabani and Sasha Goldshtein for their contributions to making this workshop happen.
Google Faculty Research Awards 2018
Friday, March 15, 2019
Posted by Maggie Johnson, VP, Education and Negar Saei, Program Manager, University Relations
We just completed another round of the
Google Faculty Research Awards
, our annual open call for proposals on computer science and related topics, such as quantum computing, machine learning, algorithms and theory, natural language processing and more. Our grants cover tuition for a graduate student and provide both faculty and students the opportunity to work directly with Google researchers and engineers.
This round we received 910 proposals covering 40 countries and over 320 universities. After expert reviews and committee discussions, we decided to fund 158 projects. The subject areas that received the most support this year were
human computer interaction
,
machine learning
,
machine perception
, and
systems
.
Congratulations to the well-deserving
recipients
of this round's awards. More information on how to apply for the next round will be available at the end of the summer on
our website
. You can find award recipients from previous years
here
.
Harnessing Organizational Knowledge for Machine Learning
Thursday, March 14, 2019
Posted by Alex Ratner, Stanford University and Cassandra Xia, Google AI
One of the biggest bottlenecks in developing machine learning (ML) applications is the need for the large, labeled datasets used to train modern ML models. Creating these datasets involves the investment of significant time and expense, requiring annotators with the right expertise. Moreover, due to the evolution of real-world applications, labeled datasets often need to be thrown out or re-labeled.
In collaboration with Stanford and Brown University, we present "
Snorkel Drybell: A Case Study in Deploying Weak Supervision at Industrial Scale
," which explores how existing knowledge in an organization can be used as noisier, higher-level supervision—or, as it is often termed,
weak supervision
—to quickly label large training datasets. In this study, we use an experimental internal system, Snorkel Drybell, which adapts the open-source
Snorkel
framework to use diverse
organizational knowledge resources—
like internal models, ontologies, legacy rules, knowledge graphs and more—in order to generate training data for machine learning models at web scale. We find that this approach can match the efficacy of hand-labeling tens of thousands of data points, and reveals some core lessons about how training datasets for modern machine learning models can be created in practice.
Rather than labeling training data by hand, Snorkel DryBell enables writing
labeling functions
that label training data programmatically. In this work, we explored how these labeling functions can capture engineers' knowledge about how to use existing resources as heuristics for weak supervision. As an example, suppose our goal is to identify content related to celebrities. One can leverage an existing
named-entity recognition
(NER) model for this task by labeling any content that does not contain a person as not related to celebrities. This illustrates how existing knowledge resources (in this case, a trained model) can be combined with simple programmatic logic to label training data for a new model. Note also, importantly, that this labeling function returns None---i.e. abstains---in many cases, and thus only labels some small part of the data; our overall goal is to use these labels to train a modern machine learning model that can generalize to new data.
In our example of a labeling function, rather than hand-labeling a data point (1), one utilizes an existing knowledge resource—in this case, a NER model (2)—together with some simple logic expressed in code (3) to heuristically label data.
This programmatic interface for labeling training data is much faster and more flexible than hand-labeling individual data points, but the resulting labels are obviously of much lower quality than manually-specified labels. The labels generated by these labeling functions will often overlap and disagree, as the labeling functions may not only have arbitrary unknown accuracies, but may also be correlated in arbitrary ways (for example, from sharing a common data source or heuristic).
To solve the problem of noisy and correlated labels, Snorkel DryBell uses a
generative modeling technique
to automatically estimate the accuracies and correlations of the labeling functions in a provably consistent way—without any ground truth training labels—then uses this to re-weight and combine their outputs into a single probabilistic label per data point. At a high level, we rely on the observed agreements and disagreements between the labeling functions (the
covariance matrix
), and learn the labeling function accuracy and correlation parameters that best explain this observed output
using a new matrix completion-style approach
. The resulting labels can then be used to train an arbitrary model (e.g. in
TensorFlow
), as shown in the system diagram below.
Using Diverse Knowledge Sources as Weak Supervision
To study the efficacy of Snorkel Drybell, we used three production tasks and corresponding datasets, aimed at classifying topics in web content, identifying mentions of certain products, and detecting certain real-time events. Using Snorkel DryBell, we were able to make use of various existing or quickly specified sources of information such as:
Heuristics and rules
:
e.g. existing human-authored rules about the target domain.
Topic models, taggers, and classifiers:
e.g. machine learning models about the target domain or a related domain.
Aggregate statistics:
e.g. tracked metrics about the target domain.
Knowledge or entity graphs:
e.g. databases of facts about the target domain.
In Snorkel DryBell, the goal is to train a machine learning model (C), for example to do content or event classification over web data. Rather than hand-labeling training data to do this, in Snorkel DryBell users write labeling functions that express various organizational knowledge resources (A), which are then automatically reweighted and combined (B).
We used these organizational knowledge resources to write labeling functions in a
MapReduce
template-based pipeline. Each labeling function takes in a data point and either abstains, or outputs a label. The result is a large set of programmatically-generated training labels. However, many of these labels were very noisy (e.g. from the heuristics), conflicted with each other, or were far too coarse-grained (e.g. the topic models) for our task, leading to the next stage of Snorkel DryBell, aimed at automatically cleaning and integrating the labels into a final training set.
Modeling the Accuracies to Combine & Repurpose Existing Sources
To handle these noisy labels, the next stage of Snorkel DryBell combines the outputs from the labeling functions into a single, confidence-weighted training label for each data point. The challenging technical aspect is that this must be done without any ground-truth labels. We use a
generative modeling technique
that learns the accuracy of each labeling function using only unlabeled data. This technique learns by observing the matrix of agreements and disagreements between the labeling functions' outputs, taking into account known (or
statistically estimated
) correlation structures between them. In Snorkel DryBell, we also implement a new faster, sampling-free version of this modeling approach, implemented in TensorFlow, in order to handle web-scale data.
By combining and modeling the output of the labeling functions using this procedure in Snorkel DryBell, we were able to generate high-quality training labels. In fact, on the two applications where hand-labeled training data was available for comparison, we achieved the same predictive accuracy training a model with Snorkel DryBell's labels as we did when training that same model with 12,000 and 80,000 hand-labeled training data points.
Transferring Non-Servable Knowledge to Servable Models
In many settings, there is also an important distinction between
servable
features—which can be used in production—and
non-servable
features, that are too slow or expensive to be used in production. These non-servable features may have very rich signal, but a general question is how to use them to train or otherwise help servable models that can be deployed in production?
In many settings, users write labeling functions that leverage organizational knowledge resources that are not servable in production (a)—e.g. aggregate statistics, internal models, or knowledge graphs that are too slow or expensive to use in production—in order to train models that are only defined over production-servable features (b), e.g. cheap, real-time web signals.
In Snorkel DryBell, we found that users could write the labeling functions—i.e. express their organizational knowledge—over one feature set that was not servable, and then use the resulting training labels output by Snorkel DryBell to train a model defined over a different,
servable
feature set. This
cross-feature
transfer boosted our performance by an average 52% on the benchmark datasets we created. More broadly, it represents a simple but powerful way to use resources that are too slow (e.g. expensive models or aggregate statistics), private (e.g. entity or knowledge graphs), or otherwise unsuitable for deployment, to train servable models over cheap, real-time features. This approach can be viewed as a new type of transfer learning, where instead of transferring a model between different datasets, we're transferring domain knowledge between different feature sets- an approach which has potential use cases not just in industry, but in medical settings and beyond.
Next Steps
Moving forward, we're excited to see what other types of organizational knowledge can be used as weak supervision, and how the approach used by Snorkel DryBell can enable new modes of information reuse and sharing across organizations. For more details, check out
our paper
, and for further technical details, blog posts, and tutorials, check out the open-source Snorkel implementation at
snorkel.stanford.edu
.
Acknowledgments
This research was done in collaboration between Google, Stanford, and Brown. We would like to thank all the people who were involved, including Stephen Bach (Brown), Daniel Rodriguez, Yintao Liu, Chong Luo, Haidong Shao, Souvik Sen, Braden Hancock (Stanford), Houman Alborzi, Rahul Kuchhal, Christopher Ré (Stanford), Rob Malkin.
An All-Neural On-Device Speech Recognizer
Tuesday, March 12, 2019
Posted by Johan Schalkwyk, Google Fellow, Speech Team
In 2012, speech recognition research showed
significant accuracy improvements with deep learning
, leading to early adoption in products such as Google's
Voice Search
. It was the beginning of a revolution in the field: each year, new architectures were developed that further increased quality, from
deep neural networks
(DNNs) to
recurrent neural networks
(RNNs),
long short-term memory networks
(LSTMs),
convolutional networks
(CNNs), and more. During this time, latency remained a prime focus — an automated assistant feels a lot more helpful when it responds quickly to requests.
Today, we're happy to announce the rollout of an end-to-end, all-neural, on-device speech recognizer to power speech input in Gboard. In our recent paper, "
Streaming End-to-End Speech Recognition for Mobile Devices
", we present a model trained using
RNN transducer
(RNN-T) technology that is compact enough to reside on a phone. This means no more network latency or spottiness — the new recognizer is always available, even when you are offline. The model works at the character level, so that as you speak, it outputs words character-by-character, just as if someone was typing out what you say in real-time, and exactly as you'd expect from a keyboard dictation system.
This video compares the production, server-side speech recognizer (left panel) to the new on-device recognizer (right panel) when recognizing the same spoken sentence. Video credit: Akshay Kannan and Elnaz Sarbar
A Bit of History
Traditionally, speech recognition systems consisted of several components - an acoustic model that maps segments of audio (typically 10 millisecond frames) to phonemes, a pronunciation model that connects phonemes together to form words, and a language model that expresses the likelihood of given phrases. In early systems, these components remained independently-optimized.
Around 2014, researchers began to focus on training a
single
neural network to directly map an input audio waveform to an output sentence. This sequence-to-sequence approach to learning a model by generating a sequence of words or graphemes given a sequence of audio features led to the development of "
attention-based
" and "
listen-attend-spell
" models. While these models showed
great promise
in terms of accuracy, they typically work by reviewing the entire input sequence, and do not allow streaming outputs as the input comes in, a necessary feature for real-time voice transcription.
Meanwhile, an independent technique called
connectionist temporal classification (CTC)
had helped
halve the latency
of the production recognizer at that time. This proved to be an important step in creating the RNN-T architecture adopted in this latest release, which can be seen as a generalization of CTC.
Recurrent Neural Network Transducers
RNN-Ts are a form of sequence-to-sequence models that do not employ attention mechanisms. Unlike most sequence-to-sequence models, which typically need to process the entire input sequence (the waveform in our case) to produce an output (the sentence), the RNN-T continuously processes input samples and streams output symbols, a property that is welcome for speech dictation. In our implementation, the output symbols are the characters of the alphabet. The RNN-T recognizer outputs characters one-by-one, as you speak, with white spaces where appropriate. It does this with a feedback loop that feeds symbols predicted by the model back into it to predict the next symbols, as described in the figure below.
Representation of an RNN-T, with the input audio samples, x, and the predicted symbols y. The predicted symbols (outputs of the Softmax layer) are fed back into the model through the Prediction network, as y
u-1
, ensuring that the predictions are conditioned both on the audio samples so far and on past outputs. The Prediction and Encoder Networks are LSTM RNNs, the Joint model is a feedforward network (
paper
). The Prediction Network comprises 2 layers of 2048 units, with a 640-dimensional projection layer. The Encoder Network comprises 8 such layers. Image credit: Chris Thornton
Training such models efficiently was already difficult, but with our development of a new
training technique
that further reduced the word error rate by 5%, it became even more computationally intensive. To deal with this, we developed a
parallel implementation
so the RNN-T loss function could run efficiently in large batches on Google's high-performance Cloud
TPU v2
hardware. This yielded an approximate 3x speedup in training.
Offline Recognition
In a traditional speech recognition engine, the acoustic, pronunciation, and language models we described above are "composed" together into a large
search graph
whose edges are labeled with the speech units and their probabilities. When a speech waveform is presented to the recognizer, a "decoder" searches this graph for the path of highest likelihood, given the input signal, and reads out the word sequence that path takes. Typically, the decoder assumes a
Finite State Transducer
(FST) representation of the underlying models. Yet, despite sophisticated decoding techniques, the search graph remains quite large, almost 2GB for our production models. Since this is not something that could be hosted easily on a mobile phone, this method requires online connectivity to work properly.
To improve the usefulness of speech recognition, we sought to avoid the latency and inherent unreliability of communication networks by hosting the new models directly on device. As such, our end-to-end approach does not need a search over a large decoder graph. Instead, decoding consists of a
beam search through a single neural network
. The RNN-T we trained offers the same accuracy as the traditional server-based models but is only 450MB, essentially making a smarter use of parameters and packing information more densely. However, even on today's smartphones, 450MB is a lot, and propagating signals through such a large network can be slow.
We further reduced the model size by using the
parameter quantization and hybrid kernel techniques
we developed in 2016 and made publicly available through the
model optimization toolkit
in the
TensorFlow Lite
library. Model quantization delivered a 4x compression with respect to the trained floating point models and a 4x speedup at run-time, enabling our RNN-T to run faster than real time speech on a single core. After compression, the final model is 80MB.
Our new all-neural, on-device Gboard speech recognizer is initially being launched to all Pixel phones in American English only. Given the trends in the industry, with the convergence of specialized hardware and algorithmic improvements, we are hopeful that the techniques presented here can soon be adopted in more languages and across broader domains of application.
Acknowledgements:
Raziel Alvarez, Michiel Bacchiani, Tom Bagby, Françoise Beaufays, Deepti Bhatia, Shuo-yiin Chang, Zhifeng Chen, Chung-Chen Chiu, Yanzhang He, Alex Gruenstein, Anjuli Kannan, Bo Li, Wei Li, Qiao Liang, Ian McGraw, Patrick Nguyen, Ruoming Pang, Rohit Prabhavalkar, Golan Pundak, Kanishka Rao, David Rybach, Tara Sainath, Haşim Sak, June Yuan Shangguan, Matt Shannon, Mohammadinamul Sheik, Khe Chai Sim, Gabor Simko, Trevor Strohman, Mirkó Visontai, Ron Weiss, Yonghui Wu, Ding Zhao, Dan Zivkovic, and Yu Zhang.
Labels
2018
accessibility
ACL
ACM
Acoustic Modeling
Adaptive Data Analysis
ads
adsense
adwords
Africa
AI
Algorithms
Android
Android Wear
API
App Engine
App Inventor
April Fools
Art
Audio
Augmented Reality
Australia
Automatic Speech Recognition
Awards
BigQuery
Cantonese
Chemistry
China
Chrome
Cloud Computing
Collaboration
Compression
Computational Imaging
Computational Photography
Computer Science
Computer Vision
conference
conferences
Conservation
correlate
Course Builder
crowd-sourcing
CVPR
Data Center
Data Discovery
data science
datasets
Deep Learning
DeepDream
DeepMind
distributed systems
Diversity
Earth Engine
economics
Education
Electronic Commerce and Algorithms
electronics
EMEA
EMNLP
Encryption
entities
Entity Salience
Environment
Europe
Exacycle
Expander
Faculty Institute
Faculty Summit
Flu Trends
Fusion Tables
gamification
Gboard
Gmail
Google Accelerated Science
Google Books
Google Brain
Google Cloud Platform
Google Docs
Google Drive
Google Genomics
Google Maps
Google Photos
Google Play Apps
Google Science Fair
Google Sheets
Google Translate
Google Trips
Google Voice Search
Google+
Government
grants
Graph
Graph Mining
Hardware
HCI
Health
High Dynamic Range Imaging
ICLR
ICML
ICSE
Image Annotation
Image Classification
Image Processing
Inbox
India
Information Retrieval
internationalization
Internet of Things
Interspeech
IPython
Journalism
jsm
jsm2011
K-12
KDD
Keyboard Input
Klingon
Korean
Labs
Linear Optimization
localization
Low-Light Photography
Machine Hearing
Machine Intelligence
Machine Learning
Machine Perception
Machine Translation
Magenta
MapReduce
market algorithms
Market Research
Mixed Reality
ML
ML Fairness
MOOC
Moore's Law
Multimodal Learning
NAACL
Natural Language Processing
Natural Language Understanding
Network Management
Networks
Neural Networks
NeurIPS
Nexus
Ngram
NIPS
NLP
On-device Learning
open source
operating systems
Optical Character Recognition
optimization
osdi
osdi10
patents
Peer Review
ph.d. fellowship
PhD Fellowship
PhotoScan
Physics
PiLab
Pixel
Policy
Professional Development
Proposals
Public Data Explorer
publication
Publications
Quantum AI
Quantum Computing
Reinforcement Learning
renewable energy
Research
Research Awards
resource optimization
Robotics
schema.org
Search
search ads
Security and Privacy
Semantic Models
Semi-supervised Learning
SIGCOMM
SIGMOD
Site Reliability Engineering
Social Networks
Software
Sound Search
Speech
Speech Recognition
statistics
Structured Data
Style Transfer
Supervised Learning
Systems
TensorBoard
TensorFlow
TPU
Translate
trends
TTS
TV
UI
University Relations
UNIX
User Experience
video
Video Analysis
Virtual Reality
Vision Research
Visiting Faculty
Visualization
VLDB
Voice Search
Wiki
wikipedia
WWW
Year in Review
YouTube
Archive
2019
Apr
Mar
Feb
Jan
2018
Dec
Nov
Oct
Sep
Aug
Jul
Jun
May
Apr
Mar
Feb
Jan
2017
Dec
Nov
Oct
Sep
Aug
Jul
Jun
May
Apr
Mar
Feb
Jan
2016
Dec
Nov
Oct
Sep
Aug
Jul
Jun
May
Apr
Mar
Feb
Jan
2015
Dec
Nov
Oct
Sep
Aug
Jul
Jun
May
Apr
Mar
Feb
Jan
2014
Dec
Nov
Oct
Sep
Aug
Jul
Jun
May
Apr
Mar
Feb
Jan
2013
Dec
Nov
Oct
Sep
Aug
Jul
Jun
May
Apr
Mar
Feb
Jan
2012
Dec
Oct
Sep
Aug
Jul
Jun
May
Apr
Mar
Feb
Jan
2011
Dec
Nov
Sep
Aug
Jul
Jun
May
Apr
Mar
Feb
Jan
2010
Dec
Nov
Oct
Sep
Aug
Jul
Jun
May
Apr
Mar
Feb
Jan
2009
Dec
Nov
Aug
Jul
Jun
May
Apr
Mar
Feb
Jan
2008
Dec
Nov
Oct
Sep
Jul
May
Apr
Mar
Feb
2007
Oct
Sep
Aug
Jul
Jun
Feb
2006
Dec
Nov
Sep
Aug
Jul
Jun
Apr
Mar
Feb
Feed
Follow @googleai
Give us feedback in our
Product Forums
.