A Note on the Johnson-Lindenstrauss Lemma

Introduction

A recent thread on Theoretical CS StackExchange comparing the Johnson-Lindenstrauss Lemma with the Singular Value Decomposition piqued my interest enough that I decided to spend some time last night reading the standard JL papers. Until this week, I only had a vague understanding of what the JL Lemma implied. I previously mistook the JL Lemma for a purely theoretical result that established the existence of distance-preserving projections from high-dimensional spaces into low-dimensional spaces.

This vague understanding of the JL Lemma turns out to be almost correct, but it also led me to neglect the most interesting elements of the literature on the JL Lemma: the papers on the JL Lemma do not simply establish the existence of such projections, but also provide (1) an explicit bound on the dimensionality required for a projection to ensure that it will approximately preserve distances and they even provide (2) an explicit construction of a random matrix, \(A\), that produces the desired projection.

Once I knew that the JL Lemma was a constructive proof, I decided to implement code in Julia to construct examples of this family of random projections. The rest of this post walks through that code as a way of explaining the JL Lemma’s practical applications.

Formal Statement of the JL Lemma

The JL Lemma, as stated in “An elementary proof of the Johnson-Lindenstrauss Lemma” by Dasgputa and Gupta, is the following result about dimensionality reduction:

For any \(0 < \epsilon < 1\) and any integer \(n\), let \(k\) be a positive integer such that \(k \geq 4(\epsilon^2/2 - \epsilon^3/3)^{-1}\log(n)\). Then for any set \(V\) of \(n\) points in \(\mathbb{R}^d\), there is a map \(f : \mathbb{R}^d \to \mathbb{R}^k\) such that for all \(u, v \in V\), $$ (1 - \epsilon) ||u - v||^2 \leq ||f(u) - f(v)||^2 \leq (1 + \epsilon) ||u - v||^2. $$ Further this map can be found in randomized polynomial time.

To fully appreciate this result, we can unpack the abstract statement of the lemma into two components.

The JL Lemma in Two Parts

Part 1: Given a number of data points, \(n\), that we wish to project and a relative error, \(\epsilon\), that we are willing to tolerate, we can compute a minimum dimensionality, \(k\), that a projection must map a space into before it can guarantee that distances will be preserved up to a factor of \(\epsilon\).

In particular, \(k = \left \lceil{4(\epsilon^2/2 – \epsilon^3/3)^{-1}\log(n)} \right \rceil\).

Note that this implies that the dimensionality required to preserve distances depends only on the number of points and not on the dimensionality of the original space.

Part 2: Given an input matrix, \(X\), of \(n\) points in \(d\)-dimensional space, we can explicitly construct a map, \(f\), such that the distance between any pair of columns of \(X\) will not distorted by more than a factor of \(\epsilon\).

Surprisingly, this map \(f\) can be a simple matrix, \(A\), constructed by sampling \(k * d\) IID draws from a Gaussian with mean \(0\) and variance \(\frac{1}{k}\).

Coding Up The Projections

We can translate the first part of the JL Lemma into a single line of code that computes the dimensionality, \(k\), of our low-dimensional space given the number of data points, \(n\), and the error, \(\epsilon\), that we are willing to tolerate:

1
mindim(n::Integer, ε::Real) = iceil((4 * log(n)) / (ε^2 / 2 - ε^3 / 3))

Having defined this function, we can try it out on a simple problem:

1
2
mindim(3, 0.1)
# => 942

This result was somewhat surprising to me: to represent \(3\) points with no more than \(10\)% error, we require nearly \(1,000\) dimensions. This reflects an important fact about the JL Lemma: it produces result that can be extremely conservative for small dimensional inputs. It’s obvious that, for data sets that contain \(3\) points in \(100\)-dimensional space, we could use a projection into \(100\) dimensions that would preserve distances perfectly.

But this observation neglects one of the essential aspects of the JL Lemma: the dimensions required by the lemma will be sufficient whether our data set contains points in \(100\)-dimensional space or points in \(10^{100}\)-dimensional space. No matter what dimensionality the raw data lies in, the JL Lemma says that \(942\) dimensions suffices to preserve the distances between \(3\) points.

I found this statement unintuitive at the start. To see that it’s true, let’s construct a random projection matrix, \(A\), that will let us confirm experimentally that the JL Lemma really works:

1
2
3
4
5
6
7
8
9
10
11
using Distributions
 
function projection(
    X::Matrix,
    ε::Real,
    k::Integer = mindim(size(X, 2), ε)
)
    d, n = size(X)
    A = rand(Normal(0, 1 / sqrt(k)), k, d)
    return A, k, A * X
end

This projection function is sufficient to construct a matrix, \(A\), that will satisfy the assumptions of the JL Lemma. It will also return the dimensionality, \(k\), of \(A\) and the result of projecting the input, \(X\), into the new space defined by \(A\). To get a feel for how this works, we can try this out on a very simple data set:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
X = eye(3, 3)
 
ε = 0.1
 
