手动编译Lua代码

23k 词

如果将一段Lua代码直接翻译成C++代码,可能会存在一个问题:Luatail callC++没有tail call。 例如下面这个函数(求二叉树所有结点的和),第二次递归调用Visittail call,如果直接翻译成C++代码,会失去一部分优化效果。

function SumTree(root)
    local Sum = 0
    local function Visit(CurNode)
        if not CurNode then
            return    
        end
        Sum = Sum + CurNode.val
        Visit(CurNode.left)
        Visit(CurNode.right) -- tail call
    end
    Visit(root)
    return Sum
end

有一种方法可以将这段代码转化成高效的C++代码,保留tail call的效果,同时避免大部分函数调用的开销。
第一步,用 cps变换 将代码转换成下面的样子,每个函数调用都是tail call,每个函数增加了一个参数Cont
代码运行时系统的栈不会增长,访问left的时候Cont才会增长。
这个变换的作用是用自定义的Cont代替系统的栈,同时保留tail call的效果。

function SumTree(root)
    local Sum = 0
<span class="kd">local</span> <span class="k">function</span> <span class="nf">ApplyCont</span><span class="p">(</span><span class="n">Cont</span><span class="p">)</span>
    <span class="n">Cont</span><span class="p">()</span>
<span class="k">end</span>

<span class="kd">local</span> <span class="k">function</span> <span class="nf">ID</span><span class="p">()</span>
<span class="k">end</span>

<span class="kd">local</span> <span class="k">function</span> <span class="nf">Visit</span><span class="p">(</span><span class="n">CurNode</span><span class="p">,</span> <span class="n">Cont</span><span class="p">)</span>
    <span class="k">if</span> <span class="ow">not</span> <span class="n">CurNode</span> <span class="k">then</span>
        <span class="k">return</span> <span class="n">ApplyCont</span><span class="p">(</span><span class="n">Cont</span><span class="p">)</span>
    <span class="k">end</span>
    <span class="n">Sum</span> <span class="o">=</span> <span class="n">Sum</span> <span class="o">+</span> <span class="n">CurNode</span><span class="p">.</span><span class="n">val</span>
    <span class="kd">local</span> <span class="n">Cont1</span> <span class="o">=</span> <span class="k">function</span><span class="p">()</span>
        <span class="n">Visit</span><span class="p">(</span><span class="n">CurNode</span><span class="p">.</span><span class="n">right</span><span class="p">,</span> <span class="n">Cont</span><span class="p">)</span>
    <span class="k">end</span>
    <span class="n">Visit</span><span class="p">(</span><span class="n">CurNode</span><span class="p">.</span><span class="n">left</span><span class="p">,</span> <span class="n">Cont1</span><span class="p">)</span> <span class="c1">-- Cont增长</span>
<span class="k">end</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">ID</span><span class="p">)</span>

<span class="k">return</span> <span class="n">Sum</span>

end

第二步,用自定义的数据结构代表Cont,主要是将Cont函数捕捉的free vars与函数的代码分离开。

function SumTree(root)
    local Sum = 0
