Problem: By listing the first six prime numbers: 2, 3, 5, 7, 11, and 13, we can see that the 6th prime is 13. What is the 10,001st prime number?

Python Solution

Since we’re looking for primes, we could try factoring increasing numbers until we come across the 10,001st prime. (This will be slow. We’ll make a faster version later.) In Project Euler Problem 5 we had a cheap (and inefficient) factorization routine. We’ll use that here.

import time
 
# function to factor a given positive integer n
def factors(n):
    factors = []
    # remove any factors of 2 first
    while n % 2 == 0:
        factors.append(2)
        n = n/2
    # now look for odd factors
    p = 3
    while n != 1:
        while n % p == 0:
            factors.append(p)
            n = n/p
        p += 2
    return factors
 
def nth_prime(n):
    prime = 2 # last prime
    count = 1 # number of primes
    num = 3 # next number to check
    while count < n:
        if len(factors(num)) == 1:
            prime = num
            count += 1
        num += 2 # only check odd numbers
    return prime
 
start = time.time()
prime = nth_prime(10001)
elapsed = (time.time() - start)
 
print "found %s in %s seconds" % (prime,elapsed)

This does produce the correct result, but (as expected) it takes an incredible amount of time to do so.

found 104743 in 88.1212220192 seconds

We can speed this up immediately with the realization that we don’t need to completely factor each number, only determine whether it is prime or not. We will therefore replace the factor routine with an “is_prime” routine that returns false upon finding a non-trivial factor.

import time
 
# function to factor a given positive integer n
def is_prime(n):
    # look for factors of 2 first
    if n % 2 == 0: return False
    # now look for odd factors
    p = 3
    while p < n**0.5+1:
        if n % p == 0: return False
        p += 2
    return True
 
def nth_prime(n):
    prime = 2
    count = 1
    iter = 3
    while count < n:
        if is_prime(iter):
            prime = iter
            count += 1
        iter += 2
    return prime
 
start = time.time()
prime = nth_prime(10001)
elapsed = (time.time() - start)
 
print "found %s in %s seconds" % (prime,elapsed)

This runs much more quickly.

found 104743 in 0.990342855453 seconds

That performs better, but it still isn’t fast. Instead of factoring increasing numbers, we could use a prime sieve instead. The idea is that we’ll make a list having length longer than the expected value of the 10,001st prime, then we’ll set the value of items in the list to True or False depending on whether they are known to be prime in the following way: Set 2 to True, but then set all multiples of 2 to False. Set 3 to True, but then set all multiples of 3 to False. And so on. In the end, we’ll simply look for the 10,001st item in the list having value True, as this will be the prime in question.

import time
 
def fast_nth_prime(n, limit=125000):
    if limit % 2 != 0: limit += 1
    primes = [True] * limit
    primes[0],primes[1] = [None] * 2
    count = 0 # how many primes have we found?
    for ind,val in enumerate(primes):
        if val is True:
            # sieve out non-primes by multiples of known primes
            primes[ind*2::ind] = [False] * (((limit - 1)//ind) - 1)
            count += 1
        if count == n: return ind
    return False
 
start = time.time()
prime = fast_nth_prime(10001)
elapsed = (time.time() - start)
 
print "found %s in %s seconds." % (prime,elapsed)

This does indeed run much more quickly than our original approach.

found 104743 in 0.0413010120392 seconds.

Cython Solution

Recoding our fastest Python version to Cython, we get the following code.

%cython
 
import time
from libc.stdlib cimport malloc, free
 
# a function to sum all prime numbers below a given number
cdef fast_nth_prime(unsigned long n,unsigned long limit = 125000):
    cdef bint *primes = <bint *>malloc(limit * sizeof(bint))
    primes[0] = 0
    primes[1] = 0
    cdef unsigned long count = 0
    cdef unsigned long j, index = 0
    while index < limit:
        primes[index] = 1
        index += 1
    index = 2
    while index < limit:
        if primes[index] == 1:
            j = index ** 2
            while j < limit:
                primes[j] = 0
                j += index
            count += 1
        if count == n:
            free(primes)
            return index
        index += 1
    free(primes)
    return False
 
start = time.time()
prime = fast_nth_prime(10001)
elapsed = (time.time() - start)
 
print "found %s in %s seconds" % (prime,elapsed)

When executed, we get the following result.

found 104743 in 0.00318908691406 seconds

Thus, the Cython code is roughly 13 times faster than the Python code.


One Trackback

  1. [...] use our prime sieve from Project Euler Problem 7 and simply rewrite things a bit. import time   def prime_sum(n): if n < 2: return 0 [...]