glow(pytorch) メモ

graphレベルの最適化処理まわりの情報

ドキュメント

ソース

void glow::optimize(Function *F, const CompilationOptions &opts) {                                                                                                                                                                                                                                                                                                                           
  // Optimize may be called after backend specific transformations and some                                                                                                                                                                                                                                                                                                                  
  // nodes may have become unused. It is a good idea to remove them, before                                                                                                                                                                                                                                                                                                                  
  // proceeding with any further optimizations.                                                                                                                                                                                                                                                                                                                                              
  DCE(F);                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                             
  // Sink transpose operations in an attempt to cancel them out.                                                                                                                                                                                                                                                                                                                             
  // Perform code sinking until a fixed-point is reached.                                                                                                                                                                                                                                                                                                                                    
  // On big functions, the number of iterations until the fixpoint                                                                                                                                                                                                                                                                                                                           
  // is usually at most 2 or 3 iterations.                                                                                                                                                                                                                                                                                                                                                   
  while (sinkCode(F)) {                                                                                                                                                                                                                                                                                                                                                                      
    // Perform Dead Code Elimination between rounds of code sinking.                                                                                                                                                                                                                                                                                                                         
    DCE(F);                                                                                                                                                                                                                                                                                                                                                                                  
  }                                                                                                                                                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                                                                                                                                             
  // Transposes that don't move data are optimized into Reshapes, which enables                                                                                                                                                                                                                                                                                                              
  // further optimizations.                                                                                                                                                                                                                                                                                                                                                                  
  optimizeTransposeIntoReshape(F);                                                                                                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                                                                                                                             
  // Reshapes and transposes can prevent other optimizations from triggering,                                                                                                                                                                                                                                                                                                                
  // so try to optimize them out first.                                                                                                                                                                                                                                                                                                                                                      
  optimizeReshape(F);                                                                                                                                                                                                                                                                                                                                                                        
  if (opts.mode == CompilationMode::Infer) {                                                                                                                                                                                                                                                                                                                                                 
    transposeConstants(F);                                                                                                                                                                                                                                                                                                                                                                   
  }                                                                                                                                                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                                                                                                                                             
  // Optimize the pooling operation.                                                                                                                                                                                                                                                                                                                                                         
  optimizePool(F);                                                                                                                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                                                                                                                             
  // Perform Common Subexpression Elimination.                                                                                                                                                                                                                                                                                                                                               
  CSE(F);                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                             
  // Optimize Pad nodes                                                                                                                                                                                                                                                                                                                                                                      
  mergePadIntoConvolution(F);                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                             
  // Perform Dead Code Elimination.                                                                                                                                                                                                                                                                                                                                                          
  DCE(F);                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                             
  // Merge multiple matmul nodes into a single large matmul.                                                                                                                                                                                                                                                                                                                                 
  mergeMatMul(F);                                                                                                                                                                                                                                                                                                                                                                            
                                                                                                                                                                                                                                                                                                                                                                                             
  // Merge multiple batched adds into a larger batched add.                                                                                                                                                                                                                                                                                                                                  
  mergeBatchedAdd(F);                                                                                                                                                                                                                                                                                                                                                                        
                                                                                                                                                                                                                                                                                                                                                                                             
  // Merge ReduceMean into AveragePool if possible.                                                                                                                                                                                                                                                                                                                                          
  optimizeReduceMean(F);                                                                                                                                                                                                                                                                                                                                                                     
                                                                                                                                                                                                                                                                                                                                                                                             
  // Perform Dead Code Elimination.                                                                                                                                                                                                                                                                                                                                                          
  DCE(F);                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                             
  if (opts.mode == CompilationMode::Infer) {                                                                                                                                                                                                                                                                                                                                                 
    // Merge batch normalization operations.                                                                                                                                                                                                                                                                                                                                                 
    // Do after transpose constant folding, as weight transposes can prevent                                                                                                                                                                                                                                                                                                                 
    // the optimization from triggering.                                                                                                                                                                                                                                                                                                                                                     
    optimizeBatchNorm(F);                                                                                                                                                                                                                                                                                                                                                                    
  }                                                                                                                                                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                                                                                                                                             
  // Perform Common Subexpression Elimination.                                                                                                                                                                                                                                                                                                                                               
  CSE(F);                                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                             
  // Optimize Concat nodes.                                                                                                                                                                                                                                                                                                                                                                  
  optimizeConcatNodes(F);                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                             
  // Optimize arithmetic nodes based on algebraic identities.                                                                                                                                                                                                                                                                                                                                
  optimizeArithmeticNodes(F);                                                                                                                                                                                                                                                                                                                                                                
                                                                                                                                                                                                                                                                                                                                                                                             
  // Optimize Tensor shape transformations.                                                                                                                                                                                                                                                                                                                                                  
  optimizeSliceOfSplat(F);                                                                                                                                                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                                                                                                                                             
  // Merge Transpose into MatMul/FC.                                                                                                                                                                                                                                                                                                                                                         
  // Run DCE to ensure correct number of node users.                                                                                                                                                                                                                                                                                                                                         
  DCE(F);                                                                                                                                                                                                                                                                                                                                                                                    
  mergeTransposeIntoMatMulOrFC(F);                                                                                                                                                                                                                                                                                                                                                           
                                                                                                                                                                                                                                                                                                                                                                                             
  // Optimize away intermediate type conversions.                                                                                                                                                                                                                                                                                                                                            
  optimizeConversions(F);                                                                                                                                                                                                                                                                                                                                                                    
                                                                                                                                                                                                                                                                                                                                                                                             
  // Optimize quantization related operators.                                                                                                                                                                                                                                                                                                                                                
  optimizeQuantization(F);                                                                                                                                                                                                                                                                                                                                                                   
                                                                                                                                                                                                                                                                                                                                                                                             
  while (sinkRescaleQuantizedNode(F)) {                                                                                                                                                                                                                                                                                                                                                      
    DCE(F);                                                                                                                                                                                                                                                                                                                                                                                  
    optimizeQuantization(F);                                                                                                                                                                                                                                                                                                                                                                 
  }                                                                                                                                                                                                                                                                                                                                                                                          
                                                                                                                                                                                                                                                                                                                                                                                             
  // Perform Dead Code Elimination.                                                                                                                                                                                                                                                                                                                                                          
  DCE(F);                                                                                                                                                                                                                                                                                                                                                                                    
}