Coin Tosses, Binomials and Dynamic Programming

Today someone asked about the probability of outcomes in relation to coin tosses on Stackoverflow. It’s an interesting question because it touches on several areas that programmers should know from maths (probability and counting) to dynamic programming.

Dynamic programming is a divide-and-conquer technique that is often overlooked by programmers. It can sometimes be hard to spot situation where it applies but when it does apply it will typically reduce an algorithm from exponential complexity (which is impractical in all but the smallest of cases) to a polynomial solution.

I will explain these concepts in one of the simplest forms: the humble coin toss.

Bernoulli Trials

A Bernoulli trial is an event (or experiment) that randomly has two outcomes. The probability of each outcome is known. The probability of success is p and the probability of failure is 1-p where 0 < p < 1. The outcome of each trial is independent of any other trial.

What constitutes success is arbitrary. For the purposes of this post, success will be defined as a coin coming up heads. To clarify the independence property, the probability of a coin coming up heads does not change regardless of any previous tosses of that (or any other) coin.

Assume a fair coin (p = 0.5). If you want to determine the probability of k successes (heads) from n trials in the general case, it is worth first recognizing that if you make n coin tosses there are 2n permutations. We are not interesting in permutations however. For n = 2, heads then tails is identical to tails then heads. If we represent 0 as tails and 1 as heads, the outcomes for the first few values of n are:

Number of Coins Number of Heads
0 1 2 3 4

1

0

1

 

 

 

2

00

01
10

11

 

 

3

000

100
010
010

011
101
110

111

 

4

0000

1000
0100
0010
0001

0011
0101
0110
1001
1010
1100

0111
1101
1011
1110

1111

If you ignore the actual values and reduce it to the number of permutations:

Number of Coins Number of Heads Total
0 1 2 3 4

1

1

1

 

 

 

2

2

1

2

1

 

 

4

3

1

3

3

1

 

8

4

1

4

6

4

1

16

Divide the number of desired outcomes by the total number of outcomes and you have your probability.

Binomial Distribution

The above number are obviously not random and correspond to the coefficients of a binomial function:

(1 + x)n

For example:

(1 + x)2 = x2 + 2x + 1 (1, 2, 1)
(1 + x)3 = x3 + 3x2 + 3x + 1 (1, 3, 3, 1)
(1 + x)4 = x4 + 4x3 + 6x2 + 4x + 1 (1, 4, 6, 4, 1)

Any of the above coefficients can be found with factorials:

f(n,k) = nCk = n! / (n-k)!k!

Binomial Probability

Since we can calculate the number of permutations for our desired outcome and we know the total outcomes are 2n then:

P(n,k) = nCk x 2-n

assuming:

P = the probability of k coins coming up heads for n coin tosses
n = number of coin tosses
k = number of desired successes

Unfair Coins

The above is accurate for fair coins but what about unfair coins? Assume that a given coin has a 60% chance of coming up heads. What does that do to our formula? Very little actually. The above formula is simply a special case of a more general formula.

P(n,k,p) = nCk x pk x (1-p)n-k

assuming:

P = the probability of k coins coming up heads for n coin tosses
n = number of coin tosses
k = number of desired successes
p = probability of a coin coming up heads

Plug in p = 0.5 and the formula reduces to the earlier version.

Pascal’s Triangle

Dealing with factorials is cumbersome and impractical for even small values of n (e.g., 1000! has 2,567 digits). Fortunately there is a much faster method of calculating coefficients known as Pascal's triangle.

If you look at the coefficients to the left, you’ll see that any coefficient is the sum of the two coefficients above it (eg 10 = 4 + 6). This leads to the following function:

C(n,k) = 1 when k = 0 or k = n
          = C(n-1,k-1) + C(n-1,k) for 1 < k < n

which could lead to the following naive implementation:

public static int coefficient(int n, int k) {
  return k == 0 || k == n ? 1 : coefficient(n - 1, k - 1) + coefficient(n - 1, k);
}

Unfortunately this simple function has terrible performance: O(2n) to be precise. The reason is simple: it does an awful lot of redundant calculations  There are two ways to solve this performance problem:

Memoization

Memoization is an optimization technique that avoids making repeated calls to the same function with the same arguments. This assumes a pure function, meaning any two calls with the same arguments will evaluate to the same result and there are no relevant side effects. One such implementation is:

private static class Pair {
  public final int n;
  public final int k;

  private Pair(int n, int k) {
    this.n = n;
    this.k = k;
  }
  
  @Override
  public int hashCode() {
    return (n + 37881) * (k + 47911);
  }
  
