TrieKNN: Unleashing KNN's Power on Mixed Data Types
Discover TrieKNN, a novel approach to extend the K-Nearest Neighbors algorithm to datasets with both categorical and numerical features. Learn how it works, see it in action, and explore its potential.

In This Post
- We’ll dissect the limitations of traditional KNN when faced with mixed data types.
- Introduce TrieKNN, a Trie-based approach that elegantly handles mixed data.
- Walk through the implementation and training of a TrieKNN model.
- Evaluate its performance and discuss its potential impact.
The Allure and Limitation of KNN
In the realm of machine learning, the K-Nearest Neighbors (KNN) algorithm stands out for its intuitive nature and ease of implementation. Its principle is simple: classify a data point based on the majority class among its ‘k’ nearest neighbors in the feature space. This non-parametric approach makes no assumptions about the underlying data distribution, rendering it versatile for various applications. KNN is very popular, but it comes with some limitations.
However, KNN’s Achilles’ heel lies in its reliance on distance metrics, which are inherently designed for numerical data. Real-world datasets often contain a mix of numerical and categorical features, posing a significant challenge for KNN. How do you measure the distance between ‘red’ and ‘blue,’ or ‘large’ and ‘small’?
Prior Art
Several strategies have been proposed to adapt KNN for mixed data:
- One-Hot Encoding: Converts categorical features into numerical vectors, but can lead to high dimensionality.
- Distance Functions for Mixed Data: Develops and apply custom distance metrics that can handle both numerical and categorical features such as HEOM and many others.
- Using mean/mode values: Replace the missing values with mean/mode.
These methods often involve compromises, either distorting the data’s inherent structure or adding computational overhead.
Enter TrieKNN: A Novel Approach
What if we could cleverly sidestep the distance calculation problem for categorical features, while still leveraging KNN’s power? TrieKNN offers just that—a way to perform KNN on any mixed data!
TrieKNN combines the strengths of Trie data structures and KNN to handle mixed data types gracefully. Here’s the core idea:
- Trie-Based Categorical Encoding: A Trie is used to store the categorical features of the data. Each node in the Trie represents a category.
- Leaf-Node KNN Models: At the leaf nodes of the Trie, where specific combinations of categorical features are found, we fit individual KNN models using only the numerical features.
- Weighted Prediction: To classify a new data point, we traverse the Trie based on its categorical features. At each level, we calculate a weighted distance based on available data, ending in a probability score in each leaf node.
Why This Works
- No Direct Distance Calculation for Categorical Features: The Trie structure implicitly captures the relationships between categorical values.
- Localized KNN Models: By fitting KNN models at the leaf nodes, we ensure that distance calculations are performed only on relevant numerical features.
- Scalability: The Trie structure efficiently handles a large number of categorical features and values.
Building a TrieKNN Model
Let’s dive into the implementation. We’ll start with the TrieNode and Trie classes, then move on to the KNN model and the training/prediction process.
Trie Implementation
import numpy as np
from collections import Counter
class TrieNode:
def __init__(self):
self.children = {} # Dictionary to store child nodes
self.is_end_of_word = False # True if the node is the end of a word
self.count = 0 # Count of how many times a word has been inserted
self.class_counts = {} # Class counts
self.class_weights = {}
self.model = None # Model at leaf nodes
self.indexes = [] # Store data indexes belonging to this leaf
self.labels = [] # Store data indexes belonging to this leaf
self.node_weight = None
class Trie:
def __init__(self):
self.root = TrieNode() # Root node of the Trie
self.data_index = 0 # Initialize data index
def insert(self, word_val, model):
current_node = self.root
word, val = word_val
current_node.count += 1
# Adding class counts
if val not in current_node.class_counts:
current_node.class_counts[val] = 0
current_node.class_counts[val] += 1
for char in word:
# If the character is not in children, add a new TrieNode
if char not in current_node.children:
current_node.children[char] = TrieNode()
current_node = current_node.children[char]
# Adding count of instances
current_node.count += 1
# adding class counts
if val not in current_node.class_counts:
current_node.class_counts[val] = 0
current_node.class_counts[val] += 1
# Mark the end of the word and increment count
current_node.is_end_of_word = True
current_node.indexes.append(self.data_index) # Store the data index
current_node.labels.append(val)
current_node.model = model
self.data_index += 1 # Increment data index
def search(self, word):
current_node = self.root
for char in word:
# If the character doesn't exist in the children, the word doesn't exist
if char not in current_node.children:
return False
current_node = current_node.children[char]
# Return True if it's the end of a word and the word exists
return current_node.is_end_of_word
def count_word(self, word):
current_node = self.root
for char in word:
# If the character doesn't exist, the word doesn't exist
if char not in current_node.children:
return 0, current_node.class_counts # Correctly return class_counts
current_node = current_node.children[char]
# Return the count of the word
return current_node.count, current_node.class_counts
def display(self):
# Recursively display the tree
def _display(node, word):
if node.is_end_of_word:
print(f"Data: {word}, Count: {node.count}, Indexes: {len(node.indexes)} Classes :{node.class_counts} weights:{len(node.class_weights)}") # Display indexes too
for char, child in node.children.items():
_display(child, word + char) # corrected the display
_display(self.root, "")
def apply(self, func):
"""
Applies a function to all models in the leaf nodes.
"""
def _apply(node):
if node.is_end_of_word and node.model is not None:
func(node)
for child in node.children.values():
_apply(child)
_apply(self.root)
def apply_weight_to_indexes(self, weight):
"""
Applies a weight to the indexes based on the percentage of data available.
"""
def _apply_weight_to_indexes(node):
if node.is_end_of_word:
total_count = sum(self.root.children[child].count for child in self.root.children)
percentage = node.count / total_count if total_count > 0 else 0
weighted_indexes = [(index, weight * percentage) for index in node.indexes]
node.class_weights = weighted_indexes # Corrected this line
for child in node.children.values():
_apply_weight_to_indexes(child)
_apply_weight_to_indexes(self.root)
KNN Model
class KNNModel:
def __init__(self, k=5):
self.data = None
self.labels = []
self.k = k
def fit(self, data, indexes, labels):
# print("Fitting model with indexes:", len(indexes), "labels:", len(labels))
self.data = data[indexes].astype(float)
self.labels = np.array(labels).astype(float)
def predict(self, data):
# print("Predicting with data:", data)
dist_ind = np.sqrt(np.sum((self.data - data) ** 2, axis=1) ** 2) # euclidean distance
main_arr = np.column_stack((self.labels, dist_ind)) # labels with distance
main = main_arr[main_arr[:, 1].argsort()] # sorting based on distance
count = Counter(main[0:self.k, 0]) # counting labels
sums = np.array(list(count.values())) # getting counts
return sums / np.sum(sums) # prediction as probability
Training and Evaluation
Here’s how we train and evaluate the TrieKNN model:
# Sample data
n = 10000
data = np.array((np.random.choice(['Anything ', 'By ','Chance '], p=[0.6,0.1,0.3],size=n),
np.random.choice(['can', 'go', 'here','lets', 'see', "it"], p=[0.1, 0.1, 0.1, 0.2, 0.4, 0.1], size=n),
np.random.normal(3, 1, size=n),
np.random.normal(5, 2, size=n))).T
y_label = np.random.choice([0,1], p=[0.7, 0.3], size=n)
# Trie training
trie = Trie()
for X, y in zip(data, y_label):
trie.insert((X[:2], y),None)
# Apply weights to indexes
trie.apply_weight_to_indexes(0.5)
# Fit models of leaf nodes
def add_model(node, data):
node.model = KNNModel()
node.model.fit(data, node.indexes, node.labels)
def traverse_and_add_model(node, data):
if node.is_end_of_word:
add_model(node, data) # Add model to leaf node
for child in node.children.values():
traverse_and_add_model(child, data)
traverse_and_add_model(trie.root, data[:, 2:])
Explanation
- We create sample data with mixed categorical and numerical features.
- We insert each data point into the Trie, using the categorical features as the path.
- After the Trie is built, we traverse it and fit a KNN model to the data points stored at each leaf node.
- Finally, we can predict the class of new data points by traversing the Trie and using the KNN model at the corresponding leaf node.
Results and Discussion
Let us display the trie.
trie.display()
The model predicted the following values:
# Prediction example
def predict_with_model(node):
predictions = node.model.predict(np.array([2,5]))
print("Predictions:", predictions)
trie.apply(predict_with_model)
The predictions will vary on each run. From this we can see that we can use KNN on mixed data types.
Conclusion: A Promising Path Forward
TrieKNN presents a compelling solution for extending the applicability of KNN to mixed data types. By leveraging the Trie data structure, it avoids direct distance calculations on categorical features, enabling the use of localized KNN models for numerical data.
Further research could explore:
- Optimizing the weighting scheme for combining predictions from different Trie levels.
- Comparing TrieKNN’s performance against other mixed-data KNN approaches on benchmark datasets.
- Extending TrieKNN to handle missing data and noisy categorical features.
TrieKNN opens up new possibilities for applying KNN in domains where mixed data types are prevalent, such as healthcare, e-commerce, and social science.
Resources and further reads:
- Nomclust R package
- An Improved kNN Based on Class Contribution and Feature Weighting
- An Improved Weighted KNN Algorithm for Imbalanced Data Classification
- A weighting approach for KNN classifier
- Unsupervised Outlier Detection for Mixed-Valued Dataset Based on the Adaptive k-Nearest Neighbor Global Network
- A hybrid approach based on k-nearest neighbors and decision tree for software fault prediction
- Analysis of Decision Tree and K-Nearest Neighbor Algorithm in the Classification of Breast Cancer
The Weekly AI Decision Brief