级联编码器 (Cascaded Encoders) 在流+离线混合 ASR 模型
流式+离线语音识别

受到 Google 一篇 CASCADED ENCODERS FOR UNIFYING STREAMING AND NON-STREAMING ASR 论文的启发,准备采用级联编码来改善 Second Pass 的识别效果。

最近也有其他离在线混合模型的论文,TRANSFORMER TRANSDUCER: ONE MODEL UNIFYING STREAMING AND NON-STREAMING SPEECH RECOGNITIONUnified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition,但是就工程上比较划算的方式还是这种采用级联的模型。级联编码的编码结构如下图所示:

截屏2021-01-14 上午11.55.04

在 Casual 编码器的基础上加入了一个 Non-Causal 的离线编码器,可以理解为在线+离线混合模型。可以边显示实时编码结果,还可以在句尾做二次矫正显示更优的整句结果。另外,级联的方式还可以帮助离线模型分摊一大部分计算量。

最终的 loss 计算可以引入一个权重系数:$L=λL_s +(1−λ)L_a$。Google 也给出了他们实际中采用的方式,那就是在模型训练过程中,采用随机选择路径的方式来进行前馈计算和梯度的求解。

In practice, we found that we can decrease the step-time during training by sampling from es and ea within a mini-batch using λ as the sampling rate. Therefore, for each input utterance, we stochastically choose either the causal or the non-causal processing path at each training step. This alleviates the need to compute RNN-T loss twice for each training example at each training step. With sampling, the model converges roughly after the same number of steps as a standalone streaming model.

这样可以减少训练中每次需要计算两次 Loss 的问题,在每条路径下累加相同总步数的情况下,收敛程度应该和分别计算是差不多的。为此,我也写了一段伪代码。

class CascadedCtcMultiLayer:
   """ 
  CASCADED ENCODERS FOR UNIFYING STREAMING AND NON-STREAMING ASR
​
  Ref: https://arxiv.org/abs/2010.14606
​
  Architecture:
​
      Inputs 
        |
        |             
  Causal Encoder-------Non-Causal Encoder
        |                     |
        |_____________________|
        |
      Logits
        |
      CTC Loss
​
  """

   def __init__(self, is_train_condition, proto_path, vocal_size,
                trainable=True, weight_item=0.5):

       self.is_train_condition = is_train_condition

       self.weight_item = tf.cast(weight_item, dtype=tf.float32)

       self.proto = ProtoParserCascadedCtc(proto_path)

       # EncoderNetwork
       self.encoder = TransformerNetwork.CtcEncoderNetwork(
           self.proto, is_train_condition, trainable)

       # LogitNetwork
       self.logit = TransformerNetwork.CtcLogitNetwork(
           self.proto, is_train_condition, trainable, vocal_size)

       # CascadedEncoders
       self.cascaded_encoder = TransformerNetwork.CascadedCtcEncoderNetwork(
           self.proto, is_train_condition, trainable)

   def __call__(self, inputs, sequence_lengths):
       '''
      @return:
          logit: Logit with causal network.
          cascaded_logit: Logit with non-causal network.
          sequence_lengths: sequence lengths.
      '''
       acoustic_features, sequence_lengths = self.encoder(
           inputs, sequence_lengths)

       logit, cascaded_logit = tf.cond(
           self.is_train_condition,
           lambda: self.training(acoustic_features, sequence_lengths),
           lambda: self.testing(acoustic_features, sequence_lengths)
      )

       return logit, cascaded_logit, sequence_lengths

   def training(self, acoustic_features, sequence_lengths):
       '''
      For each input utterance, we stochastically choose either the causal or 
      the non-causal processing path at each training step. 
      This alleviates the need to compute RNN-T loss twice for each training 
      example at each training step. With sampling, the model converges roughly 
      after the same number of steps as a standalone streaming model.
      '''
       random_val = tf.random.uniform(
           shape=(), minval=0., maxval=1., dtype=tf.float32)

       def causal():
           return self.logit(acoustic_features)

       def nocausal():
           cascaded_acoustic_features, _ = self.cascaded_encoder(
               acoustic_features, sequence_lengths)
           return self.logit(cascaded_acoustic_features)

       logit = tf.cond(
           tf.math.less_equal(random_val, self.weight_item),
           causal,
           nocausal
      )

       return logit, logit

   def testing(self, acoustic_features, sequence_lengths):
       logit = self.logit(acoustic_features)

       cascaded_acoustic_features, _ = self.cascaded_encoder(
           acoustic_features, sequence_lengths)

       cascaded_logit = self.logit(cascaded_acoustic_features)

       return logit, cascaded_logit

和原文采用 RNN-T loss 不同的是,我目前只使用了 CTC Loss 的实验正在进行中。其中 Logit 部分也共享了参数。

原论文中采用级联结构后的结果如下:

截屏2021-01-14 下午12.09.45

可以看到和单向的 RNN-T 比较 causal 部分有略微提升。Non-causal 部分在 VS 数据下,相比于纯双向 RNN-T 模型有 6% 的相对差距。考虑到级联模型在在线应用上的优势,还是一个非常值得尝试的结构。

01/14/2021