torchaudio.prototype¶
torchaudio.prototype
provides prototype features;
see here for more information on prototype features.
The module is available only within nightly builds and must be imported
explicitly, e.g. import torchaudio.prototype
.
Emformer¶

class
torchaudio.prototype.
Emformer
(input_dim: int, num_heads: int, ffn_dim: int, num_layers: int, dropout: float = 0.0, activation: str = 'relu', left_context_length: int = 0, right_context_length: int = 0, segment_length: int = 128, max_memory_size: int = 0, weight_init_scale_strategy: str = 'depthwise', tanh_on_mem: bool = False, negative_inf: float =  100000000.0)[source]¶ Implements the Emformer architecture introduced in Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency Streaming Speech Recognition [1].
 Parameters
input_dim (int) – input dimension.
num_heads (int) – number of attention heads in each Emformer layer.
ffn_dim (int) – hidden layer dimension of each Emformer layer’s feedforward network.
num_layers (int) – number of Emformer layers to instantiate.
dropout (float, optional) – dropout probability. (Default: 0.0)
activation (str, optional) – activation function to use in each Emformer layer’s feedforward network. Must be one of (“relu”, “gelu”, “silu”). (Default: “relu”)
left_context_length (int, optional) – length of left context. (Default: 0)
right_context_length (int, optional) – length of right context. (Default: 0)
segment_length (int, optional) – length of each input segment. (Default: 128)
max_memory_size (int, optional) – maximum number of memory elements to use. (Default: 0)
weight_init_scale_strategy (str, optional) – perlayer weight initialization scaling strategy. Must be one of (“depthwise”, “constant”,
None
). (Default: “depthwise”)tanh_on_mem (bool, optional) – if
True
, applies tanh to memory elements. (Default:False
)negative_inf (float, optional) – value to use for negative infinity in attention weights. (Default: 1e8)
Examples
>>> emformer = Emformer(512, 8, 2048, 20) >>> input = torch.rand(128, 400, 512) # batch, num_frames, feature_dim >>> lengths = torch.randint(1, 200, (128,)) # batch >>> output = emformer(input, lengths) >>> output, lengths, states = emformer.infer(input, lengths, None)

forward
(input: torch.Tensor, lengths: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor][source]¶ Forward pass for training.
B: batch size; T: number of frames; D: feature dimension of each frame.
 Parameters
input (torch.Tensor) – utterance frames rightpadded with right context frames, with shape (B, T, D).
lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
input
.
 Returns
 Tensor
output frames, with shape (B, T  ``right_context_length`, D)`.
 Tensor
output lengths, with shape (B,) and ith element representing number of valid frames for ith batch element in output frames.
 Return type
(Tensor, Tensor)

infer
(input: torch.Tensor, lengths: torch.Tensor, states: Optional[List[List[torch.Tensor]]] = None) → Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]][source]¶ Forward pass for inference.
B: batch size; T: number of frames; D: feature dimension of each frame.
 Parameters
input (torch.Tensor) – utterance frames rightpadded with right context frames, with shape (B, T, D).
lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
input
.states (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing Emformer internal state generated in preceding invocation of
infer
. (Default:None
)
 Returns
 Tensor
output frames, with shape (B, T  ``right_context_length`, D)`.
 Tensor
output lengths, with shape (B,) and ith element representing number of valid frames for ith batch element in output frames.
 List[List[Tensor]]
output states; list of lists of tensors representing Emformer internal state generated in current invocation of
infer
.
 Return type
(Tensor, Tensor, List[List[Tensor]])
RNNT¶

class
torchaudio.prototype.
RNNT
(transcriber: torchaudio.prototype.rnnt._Transcriber, predictor: torchaudio.prototype.rnnt._Predictor, joiner: torchaudio.prototype.rnnt._Joiner)[source]¶ Recurrent neural network transducer (RNNT) model.
Note
To build the model, please use one of the factory functions.
 Parameters
transcriber (torch.nn.Module) – transcription network.
predictor (torch.nn.Module) – prediction network.
joiner (torch.nn.Module) – joint network.

forward
(sources: torch.Tensor, source_lengths: torch.Tensor, targets: torch.Tensor, target_lengths: torch.Tensor, predictor_state: Optional[List[List[torch.Tensor]]] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]]][source]¶ Forward pass for training.
B: batch size; T: maximum source sequence length in batch; U: maximum target sequence length in batch; D: feature dimension of each source sequence element.
 Parameters
