Derivation of Dyna's Back-Propagation Algorithm (2004)

This is an online appendix to Eisner, Goldlust, and Smith (2005), Compiling Comp Ling: Weighted Dynamic Programming and the Dyna Language.

Section 4.2 and Figure 3 give a forward-chaining algorithm that computes the values of all items in a semiring-weighted logic program.

Then section 5.2 and Figure 4 give a gradient algorithm that differentiates the previous computation (when the semiring is (R,+,*)). It works backwards to rapidly compute the partial derivatives of the goal item with respect to each of the items. This gradient can be used to adjust the values of items that correspond to free parameters, within an algorithm like gradient descent or L-BFGS.

Below is the derivation of this gradient algorithm, which did not fit in the paper (as noted in the Figure 4 caption). It is a reformatted version of an August 12, 2004 email among the authors (from Jason Eisner to Eric Goldlust).

Note that the Dyna compiler produces specialized code for both the forward-chaining algorithm and the gradient algorithm, for any logic program in this semiring.

Background reading: A special case of these forward-chaining and back-propagation algorithms (for a particular finite-state model) was explained in detail in Eisner (2001), sections 4.2 and 4.3 respectively. Our derivation below uses the same notation.

You should already be familiar with the general idea of automatic differentiation (website) in the reverse mode (tutorial, history). Most people first see this in the case of back-propagation in neural networks (slides, video, explanation, book chapter, formulas and C++ code, another tutorial with code). Indeed, the algorithm below follows the same pattern as back-propagation through time for recurrent neural networks (Werbos 1989, Williams and Zipser 1989).


Setup

We need to start by writing out the timestamped forward pass, then derive a matching backward pass. I'll use (t) for a timestamp superscript.

First, here is a picture of a bit of the formula graph given a rule c += a1 * a2 * a3 where a2 is the only driver. You can see that agenda(t-1)[a2] is used to update the chart(t-1) values into chart(t) values and to define a new value for agenda(t)[c]. Note that for a given time t, the chart values feed forward into the agenda values.


                 ,---------------------- agenda(t)[c] 
                /                           *  |  *  \
               /                               |      \
   chart(t)[a1]              chart(t)[a2]       |       \    chart(t)[a3]
        |                      /  +  \         |        \     |
        |                     /       \        |         \    |
   chart(t-1)[a1]   chart(t-1)[a2]    agenda(t-1)[a2]   chart(t-1)[a3]

Forward pass with explicit timestamps

Here is the forward pass pseudocode.

[Note that we explicitly push current values onto a "tape" (stack) before we update them. Later, during the backward pass, this tape will be read backwards (popped) to help us unwind the computation. From section 5.2 of our paper: "At line 3, a stack is needed to remember the sequence of (a, old, Δ) triples from the original computation. It is a more efficient version of the 'tape' usually used in automatic differentiation."]

initialize the agenda
t:=0
while agenda is not empty

  t := t+1

  // choose an update from agenda
  choose a such that agenda(t-1)[a] != 0  // this a could permanently be called a(t-1)
  push (a, chart(t-1)[a], agenda(t-1)[a]) onto tape

  // compute chart(t)
  chart(t)    := chart(t-1)           
  chart(t)[a] += agenda{t-1)[a]        // overrides (adds to) previous line, for [a]

  // compute agenda(t)
  agenda(t)    := agenda(t-1)         // plus other terms to be added below
  agenda(t)[a] := 0                    // overrides previous line, for [a]
  if (chart(t)[a] != chart(t-1)[a])   // only propagate actual changes
    // the following updates may be done in any convenient order (e.g., by pattern matching on a)
    for each inference rule c += a1*a2*...*ak
      for i from 1 to k 
        for each way (if any) of instantiating the rule's variables 
        such that ai=a and (all j) chart(t)[aj] != 0 or chart(t-1)[aj] != 0
           agenda(t)[c] += prod_{j=1)(i-1) chart(t)[aj]
                            * agenda(t-1)[a]
                            * prod_{j=i+1)^k chart(t-1)[aj]

           // has the same effect as updating chart[a1], chart[a2], ... 
           // one at a time in that order; the only ones that 
           // actually change are the ones that equal a.

Backward pass with explicit timestamps

Okay, now let's see what the backward pass must do. We'll just differentiate the defining computations above, basically in reverse order.

// t and tape are as they were left at the end of the forward pass
initialize the gradient computation 
  (e.g., set gagenda(t):=0, gchart(t):=0, then gchart(t)[goal]:=1)