  @Override
  public boolean equals(Object ob) {
    if (!(ob instanceof Pair)) {
      return false;
    }
    Pair p = (Pair)ob;
    return n == p.n && k == p.k;
  }
}

private static final Map<Pair, Integer> CACHE = new HashMap<Pair, Integer>();

public static int coefficient2(int n, int k) {
  if (k == 0 || k == n) {
    return 1;
  } else if (k == 1 || k == n - 1) {
    return n;
  }
  Pair p = new Pair(n, k);
  Integer i = CACHE.get(p);
  if (i == null) {
    i = coefficient2(n - 1, k - 1) + coefficient2(n - 1, k);
    CACHE.put(p, i);
  }
  return i;
}

To compare, calculating C(34,20) took 10.7 seconds on my PC with the first method and 2.4 milliseconds with the second. C(340,200) still only takes a mere 21.1 ms. There probably isn’t enough time left in the universe for the first one to finish that.

The second method uses O(n2) space and O(n2) time.

Introducing Dynamic Programming

The key to understanding dynamic programming and applying it to a problem is identifying a recurrence relation that expresses the solution for a given value n in terms of the solution for a lower value of n. We’ve already identified this.

It is worth noting that the only thing we need to calculate the values for any row of Pascal’s triangles are the values for the previous row. Therefore, we should never need to keep more than one row in memory at a time, which reduces our desired algorithm to O(n) space.

Consider this implementation:

public static int coefficient3(int n, int k) {
  int[] work = new int[n + 1];
  for ( int i = 1; i <=n ; i++ ) {
    for ( int j = i; j >= 0; j-- ) {
      if (j == 0 || j == i) {
        work[j] = 1;
      } else if (j == 1 || j == i - 1) {
        work[j] = i;
      } else {
        work[j] += work[j - 1];
      }
    }
  }
  return work[k];
}

This algorithm simply calculates one row at a time from the first until the required row is found. Only one row (ie work) is ever kept. The inner loop satisfies the boundary conditions (arguably unnecessarily) and no recursion or more complex cache is required.

This algorithm is still O(n2) but is O(n) space and in practice will be faster than the previous version. My example of C(340,200) runs in 1.5 ms compared 21.6 ms.

Different Coins

Up until now it has been assumed that the same coin (or at least an identical coin) is used for every test. In mathematical terms, p is constant. What if it isn’t? Imagine we have two coins, the first has a 70% chance of coming up heads and the second has a 60% chance, what is the probability of getting one heads when tossing the pair of coins?

Our formula clearly breaks down because it has been assumed until now that HT and TH are interchangeably and (more importantly) equally likely. That is no longer the case.

The simplest solution traverses all permutations, finds those that satisfy the condition and works out their cumulative probability. Consider this terrible implementation:

private static final int SLICES = 10;
private static int[] COINS = {7, 6};

private static void runBruteForce(int... coins) {
  long start = System.nanoTime();
  int[] tally = new int[coins.length + 1];
  bruteForce(tally, coins, 0, 0);
  long end = System.nanoTime();
  int total = 0;
  for (int count : tally) {
    total += count;
  }
  for (int i = 0; i < tally.length; i++) {
    System.out.printf("%d : %,.4f%n", i, (double) tally[i] / total);
  }
  System.out.printf("%nBrute force ran for %d COINS in %,.3f ms%n%n",
      coins.length, (end - start) / 1000000d);
}

private static void bruteForce(int[] table, int[] coins, int index, int count) {
  if (index == coins.length) {
    table[count]++;
    return;
  }
  for (int i = 0; i < coins[index]; i++) {
    bruteForce(table, coins, index + 1, count);
  }
  for (int i = coins[index]; i < SLICES; i++) {
    bruteForce(table, coins, index + 1, count + 1);
  }
}

For simplicity, this code is assuming the probability of heads is an integer multiple of 0.1. The result:

0 : 0.4200
1 : 0.4600
2 : 0.1200

Inspection determines this result is correct. Of course, this algorithm is O(10n) time. It can be improved to avoid some redundant calls but the resulting algorithm is still exponential.

Dynamic Programming: Part Two

The key to dynamic programming is to be able to state the problem with a suitable recurrence relation that expresses the solution in terms of smaller values of the solution.

General Problem: Given C, a series of n coins p1 to pn where pi represents the probability of the i-th coin coming up heads, what is the probability of k heads coming up from tossing all the coins?

The recurrence relation required to solve this problem isn’t necessarily obvious:

P(n,k,C,i) = pi x P(n-1,k-1,C,i+1) + (1-pi) x P(n,k,C,i+1)