sources (torch.Tensor) – source frame sequences rightpadded with right context, with shape (B, T, D).
source_lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
sources
.targets (torch.Tensor) – target sequences, with shape (B, U) and each element mapping to a target symbol.
target_lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
targets
.predictor_state (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing prediction network internal state generated in preceding invocation of
forward
. (Default:None
)
 Returns
 torch.Tensor
joint network output, with shape (B, max output source length, max output target length, number of target symbols).
 torch.Tensor
output source lengths, with shape (B,) and ith element representing number of valid elements along dim 1 for ith batch element in joint network output.
 torch.Tensor
output target lengths, with shape (B,) and ith element representing number of valid elements along dim 2 for ith batch element in joint network output.
 List[List[torch.Tensor]]
output states; list of lists of tensors representing prediction network internal state generated in current invocation of
forward
.
 Return type
(torch.Tensor, torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

transcribe_streaming
(sources: torch.Tensor, source_lengths: torch.Tensor, state: Optional[List[List[torch.Tensor]]]) → Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]][source]¶ Applies transcription network to sources in streaming mode.
B: batch size; T: maximum source sequence segment length in batch; D: feature dimension of each source sequence frame.
 Parameters
sources (torch.Tensor) – source frame sequence segments rightpadded with right context, with shape (B, T + right context length, D).
source_lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
sources
.state (List[List[torch.Tensor]] or None) – list of lists of tensors representing transcription network internal state generated in preceding invocation of
transcribe_streaming
.
 Returns
 torch.Tensor