<span class="kd">local</span> <span class="n">Visit</span>
<span class="kd">local</span> <span class="n">ActionMap</span> <span class="o">=</span> <span class="p">{</span>
    <span class="k">function</span><span class="p">(</span><span class="n">FreeVars</span><span class="p">,</span> <span class="n">Cont</span><span class="p">)</span>
        <span class="n">Visit</span><span class="p">(</span><span class="n">FreeVars</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">Cont</span><span class="p">)</span>
    <span class="k">end</span><span class="p">,</span>
<span class="p">}</span>

<span class="kd">local</span> <span class="k">function</span> <span class="nf">ApplyCont</span><span class="p">(</span><span class="n">Cont</span><span class="p">)</span>
    <span class="k">if</span> <span class="ow">not</span> <span class="n">Cont</span> <span class="k">then</span>
        <span class="k">return</span>
    <span class="k">end</span>
    <span class="kd">local</span> <span class="n">Action</span> <span class="o">=</span> <span class="n">ActionMap</span><span class="p">[</span><span class="n">Cont</span><span class="p">.</span><span class="n">ActionIndex</span><span class="p">]</span>
    <span class="n">Action</span><span class="p">(</span><span class="n">Cont</span><span class="p">.</span><span class="n">FreeVars</span><span class="p">,</span> <span class="n">Cont</span><span class="p">.</span><span class="n">Cont</span><span class="p">)</span>
<span class="k">end</span>

<span class="kd">local</span> <span class="n">ID</span> <span class="o">=</span> <span class="kc">nil</span>

<span class="n">Visit</span> <span class="o">=</span> <span class="k">function</span><span class="p">(</span><span class="n">CurNode</span><span class="p">,</span> <span class="n">Cont</span><span class="p">)</span>
    <span class="k">if</span> <span class="ow">not</span> <span class="n">CurNode</span> <span class="k">then</span>
        <span class="k">return</span> <span class="n">ApplyCont</span><span class="p">(</span><span class="n">Cont</span><span class="p">)</span>
    <span class="k">end</span>
    <span class="n">Sum</span> <span class="o">=</span> <span class="n">Sum</span> <span class="o">+</span> <span class="n">CurNode</span><span class="p">.</span><span class="n">val</span>
    <span class="kd">local</span> <span class="n">Cont1</span> <span class="o">=</span> <span class="p">{</span>
        <span class="n">FreeVars</span> <span class="o">=</span> <span class="p">{</span><span class="n">CurNode</span><span class="p">.</span><span class="n">right</span><span class="p">},</span>
        <span class="n">ActionIndex</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
        <span class="n">Cont</span> <span class="o">=</span> <span class="n">Cont</span><span class="p">,</span>
    <span class="p">}</span>
    
    <span class="n">Visit</span><span class="p">(</span><span class="n">CurNode</span><span class="p">.</span><span class="n">left</span><span class="p">,</span> <span class="n">Cont1</span><span class="p">)</span>
<span class="k">end</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">root</span><span class="p">,</span> <span class="n">ID</span><span class="p">)</span>

<span class="k">return</span> <span class="n">Sum</span>

end

第三步,上面代码中的Cont是一个简单的链表,而且只在一端操作,所以可以替换成一个外部的Stack,删掉所有函数的Cont参数。 同时将ActionMap中的代码inlineApplyCont 中。这段代码中只有一种Action,所以ActionIndex不是必须的。

function SumTree(root)
    local Sum = 0
<span class="kd">local</span> <span class="n">Visit</span>
<span class="kd">local</span> <span class="n">Stack</span> <span class="o">=</span> <span class="p">{}</span>

<span class="kd">local</span> <span class="k">function</span> <span class="nf">ApplyCont</span><span class="p">()</span>
    <span class="k">if</span> <span class="o">#</span><span class="n">Stack</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">then</span>
        <span class="k">return</span>
    <span class="k">end</span>
    <span class="kd">local</span> <span class="n">Top</span> <span class="o">=</span> <span class="n">Stack</span><span class="p">[</span><span class="o">#</span><span class="n">Stack</span><span class="p">]</span>
    <span class="n">Stack</span><span class="p">[</span><span class="o">#</span><span class="n">Stack</span><span class="p">]</span> <span class="o">=</span> <span class="kc">nil</span>
    <span class="k">if</span> <span class="n">Top</span><span class="p">.</span><span class="n">ActionIndex</span> <span class="o">==</span> <span class="mi">1</span> <span class="k">then</span>
        <span class="n">Visit</span><span class="p">(</span><span class="n">Top</span><span class="p">.</span><span class="n">FreeVars</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
    <span class="k">end</span>
<span class="k">end</span>

<span class="n">Visit</span> <span class="o">=</span> <span class="k">function</span><span class="p">(</span><span class="n">CurNode</span><span class="p">)</span>
    <span class="k">if</span> <span class="ow">not</span> <span class="n">CurNode</span> <span class="k">then</span>
        <span class="k">return</span> <span class="n">ApplyCont</span><span class="p">()</span>
    <span class="k">end</span>
    <span class="n">Sum</span> <span class="o">=</span> <span class="n">Sum</span> <span class="o">+</span> <span class="n">CurNode</span><span class="p">.</span><span class="n">val</span>
    <span class="nb">table.insert</span><span class="p">(</span><span class="n">Stack</span><span class="p">,</span> <span class="p">{</span>
        <span class="n">FreeVars</span> <span class="o">=</span> <span class="p">{</span><span class="n">CurNode</span><span class="p">.</span><span class="n">right</span><span class="p">},</span>
        <span class="n">ActionIndex</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
    <span class="p">})</span>
    
    <span class="n">Visit</span><span class="p">(</span><span class="n">CurNode</span><span class="p">.</span><span class="n">left</span><span class="p">)</span>
<span class="k">end</span>
<span class="n">Visit</span><span class="p">(</span><span class="n">root</span><span class="p">)</span>

<span class="k">return</span> <span class="n">Sum</span>

end

第四步,用tail call的方式调用函数时,当前函数中的参数和局部变量的生命周期就结束了,所以可以用一些外部的register来代替函数的参数和局部变量。同时调整一下Stack的结构,上面代码中的FreeVars主要是为了更好地说明问题。

function SumTree(root)
    local Sum = 0
<span class="kd">local</span> <span class="n">Stack</span> <span class="o">=</span> <span class="p">{}</span>
<span class="kd">local</span> <span class="n">R_CurNode</span><span class="p">,</span> <span class="n">R_Temp</span>

<span class="kd">local</span> <span class="n">Visit</span>

<span class="kd">local</span> <span class="k">function</span> <span class="nf">ApplyCont</span><span class="p">()</span>
    <span class="k">if</span> <span class="o">#</span><span class="n">Stack</span> <span class="o">==</span> <span class="mi">0</span> <span class="k">then</span>
        <span class="k">return</span>
    <span class="k">end</span>
    <span class="n">R_Temp</span> <span class="o">=</span> <span class="n">Stack</span><span class="p">[</span><span class="o">#</span><span class="n">Stack</span><span class="p">]</span>
    <span class="n">Stack</span><span class="p">[</span><span class="o">#</span><span class="n">Stack</span><span class="p">]</span> <span class="o">=</span> <span class="kc">nil</span>
    <span class="k">if</span> <span class="n">R_Temp</span><span class="p">.</span><span class="n">ActionIndex</span> <span class="o">==</span> <span class="mi">1</span> <span class="k">then</span>
        <span class="n">R_CurNode</span> <span class="o">=</span> <span class="n">R_Temp</span><span class="p">.</span><span class="n">Node</span>
        <span class="n">Visit</span><span class="p">()</span>
    <span class="k">end</span>
<span class="k">end</span>

<span class="n">Visit</span> <span class="o">=</span> <span class="k">function</span><span class="p">()</span>
    <span class="k">if</span> <span class="ow">not</span> <span class="n">R_CurNode</span> <span class="k">then</span>
        <span class="k">return</span> <span class="n">ApplyCont</span><span class="p">()</span>
    <span class="k">end</span>
    <span class="n">Sum</span> <span class="o">=</span> <span class="n">Sum</span> <span class="o">+</span> <span class="n">R_CurNode</span><span class="p">.</span><span class="n">val</span>
    <span class="nb">table.insert</span><span class="p">(</span><span class="n">Stack</span><span class="p">,</span> <span class="p">{</span>
        <span class="n">ActionIndex</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
        <span class="n">Node</span> <span class="o">=</span> <span class="n">R_CurNode</span><span class="p">.</span><span class="n">right</span><span class="p">,</span>
    <span class="p">})</span>
    <span class="n">R_CurNode</span> <span class="o">=</span> <span class="n">R_CurNode</span><span class="p">.</span><span class="n">left</span>
    <span class="n">Visit</span><span class="p">()</span>
<span class="k">end</span>
<span class="n">R_CurNode</span> <span class="o">=</span> <span class="n">root</span>
<span class="n">Visit</span><span class="p">()</span>

<span class="k">return</span> <span class="n">Sum</span>

end

第五步,将上面代码翻译成C++代码,其中,每个函数都没有参数、没有返回值、没有局部变量,每个函数调用都是tail call,所以可以直接用goto语句代替函数调用。有多种Action的时候,ActionIndex就有存在的必要了。

int SumTree(TreeNode* root) 
{
    int Sum = 0;
<span class="k">struct</span> <span class="n">StackNode</span> 
<span class="p">{</span>
    <span class="kt">int</span> <span class="n">ActionIndex</span><span class="p">;</span>
    <span class="n">TreeNode</span><span class="o">*</span> <span class="n">Node</span><span class="p">;</span>
<span class="p">};</span>
<span class="n">vector</span><span class="o">&lt;</span><span class="n">StackNode</span><span class="o">&gt;</span> <span class="n">Stack</span><span class="p">;</span>

<span class="n">TreeNode</span><span class="o">*</span> <span class="n">R_CurNode</span> <span class="o">=</span> <span class="n">root</span><span class="p">;</span>  <span class="c1">//对应上面代码中的第一次调用Visit

goto L_Visit;

L_ApplyCont:
if (Stack.empty())
{
return Sum; // 将最后的return移到此处
}
// 此处微调一下,不再需要R_Temp
switch (Stack.back().ActionIndex)
{
case 1:
R_CurNode = Stack.back().Node;
Stack.pop_back();
goto L_Visit;
//break; 不需要
}

L_Visit:
if (!R_CurNode)
{
goto L_ApplyCont;
}
Sum += R_CurNode->val;
Stack.push_back({
1, R_CurNode->right
});
R_CurNode = R_CurNode->left;
goto L_Visit;
}

最后,用Leetcode上的N-ary Tree Level Order Traversal 验证一下这种方法的正确性和优化效果。
这个问题比较聪明的答案是利用queue实现一个迭代算法,Leetcode官方版本的迭代代码最短执行时间是40ms,但是我用同样的代码提交,执行时间大约在45~70ms之间。
先实现一个无脑的递归算法,下面这段代码的执行时间大概在60~100ms之间,最短时间58秒。

vector<vector<int>> levelOrder(Node* root) {
    vector<vector<int>> Results;
    function<void(Node*, int)> Visit;
    Visit = [&](Node* CurNode, int Depth)
    {
        if (!CurNode)
        {
            return;
        }
        if (Results.size() < Depth)
        {
            Results.resize(Depth);
        }
        Results[Depth - 1].push_back(CurNode->val);
        for (auto Child : CurNode->children)
        {
            Visit(Child, Depth + 1);
        }
    };
    Visit(root, 1);
    return Results;
}

然后利用上面的方法转化成下面的代码,能通过所有testcase,执行时间大概在45~85ms之间,最短时间44ms。这个递归算法的实现中没有tail call,所以主要的优化效果来源于goto语句。

vector<vector<int>> levelOrder(Node* root) {
    vector<vector<int>> Results;
    struct StackNode {
        int ChildIndex;
        Node* CurNode;
        int CurDepth;
    };
    vector<StackNode> Stack;
    int R_ChildIndex;
    Node* R_CurNode = root;
    int R_CurDepth = 1;
    goto L_Recursive;

L_ApplyCont:
if (Stack.empty())
{
return Results;
}
R_ChildIndex = Stack.back().ChildIndex;
R_CurNode = Stack.back().CurNode;
R_CurDepth = Stack.back().CurDepth;
Stack.pop_back();
goto L_LoopChild;

L_LoopChild:
if (R_ChildIndex < R_CurNode->children.size())
{
Stack.push_back({
R_ChildIndex+1, R_CurNode, R_CurDepth
});
R_CurNode = R_CurNode->children[R_ChildIndex];
R_CurDepth++;
goto L_Recursive;
}
else
{
goto L_ApplyCont;
}

L_Recursive:
if (!R_CurNode)
{
goto L_ApplyCont;
}
if (Results.size() < R_CurDepth)
{
Results.resize(R_CurDepth);
}
Results[R_CurDepth - 1].push_back(R_CurNode->val);
R_ChildIndex = 0;
goto L_LoopChild;
}

这个转换代码的方法有一个灵活之处,我们可以将那些不需要优化的函数调用视为基本操作(类似+ - * /),最后生成的C++代码中添加上对这些Lua函数的调用。