A, k, AX = projection(X, ε)
# =>
# (
# 942x3 Array{Float64,2}:
#  -0.035269    -0.0299966   -0.0292959 
#  -0.00501367   0.0316806    0.0460191 
#   0.0633815   -0.0136478   -0.0198676 
#   0.0262627    0.00187459  -0.0122604 
#   0.0417169   -0.0230222   -0.00842476
#   0.0236389    0.0585979   -0.0642437 
#   0.00685299  -0.0513301    0.0501431 
#   0.027723    -0.0151694    0.00274466
#   0.0338992    0.0216184   -0.0494157 
#   0.0612926    0.0276185    0.0271352 
#   ⋮                                   
#  -0.00167347  -0.018576     0.0290964 
#   0.0158393    0.0124403   -0.0208216 
#  -0.00833401   0.0323784    0.0245698 
#   0.019355     0.0057538    0.0150561 
#   0.00352774   0.031572    -0.0262811 
#  -0.0523636   -0.0388993   -0.00794319
#  -0.0363795    0.0633939   -0.0292289 
#   0.0106868    0.0341909    0.0116523 
#   0.0072586   -0.0337501    0.0405171 ,
# 
# 942,
# 942x3 Array{Float64,2}:
#  -0.035269    -0.0299966   -0.0292959 
#  -0.00501367   0.0316806    0.0460191 
#   0.0633815   -0.0136478   -0.0198676 
#   0.0262627    0.00187459  -0.0122604 
#   0.0417169   -0.0230222   -0.00842476
#   0.0236389    0.0585979   -0.0642437 
#   0.00685299  -0.0513301    0.0501431 
#   0.027723    -0.0151694    0.00274466
#   0.0338992    0.0216184   -0.0494157 
#   0.0612926    0.0276185    0.0271352 
#   ⋮                                   
#  -0.00167347  -0.018576     0.0290964 
#   0.0158393    0.0124403   -0.0208216 
#  -0.00833401   0.0323784    0.0245698 
#   0.019355     0.0057538    0.0150561 
#   0.00352774   0.031572    -0.0262811 
#  -0.0523636   -0.0388993   -0.00794319
#  -0.0363795    0.0633939   -0.0292289 
#   0.0106868    0.0341909    0.0116523 
#   0.0072586   -0.0337501    0.0405171 )

According to the JL Lemma, the new matrix, \(AX\), should approximately preserve the distances between columns of \(X\). We can write a quick function that verifies this claim:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
function ispreserved(X::Matrix, A::Matrix, ε::Real)
    d, n = size(X)
    k = size(A, 1)
 
    for i in 1:n
        for j in (i + 1):n
            u, v = X[:, i], X[:, j]
            d_old = norm(u - v)^2
            d_new = norm(A * u - A * v)^2
            @printf("Considering the pair X[:, %d], X[:, %d]...\n", i, j)
            @printf("\tOld distance: %f\n", d_old)
            @printf("\tNew distance: %f\n", d_new)
            @printf(
                "\tWithin bounds %f <= %f <= %f\n",
                (1 - ε) * d_old,
                d_new,
                (1 + ε) * d_old
            )
            if !((1 - ε) * d_old <= d_old <= (1 + ε) * d_old)
                return false
            end
        end
    end
 
    return true
end

And then we can test out the results:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
ispreserved(X, A, ε)
# =>
# Considering the pair X[:, 1], X[:, 2]...
#     Old distance: 2.000000
#     New distance: 2.104506
#     Within bounds 1.800000 <= 2.104506 <= 2.200000
# Considering the pair X[:, 1], X[:, 3]...
#     Old distance: 2.000000
#     New distance: 2.006130
#     Within bounds 1.800000 <= 2.006130 <= 2.200000
# Considering the pair X[:, 2], X[:, 3]...
#     Old distance: 2.000000
#     New distance: 1.955495
#     Within bounds 1.800000 <= 1.955495 <= 2.200000

As claimed, the distances are indeed preserved up to a factor of \(\epsilon\). But, as we noted earlier, the JL lemma has a somewhat perverse consequence for our \(3×3\) matrix: we’ve expanded our input into a \(942×3\) matrix rather than reduced its dimensionality.

To get meaningful dimensionality reduction, we need to project a data set from a space that has more than \(942\) dimensions. So let’s try out a \(50,000\)-dimensional example:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
X = eye(50_000, 3)
 
A, k, AX = projection(X, ε)
 
ispreserved(X, A, ε)
# =>
# Considering the pair X[:, 1], X[:, 2]...
#     Old distance: 2.000000
#     New distance: 2.021298
#     Within bounds 1.800000 <= 2.021298 <= 2.200000
# Considering the pair X[:, 1], X[:, 3]...
#     Old distance: 2.000000
#     New distance: 1.955502
#     Within bounds 1.800000 <= 1.955502 <= 2.200000
# Considering the pair X[:, 2], X[:, 3]...
#     Old distance: 2.000000
#     New distance: 1.988945
#     Within bounds 1.800000 <= 1.988945 <= 2.200000

In this case, the JL Lemma again works as claimed: the pairwise distances between columns of \(X\) are preserved. And we’ve done this while reducing the dimensionality of our data from \(50,000\) to \(942\). Moreover, this same approach would still work if the input space had \(10\) million dimensions.

Conclusion

Contrary to my naive conception of the JL Lemma, the literature on the lemma not only tells us that, abstractly, distances can be preserved by dimensionality reduction techniques. It tells how to perform this reduction — and the mechanism is both simple and general.