output frame sequences, with shape (B, T // time_reduction_stride, output_dim).
 torch.Tensor
output lengths, with shape (B,) and ith element representing number of valid elements for ith batch element in output.
 List[List[torch.Tensor]]
output states; list of lists of tensors representing transcription network internal state generated in current invocation of
transcribe_streaming
.
 Return type
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

transcribe
(sources: torch.Tensor, source_lengths: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor][source]¶ Applies transcription network to sources in nonstreaming mode.
B: batch size; T: maximum source sequence length in batch; D: feature dimension of each source sequence frame.
 Parameters
sources (torch.Tensor) – source frame sequences rightpadded with right context, with shape (B, T + right context length, D).
source_lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
sources
.
 Returns
 torch.Tensor
output frame sequences, with shape (B, T // time_reduction_stride, output_dim).
 torch.Tensor
output lengths, with shape (B,) and ith element representing number of valid elements for ith batch element in output frame sequences.
 Return type

predict
(targets: torch.Tensor, target_lengths: torch.Tensor, state: Optional[List[List[torch.Tensor]]]) → Tuple[torch.Tensor, torch.Tensor, List[List[torch.Tensor]]][source]¶ Applies prediction network to targets.
B: batch size; U: maximum target sequence length in batch; D: feature dimension of each target sequence frame.
 Parameters
targets (torch.Tensor) – target sequences, with shape (B, U) and each element mapping to a target symbol, i.e. in range [0, num_symbols).
target_lengths (torch.Tensor) – with shape (B,) and ith element representing number of valid frames for ith batch element in
targets
.state (List[List[torch.Tensor]] or None) – list of lists of tensors representing internal state generated in preceding invocation of
predict
.
 Returns
 torch.Tensor
output frame sequences, with shape (B, U, output_dim).
 torch.Tensor
output lengths, with shape (B,) and ith element representing number of valid elements for ith batch element in output.
 List[List[torch.Tensor]]
output states; list of lists of tensors representing internal state generated in current invocation of
predict
.
 Return type
(torch.Tensor, torch.Tensor, List[List[torch.Tensor]])

join
(source_encodings: torch.Tensor, source_lengths: torch.Tensor, target_encodings: torch.Tensor, target_lengths: torch.Tensor) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶ Applies joint network to source and target encodings.
B: batch size; T: maximum source sequence length in batch; U: maximum target sequence length in batch; D: dimension of each source and target sequence encoding.
 Parameters
source_encodings (torch.Tensor) – source encoding sequences, with shape (B, T, D).
source_lengths (torch.Tensor) – with shape (B,) and ith element representing valid sequence length of ith batch element in
source_encodings
.target_encodings (torch.Tensor) – target encoding sequences, with shape (B, U, D).
target_lengths (torch.Tensor) – with shape (B,) and ith element representing valid sequence length of ith batch element in
target_encodings
.
 Returns
 torch.Tensor
joint network output, with shape (B, T, U, D).
 torch.Tensor
output source lengths, with shape (B,) and ith element representing number of valid elements along dim 1 for ith batch element in joint network output.
 torch.Tensor
output target lengths, with shape (B,) and ith element representing number of valid elements along dim 2 for ith batch element in joint network output.
 Return type
emformer_rnnt_base¶
emformer_rnnt_model¶

torchaudio.prototype.
emformer_rnnt_model
(*, input_dim: int, encoding_dim: int, num_symbols: int, segment_length: int, right_context_length: int, time_reduction_input_dim: int, time_reduction_stride: int, transformer_num_heads: int, transformer_ffn_dim: int, transformer_num_layers: int, transformer_dropout: float, transformer_activation: str, transformer_left_context_length: int, transformer_max_memory_size: int, transformer_weight_init_scale_strategy: str, transformer_tanh_on_mem: bool, symbol_embedding_dim: int, num_lstm_layers: int, lstm_layer_norm: bool, lstm_layer_norm_epsilon: float, lstm_dropout: float) → torchaudio.prototype.rnnt.RNNT[source]¶ Builds Emformerbased recurrent neural network transducer (RNNT) model.
Note
For nonstreaming inference, the expectation is for transcribe to be called on input sequences rightconcatenated with right_context_length frames.
For streaming inference, the expectation is for transcribe_streaming to be called on input chunks comprising segment_length frames rightconcatenated with right_context_length frames.
 Parameters
input_dim (int) – dimension of input sequence frames passed to transcription network.
encoding_dim (int) – dimension of transcription and predictionnetworkgenerated encodings passed to joint network.
num_symbols (int) – cardinality of set of target tokens.
segment_length (int) – length of input segment expressed as number of frames.
right_context_length (int) – length of right context expressed as number of frames.
time_reduction_input_dim (int) – dimension to scale each element in input sequences to prior to applying time reduction block.
time_reduction_stride (int) – factor by which to reduce length of input sequence.
transformer_num_heads (int) – number of attention heads in each Emformer layer.
transformer_ffn_dim (int) – hidden layer dimension of each Emformer layer’s feedforward network.
transformer_num_layers (int) – number of Emformer layers to instantiate.
transformer_left_context_length (int) – length of left context considered by Emformer.
transformer_dropout (float) – Emformer dropout probability.
transformer_activation (str) – activation function to use in each Emformer layer’s feedforward network. Must be one of (“relu”, “gelu”, “silu”).
transformer_max_memory_size (int) – maximum number of memory elements to use.
transformer_weight_init_scale_strategy (str) – perlayer weight initialization scaling strategy. Must be one of (“depthwise”, “constant”,
None
).transformer_tanh_on_mem (bool) – if
True
, applies tanh to memory elements.symbol_embedding_dim (int) – dimension of each target token embedding.
num_lstm_layers (int) – number of LSTM layers to instantiate.
lstm_layer_norm (bool) – if
True
, enables layer normalization for LSTM layers.lstm_layer_norm_epsilon (float) – value of epsilon to use in LSTM layer normalization layers.
lstm_dropout (float) – LSTM dropout probability.
 Returns
Emformer RNNT model.
 Return type
RNNTBeamSearch¶

class
torchaudio.prototype.
RNNTBeamSearch
(model: torchaudio.prototype.rnnt.RNNT, blank: int, temperature: float = 1.0, hypo_sort_key: Optional[Callable[[torchaudio.prototype.rnnt_decoder.Hypothesis], float]] = None, step_max_tokens: int = 100)[source]¶ Beam search decoder for RNNT model.
 Parameters
model (RNNT) – RNNT model to use.
blank (int) – index of blank token in vocabulary.
temperature (float, optional) – temperature to apply to joint network output. Larger values yield more uniform samples. (Default: 1.0)
hypo_sort_key (Callable[[Hypothesis], float] or None, optional) – callable that computes a score for a given hypothesis to rank hypotheses by. If
None
, defaults to callable that returns hypothesis score normalized by token sequence length. (Default: None)step_max_tokens (int, optional) – maximum number of tokens to emit per input time step. (Default: 100)

forward
(input: torch.Tensor, length: torch.Tensor, beam_width: int) → List[torchaudio.prototype.rnnt_decoder.Hypothesis][source]¶ Performs beam search for the given input sequence.
T: number of frames; D: feature dimension of each frame.
 Parameters
input (torch.Tensor) – sequence of input frames, with shape (T, D) or (1, T, D).
length (torch.Tensor) – number of valid frames in input sequence, with shape () or (1,).
beam_width (int) – beam size to use during search.
 Returns
top
beam_width
hypotheses found by beam search. Return type
List[Hypothesis]

infer
(input: torch.Tensor, length: torch.Tensor, beam_width: int, state: Optional[List[List[torch.Tensor]]] = None, hypothesis: Optional[torchaudio.prototype.rnnt_decoder.Hypothesis] = None) → Tuple[List[torchaudio.prototype.rnnt_decoder.Hypothesis], List[List[torch.Tensor]]][source]¶ Performs beam search for the given input sequence in streaming mode.
T: number of frames; D: feature dimension of each frame.
 Parameters
input (torch.Tensor) – sequence of input frames, with shape (T, D) or (1, T, D).
length (torch.Tensor) – number of valid frames in input sequence, with shape () or (1,).
beam_width (int) – beam size to use during search.
state (List[List[torch.Tensor]] or None, optional) – list of lists of tensors representing transcription network internal state generated in preceding invocation. (Default:
None
)hypothesis (Hypothesis or None) – hypothesis from preceding invocation to seed search with. (Default:
None
)
 Returns
 List[Hypothesis]
top
beam_width
hypotheses found by beam search. List[List[torch.Tensor]]
list of lists of tensors representing transcription network internal state generated in current invocation.
 Return type
(List[Hypothesis], List[List[torch.Tensor]])
Hypothesis¶

class
torchaudio.prototype.
Hypothesis
(tokens: List[int], predictor_out: torch.Tensor, state: List[List[torch.Tensor]], score: float, alignment: List[int], blank: int, key: str)[source]¶ Represents hypothesis generated by beam search decoder
RNNTBeamSearch
. Variables
tokens (List[int]) – Predicted sequence of tokens.
predictor_out (torch.Tensor) – Prediction network output.
state (List[List[torch.Tensor]]) – Prediction network internal state.
score (float) – Score of hypothesis.
alignment (List[int]) – Sequence of timesteps, with the ith value mapping to the ith predicted token in
tokens
.blank (int) – Token index corresponding to blank token.
key (str) – Value used to determine equivalence in token sequences between
Hypothesis
instances.
References¶
 1
Yangyang Shi, Yongqiang Wang, Chunyang Wu, ChingFeng Yeh, Julian Chan, Frank Zhang, Duc Le, and Mike Seltzer. Emformer: efficient memory transformer based acoustic model for low latency streaming speech recognition. In ICASSP 2021  2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), 6783–6787. 2021.