while tape is not empty                     
  pop (a, chart(t-1)[a], agenda(t-1)[a])  
                           // In the timestamped version, popping the old 
                           // chart and agenda values isn't actually necessary, 
                           // since we remember those values from the forward pass.

  gagendaold := 0         // holds accumulands into gagenda(t-1)[a]
  gchart(t-1) := 0       // to be accumulated below

  if (chart(t)[a] != chart(t-1)[a])
      // the following updates may be done in any convenient order (e.g., by pattern matching on a)
      for each inference rule c += a1*a2*...*ak
        for i from 1 to k 
          for each way (if any) of instantiating the rule's variables such that ai=a and (all j) chart(t)[aj] != 0 or chart(t-1)[aj] != 0

             for j':=1 to i-1
                gchart(t)[aj'] += gagenda(t)[c]
                                   * prod_{j=1 to i-1 except j') chart(t)[aj]
                                   * agenda(t-1)[a]
                                   * prod_{j=i+1)^k chart(t-1)[aj]

             gagendaold += prod_{j=1)(i-1) chart(t)[aj]
                           * gagenda(t)[c]
                           * prod_{j=i+1)^k chart(t-1)[aj]

             for j':=i+1 to k
                gchart(t-1)[aj'] += prod_{j=1)(i-1) chart(t)[aj]
                                     * agenda(t-1)[a]
                                     * prod_{j=i+1 to k except j') chart(t-1)[aj]
                                     * gagenda(t)[c]

  // now gchart(t) is defined
  gagenda(t-1) := gagenda(t)
  gagenda(t-1)[a] := gagendaold + gchart(t)[a] // overrides above since gagenda(t-1)[a] does not feed forward into gagenda(t)[a], but into other things
  gchart(t-1) += gchart(t) // note that in forward pass, chart(t-1) influences agenda(t) both directly and via chart(t).

  t := t-1

Check: Special case where all rules are binary

As a check, the loop over inference rules simplifies as follows in the binary-rule case:

      for each inference rule c += a1*a2

        for each way (if any) of instantiating the rule's variables such that a1=a and chart(t)[a2] != 0 or chart(t-1)[a2] != 0

             gagenda(t-1)[a] += gagenda(t)[c]
                                 * chart(t-1)[a2]

             gchart(t-1)[a2] += agenda(t-1)[a]
                               * gagenda(t)[c]

             // can get these by specializing the general version above, 
             // or by directly differentiating
             //    agenda(t)[c] += agenda(t-1)[a] * chart(t-1)[a2]
           
        for each way of instantiating the rule's variables such that a2=a and chart(t)[a1] != 0 or chart(t-1)[a1] != 0

             gchart(t)[a1] += gagenda(t)[c] 
                               * agenda(t-1)[a]

             gagenda(t-1)[a] += chart(t)[a1]
                                 * gagenda(t)[c]

             // can get these by specializing the general version above, 
             // or by directly differentiating
             //    agenda(t)[c] += chart(t)[a1] * agenda(t-1)[a]

Backward pass without timestamps

Now our next job is to simplify these routines by eliminating the timestamps [and using the tape instead]. The forward pass simplifies as already shown in the paper [Figure 3]. Here's how the backward pass simplifies [Figure 4 in paper]:

initialize the gradient computation 
  (e.g., set gagenda:=0, gchart:=0, then gchart[goal]:=1)
while tape is not empty                     
  // chart currently holds chart(t)
  // gchart and gagenda currently hold gchart(t) and gagenda(t)

  pop (a, oldchart, oldagenda)    // "old" refers to the (t-1) values

  // During the main loop below, chart and gagenda still refer to (t),
  // but gchart now refers to (t-1), as if we had done 
  // gchart(t-1):=0 as in the timestamped version but then had done 
  // the gchart(t-1)+=gchart(t) step early.
  // Because the latter step was done early, before gchart(t) was modified,
  // we'll need to ensure that any modifications to gchart(t) (which is
  // no longer stored) are passed through so that they appropriately affect 
  // gchart(t-1) and gagenda(t-1)[a].
  //
  // As in the paper, we use chart(aj,oldchart) to refer to 
  // aj==a ? oldchart : chart[aj], that is, chart(t-1)[aj].

  gagendaold := 0         // holds accumulands into gagenda(t-1)[a] (while gagenda refers to (t))
  gchartnew := gchart[a]  // holds gchart(t)[a] (while gchart refers to (t-1))

  if (chart[a] != oldchart)     // in other words, if update actually accomplished something
      // the following updates may be done in any convenient order (e.g., by pattern matching on a)
      for each inference rule c += a1*a2*...*ak
        for i from 1 to k
          for each way (if any) of instantiating the rule's variables such that ai=a and (all j) chart[aj] != 0 or aj=a

             for j':=1 to i-1
                // modifying gchart(t), must pass through
                gchart[aj'] += gagenda[c]     
                               * prod_{j=1 to i-1 except j') chart[aj]
                               * oldagenda
                               * prod_{j=i+1)^k chart(aj,oldchart)
                if (aj'=a) then gchartnew += (same thing)

             gagendaold += prod_{j=1)(i-1) chart[aj]
                           * gagenda[c]   
                           * prod_{j=i+1)^k chart(aj,oldchart)

             for j':=i+1 to k
                gchart[aj'] += prod_{j=1)(i-1) chart[aj]
                               * oldagenda
                               * prod_{j=i+1 to k except j') chart(aj,oldchart)
                               * gagenda[c]

  // now we roll gagenda and chart back to refer to (t-1)
  gagenda[a] := gagendaold + gchartnew   
  chart[a] := oldchart

  t := t-1

Check: Special case where all rules are binary

Again, let's consider what the main loop looks like in the binary-rule case:

      for each inference rule c += a1*a2*...*ak

        for each way (if any) of instantiating the rule's variables such that a1=a and chart[a2] != 0 or a2=a

           gagendaold += gagenda[c] * chart(a2,oldchart)
           gchart[a2] += oldagenda * gagenda[c]
           if (a2=a) then gchartnew += oldagenda * gagenda[c]

        for each way (if any) of instantiating the rule's variables such that a2=a and chart[a1] != 0 or a1=a

           gchart[a1] += gagenda[c] * oldagenda
           gagendaold += chart[a1] * gagenda[c]   

Example

I suggest you check this with a simple example. Suppose we have the program

   a += a*a.      % note: equivalent to a = a*a since only one summand
   a += r.        % where r is an axiom
Then successive updates will look like this, I think:
   chart[a]=0, agenda[a]=r
   chart[a]=r, agenda[a]=r2-02
   chart[a]=r+r2, agenda[a]=(r+r2)2-r2
   chart[a]=r+(r+r2)2, agenda[a]=(r+(r+r2)2)2-(r+r2)2
   chart[a]=r+(r+(r+r2)2)2, agenda[a]=(old chart[a])2+(new chart[a])2
The convergence condition is chart[a]=r+chart[a]2, which is what we want. For example, if r is 0.25, I think chart[a] will converge to 0.5.

I suggest you run the forward algorithm by hand for 2 steps (early stopping!), and see what formula you get for chart[a] in terms of r, presumably r+r2. Then run the backward algorithm by hand, and make sure that gradient[r] ends up holding the correct derivative of that. In both cases, exactly follow the pseudocode in the EMNLP paper, with the modifications listed above. Let me know if it works!

Example

As [another] overly simple check, the back-relaxation code in my thesis (pp. 123-124) should fall out by trying the [backward algorithm] on the following program:

   stop(I) += edgezero(I)*pathto(I).
   pathto(J) += edge(I,J)*pathto(I).
when only pathto items are popped. Note that a1=a2 is impossible for this program. When we pop pathto(I), we get
  gagendaold := 0                     // in thesis, z := 0
  gchartnew := gchart[pathto(I)] // in thesis, doesn't appear; is effectively 0 since chart[pathto(...)] is never used: pathto items are never passengers

  // from stop(I) += edgezero(I)*pathto(I).
  gchart[edgezero(I)] += gagenda[stop(I)] * oldagenda   // in thesis, gPi0 += gpi * Ji
  gagendaold += chart[edgezero(I)] * gagenda[stop(I)]   // in thesis, z += Pi0 * gpi

  // from pathto(J) += edge(I,J)*pathto(I).
  gchart[edge(I,J)] += gagenda[pathto(J)] * oldagenda   // in thesis, gPij += gJ_j * Ji
  gagendaold += chart[edge(I,J)] * gagenda[pathto(J)]   // in thesis, z += Pij * gJ_j

  // after loop
  gagenda[pathto(I)] := gagendaold + gchartnew // in thesis, gJi := z (note gchartnew is effectively 0)
  chart[pathto(I)] := oldchart // in thesis, doesn't appear since chart[pathto(...)] is never stored: pathto items are never passengers (hence, no need to store oldchart on tape, either)
which seems to match, as shown in the comments.

Optimization opportunities

Some further optimizations seem possible though not urgent:


This page online: http://cs.jhu.edu/~jason/papers/eisner+goldlust+smith.emnlp05-appendix.html
Jason Eisner - jason@cs.jhu.edu