KNN Optimizations
2 minute read
Naive KNN needs some improvements to fix some of its drawbacks.
- Standardization
- Distance-Weighted KNN
- Mahalanobis Distance
⭐️Say one feature is ‘Annual Income’ (0-1M), and another feature is ‘Years of Experience’ (0-40).
👉The Euclidean distance will be almost entirely dominated by income 💵.
💡So, we do standardization of each feature, such that it has a mean, \(\mu\)=0 and variance,\(\sigma\)=1.
\[z=\frac{x-\mu}{\sigma}\]⭐️Vanilla KNN treats the 1st nearest neighbor and the k-th nearest neighbor as equal.
💡A neighbor that is 0.1units away should have more influence than a neighbor that is 10 units away.
👉We assign weight 🏋️♀️ to each neighbor; most common strategy is inverse of squared distance.
\[w_i = \frac{1}{d(x_q, x_i)^2 + \epsilon}\]Improvements:
- Noise/Outlier: Reduces the impact of ‘noise’ or ‘outlier’ (distant neighbors).
- Imbalanced Data: Closer points dominate, mitigating impact of imbalanced data.
- e.g: If you have a query point surrounded by 2 very close ‘Class A’ points and 3 distant ‘Class B’ points, weighted 🏋️♀️ KNN will correctly pick ‘Class A'.
⭐️Euclidean distance makes assumption that all the features are independent and provide unique information.
💡‘Height’ and ‘Weight’ are highly correlated.
👉If we use Euclidean distance, we are effectively ‘double-counting’ the size of the person.
🏇Mahalanobis distance measures distance in terms of standard deviations from the mean, accounting for the covariance between features.
\[d(x, y) = \sqrt{(x - y)^T \Sigma^{-1} (x - y)}\]\(\Sigma\): Covariance matrix of the data
- If \(\Sigma\) is identity matrix, Mahalanobis distance reduces to Euclidean distance.
- If \(\Sigma\) is a diagonal matrix, Mahalanobis distance reduces to Normalized Euclidean distance.
🦀Naive KNN shifts all computation 💻 to inference time ⏰, and it is very slow.
- To find the neighbor for one query, we must touch every single bit of the ‘nxd’ matrix.
- If n=10^9,a single query would take seconds, but we need milliseconds.
- Distance Weighted KNN
- K-D Trees (d<20): Recursively partitions space into axis-aligned hyper-rectangles. O(log N ) search.
- Ball Trees : High dimensional data; Haversine distance for geospatial 🌎 data.
- Approximate Nearest Neighbors (ANN)
- Locality Sensitive Hashing (LSH): Uses ‘bucketizing’ 🗑️ hashes. Points that are close have a high probability of having the same hash.
- Hierarchical Navigable Small World (HNSW); Graph of vectors; Search is a ‘greedy walk’ across levels.
- Product Quantization (Reduce memory 🧠 footprint 👣 of high dimensional vectors)
- ScaNN (Google)
- FAISS (Meta)
- Dimensionality Reduction (Mitigate ‘Curse of Dimensionality’)
- PCA
End of Section