To put it another way, if you take a subset of the the coins C from pi to pn, the probability that there will be k heads in the remaining coins can be expressed this way: if the current coin pi is heads then the subsequent coins need only have k-1 heads. But if the current coin comes up tails (at a chance of 1-pi) then the subsequent coins must contains k heads.

You may need to think about that before it sinks in. Once it does sink in it may not be obvious that this helps solve the problem. The key point to remember is that each step of this relation expresses the solution in terms of one less coin and possibly one less value of k. That divide-and-conquer aspect is the key to dynamic programming.

Algorithm Explained

Assume three coins (0.2, 0.3, 0.4 of coming up heads). Consider the simplest case: k = 0. The probability that all three coins are tails is equal to:

P(3,0,C,1) = (1-p1) x P(3,0,C,2)
= (1-p1) x (1-p2) x P(3,0,C,3)
= (1-p1) x (1-p2) x (1-p3)

which is clearly correct. It gets slightly more complicated when k > 0. We need to remember that we now have a table for k = 0. We can construct a table for k = 1 in terms of that table and so on.

P(3,1,C,1) = p1 x P(3,0,C,2) + (1-p1) x P(3,1,C,2)
= ...

Once again, if the current coin comes up heads we need k-1 coins from the remaining coins. If it doesn’t, we need k heads from the remaining coins. the above relation is adding those two things together.

In implementation these two cases (k = 0 and k > 0) tend to be treated the same since k = 0 is a special case of k > 0 given the probability of anything where k = –1 is 0.

Expressing this as a table:

  Coins  
Number of Heads 0.200 0.300 0.400  
  0.000 0.000 0.000 0.000
0 0.336 0.420 0.600 1.000
1 0.452 0.460 0.400 0.000
2 0.188 0.120 0.000 0.000
3 0.024 0.000 0.000 0.000

Remember each cell in the table is a function of the cell to its right and the cell above that one weighted by the probability of that particular coin. The first column represents the probabilities for 0 to 3 heads from the 3 coins.

Most importantly, this algorithm is O(nk) time and O(nk) space.

Implementation

private static void runDynamic() {
  long start = System.nanoTime();
  double[] probs = dynamic(0.2, 0.3, 0.4);
  long end = System.nanoTime();
  int total = 0;
  for (int i = 0; i < probs.length; i++) {
    System.out.printf("%d : %,.4f%n", i, probs[i]);
  }
  System.out.printf("%nDynamic ran for %d coinsin %,.3f ms%n%n",
      coins.length, (end - start) / 1000000d);
}

private static double[] dynamic(double... coins) {
  double[][] table = new double[coins.length + 2][];
  for (int i = 0; i < table.length; i++) {
    table[i] = new double[coins.length + 1];
  }
  table[1][coins.length] = 1.0d; // everything else is 0.0
  for (int i = 0; i <= coins.length; i++) {
    for (int j = coins.length - 1; j >= 0; j--) {
      table[i + 1][j] = coins[j] * table[i][j + 1] +
          (1 - coins[j]) * table[i + 1][j + 1];
    }
  }
  double[] ret = new double[coins.length + 1];
  for (int i = 0; i < ret.length; i++) {
    ret[i] = table[i + 1][0];
  }
  return ret;
}

Output:

0 : 0.3360
1 : 0.4520
2 : 0.1880
3 : 0.0240

Dynamic ran for 3 coins in 0.018 ms

Conclusion

I hope this serves as a useful introduction to dynamic programming. It can be an incredibly powerful and useful technique but can require some practice to identify where and how to apply it.

The last very short code segment does an awful lot. It can handle large numbers of coins, uniform or not, and deliver results in a timely fashion.

2 comments:

Anonymous said...

There's no need to use pascal's triangle to compute choose(n, k).
Just use these two facts:
choose(n - k, 0) = 1
choose(n, k) = choose(n - 1, k - 1) * n / k

The code looks like this:
result <- 1
for i in 1.. k
result <- result * n - k + i
result <- result / i

The division is always exact, even in integer arithmetic.

Anonymous said...

I have only ever seen the terminology Divide and Conquer applied to problems where the problem space is divided in half (or smaller) for the subproblems, and then combined - for example, Mergesort, where the input sequence is divided in half until there is one element, and then doing the ordered merge between subproblems. Such problems usually yield a runtime with a logarithmic component, since repeatedly dividing a problem of size n in half can only be done log(n) times before you have a subproblem of size 1.

Dynamic programming is where a problem can be solved by solving some combination of subproblems. However, the subproblems usually only add one layer to the onion, as it were - problem n requires that problem n-1 to be solved, which requires n-2 to be solved, etc. Their runtime tends to be o(total number of subproblems) * o(# subproblems examined at each step).

Post a Comment