def f1(n):
= []
out for a in range(1, n+1):
for b in range(1, n+1):
for c in range(1, n+1):
for d in range(1, n+1):
if a**3 + b**3 == c**3 + d**3:
out.append((a, b, c, d))return out
In this blogpost, we will take a question from Cracking the Coding Interview. I discussed this question with Masters students at IITGN. We came up with some great answers. I’ll show how we increasingly went towards better solutions starting from naive ones.
Problem statement
Find all integer solutions to the problem \(a^3 + b^3 = c^3 + d^3\)
where \(1<=a<=n, 1<=b<=n, 1<=c<=n, 1<=d<=n\)
First attempt : Naive bruteforce \(O(n^4)\)
Let’s write a very simple first attempt. We will write four nested loops. This will be \(O(n^4)\) solution.
= %timeit -o f1(50) f1_time
6.65 s ± 203 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
f1_time.average
6.646897936570895
Second attempt : Reduce computations in brute force method
Let’s now try to optimise f1. We will still use a solution of \(O(n^4)\) solution. However, we add one small optimisation fo f1. We break from the innermost loop once we find a match. This will hopefull save us some computations.
def f2(n):
= []
out for a in range(1, n+1):
for b in range(1, n+1):
for c in range(1, n+1):
for d in range(1, n+1):
if a**3 + b**3 == c**3 + d**3:
out.append((a, b, c, d))break
return out
= %timeit -o f2(50) f2_time
6.29 s ± 26.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Ok. We’re little better than f1. Every reduced computation is time saved!
Third attempt : Reduce repeated computations by saving cubes of numbers
One of the student came up with an excellent observation. Why should we keep on computing cubes of numbers? This is a repeated operation. Let’s instead store them in a dictionary.
def f3(n):
= {}
cubes for x in range(1, n+1):
= x**3
cubes[x] = []
out for a in range(1, n+1):
for b in range(1, n+1):
for c in range(1, n+1):
for d in range(1, n+1):
if cubes[a] + cubes[b] == cubes[c] + cubes[d]:
out.append((a, b, c, d))break
return out
= %timeit -o f3(50) f3_time
1.05 s ± 4.11 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Ok. We now mean business! This is about 6 times quicker than our previous version.
Fourth attempt : Reduce one loop \(O(n^3)\)
In this solution, we will reduce one loop. We can solve for \(d^3 = a^3 + b^3 - c^3\) and find all the integer solutions. Now, there’s another clever optimisation that I have added. We can precompute the cubes and the cuberoots corresponding to numbers from 1 to N and perfect cubes from 1 to \(N^3\) respectively.
def f4(n):
= {}
cubes = {}
cuberoots for x in range(1, n+1):
= x**3
x3 = x3
cubes[x] = x
cuberoots[x3] = []
out for a in range(1, n+1):
for b in range(1, n+1):
for c in range(1, n+1):
= (cubes[a] + cubes[b] - cubes[c])
d3 if d3 in cuberoots:
out.append((a, b, c, cuberoots[d3]))return out
= %timeit -o f4(50) f4_time
21.7 ms ± 1.99 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
This is seriously fast now!
Fifth attempt : Reduce another loop \(O(n^2)\)
In this solution, we will reduce one more loop. We can compute \(a^3 + b^3\) for all a, b
. And then find c
and d
where \(c^3 + d^3\) is the same as \(a^3 + b^3\). This has a few Python tricks inside! One of the special cases to handle is of the type \(1^3 + 2^3 = 2^3 + 1^3\)
def f5(n):
= []
out = {}
cubes for x in range(1, n+1):
= x**3
cubes[x]
= {}
sum_a3_b3 for a in range(1, n+1):
for b in range(1, n+1):
= cubes[a]+cubes[b]
temp if temp in sum_a3_b3:
sum_a3_b3[temp].append((a, b))else:
= [(a, b)]
sum_a3_b3[temp]
for c in range(1, n+1):
for d in range(1, n+1):
= cubes[c] + cubes[d]
sum_c3_d3 if sum_c3_d3 in sum_a3_b3:
for (a, b) in sum_a3_b3[sum_c3_d3]:
out.append((a, b, c, d))
return out
= %timeit -o f5(50) f5_time
1.97 ms ± 235 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Plain Wow! Going from 6 seconds to about 2 ms! Let’s plot the timings on a log scale to learn more.
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
= pd.Series({'Naive (O(N^4))':f1_time.average,
s 'Naive (O(N^4)) with break':f2_time.average,
'Naive (O(N^4)) with break and precomputing cubes':f3_time.average,
'(O(N^3))':f4_time.average,
'(O(N^2))':f5_time.average})
='bar', logy=True)
s.plot(kind"Time"); plt.ylabel(
Hope this was fun!