好得很程序员自学网

<tfoot draggable='sEl'></tfoot>

KDTree实现KNN算法

KDTree实现KNN算法

完整的实验代码在我的github上👉 QYHcrossover/ML-numpy: 机器学习算法numpy实现 (github.com) 欢迎star⭐

在之前的博客中,我们已经学习了KNN算法的原理和代码实现。KNN算法通过计算待分类样本点和已知样本点之间的距离,选取距离最近的K个点,通过多数表决的方式进行分类。但是,当样本数据量很大时,计算所有样本之间的距离会变得非常耗时,因此我们需要一种更高效的方法来解决这个问题。

KDTree介绍

KDTree是一种常见的数据结构,可以用于高效地查找多维空间中的最近邻点。在KDTree中,每个节点都是一个k维点,节点可以分为左右子树,子树中的节点代表k维空间中的点集。建立KDTree的过程可以通过递归来实现,对于每个节点,我们需要选择一个维度和一个分割值,将该节点的点集按照这个维度的值分为两部分,分别放到左右子树中。分割值可以选取中位数或者其他的分位数,这样可以保证左右子树的平衡,避免树的深度过大,影响查询效率。

基于KDTree的KNN代码实现

在代码实现中,定义了一个 TreeNode 类来表示 KD Tree 的节点,每个节点包含了四个属性: data 表示节点对应的数据点, label 表示数据点的标签, fi 表示当前节点所在的维度, fv 表示当前节点所在维度的特征值,以及 left 和 right 表示左右子节点。

 class TreeNode:
    def __init__(self,data=None,label=None,fi=None,fv=None,left=None,right=None):
        self.data = data
        self.label = label
        self.fi = fi
        self.fv = fv
        self.left = left
        self.right = right
 

接着定义了 KDTreeKNN 类,其中 __init__ 函数接收一个参数 k ,表示 K 近邻算法中的 K 值,即选择最近的 K 个邻居。 buildTree 函数是构建 KD Tree 的核心函数,它接收三个参数: X 表示数据集, y 表示数据集对应的标签,以及 depth 表示当前节点所在的深度。在递归过程中,每次选择当前节点所在的维度 fi ,并将数据集按照该维度的特征值排序,选择排序后中间位置的数据点作为当前节点,然后递归构建左右子树,并返回当前节点。

 class KDTreeKNN:
    def __init__(self,k=3):
        self.k = k
    
    def buildTree(self,X,y,depth):
        n_size,n_feature = X.shape
        #递归终止条件
        if n_size == 1:
            tree = TreeNode(data=X[0],label=y[0])
            return tree

        fi = depth % n_feature
        argsort = np.argsort(X[:,fi])
        middle_idx = argsort[n_size // 2]
        left_idxs,right_idxs = argsort[:n_size//2],argsort[n_size//2+1:]

        fv = X[middle_idx,fi]
        data,label = X[middle_idx],y[middle_idx]
        left,right = None,None
        if len(left_idxs) > 0:
            left = self.buildTree(X[left_idxs],y[left_idxs],depth+1)
        if len(right_idxs) > 0:
            right = self.buildTree(X[right_idxs],y[right_idxs],depth+1)
        tree = TreeNode(data,label,fi,fv,left,right)
        return tree
 

当我们在KNN算法中找到当前测试样本最近的k个训练样本后,需要根据这k个训练样本的标签来决定当前测试样本的预测标签。为了找到最近的k个训练样本,我们需要使用一个函数来计算两个样本之间的距离,这就是 distance 函数的作用。在该KDTreeKNN类中,欧式距离被用于计算两个样本之间的距离,该函数的实现方式非常简单,只需要计算两个样本在每个特征上差值的平方和的平方根即可。

 @staticmethod
def distance(a,b):
    return np.sqrt(((a-b)**2).sum())
 

而 find_nearest 函数的作用是找到当前测试样本在KD树上的最近邻居。该函数的实现方式是通过递归搜索KD树,首先从根节点开始,在递归搜索过程中,如果当前节点到测试样本的距离小于当前最近邻居距离,则将当前节点设为当前最近邻居。如果当前节点有左子节点,则根据该节点所代表的维度与测试样本在该维度上的比较,决定搜索哪一个子树;如果当前节点有右子节点,则同理。如果该节点没有子节点或者该节点到测试样本最近的距离小于当前最近邻居与该节点的距离,则返回上一级节点。在递归搜索完成后,函数返回当前测试样本的最近邻居,该邻居的标签即为当前测试样本的预测标签。

 @staticmethod
def distance(a,b):
    return np.sqrt(((a-b)**2).sum())

def find_nearest(self,x,finded):
    nearest_point = None
    nearest_dis = np.inf
    nearest_label = None
    def travel(kdtree,x):
        nonlocal nearest_dis,nearest_point,nearest_label
        if kdtree == None:
            return

        #如果根节点到目标点的距离小于最近距离,则更新nearest_point和nearest_dis
        if KDTreeKNN.distance(kdtree.data,x) < nearest_dis and not self._isin(kdtree.data,finded) :
            nearest_dis = KDTreeKNN.distance(kdtree.data,x)
            nearest_point = kdtree.data
            nearest_label = kdtree.label

        if kdtree.fi == None or kdtree.fv == None:
            return

        #进入下一个相应的子节点
        if x[kdtree.fi] < kdtree.fv:
            travel(kdtree.left,x)
            if x[kdtree.fi] + nearest_dis > kdtree.fv:
                travel(kdtree.right,x)
        elif x[kdtree.fi] > kdtree.fv:
            travel(kdtree.right,x)
            if x[kdtree.fi] - nearest_dis < kdtree.fv:
                travel(kdtree.left,x)
        else:
            travel(kdtree.left,x)
            travel(kdtree.right,x)
    travel(self.tree,x)
    return nearest_point,nearest_dis,nearest_label
 

最后我们整合一个加上 fit , _predict , predict , score 等完成 KDTreeKNN 类的全部功能,包括训练、预测和评估的。

fit 函数是用来训练KD-Tree KNN模型的。它的输入是训练集的特征 X 和标签 y 。在 fit 函数内部,首先通过调用 buildTree 函数,构建出一个KD树。然后将这个KD树保存在 tree 属性中,以备后续的预测和评估使用。

 def fit(self,X,y):
        self.tree = self.buildTree(X,y,0)
 

_predict 函数是用来预测单个样本的标签的。它的输入是一个样本的特征 x 。在函数内部,首先通过调用 find_nearest 函数,找到最近的 k 个样本点。然后根据这 k 个样本点的标签,通过投票的方式决定该样本点的标签。最后返回该样本点的预测标签。

以下是 _predict 函数的代码实现:

 def _predict(self,x):
        finded = []
        labels = []
        for i in range(self.k):
            nearest_point,nearest_dis,nearest_label = self.find_nearest(x,finded)
            finded.append(nearest_point)
            labels.append(nearest_label)
        
        counter={}
        for i in labels:
            counter.setdefault(i,0)
            counter[i]+=1
        sort=sorted(counter.items(),key=lambda x:x[1])
        return sort[0][0]
 

predict 函数是用来预测整个测试集的标签的。它的输入是测试集的特征 X 。在函数内部,通过循环调用 _predict 函数,对测试集中的每个样本点进行预测。最后将所有预测结果保存在一个 numpy 数组中,并返回该数组。

以下是 predict 函数的代码实现:

 def predict(self,X):
        return np.array([self._predict(x) for x in tqdm(X)])
 

score 函数是用来评估模型的性能的。它的输入是测试集的特征 X 和标签 y 。在函数内部,首先通过调用 predict 函数,得到测试集的预测结果。然后将预测结果和真实标签进行比较,计算模型的准确率。最后将准确率作为评估结果返回。

以下是 score 函数的代码实现:

 def score(self,X,y):
    	return np.sum(self.predict(X)==y) / len(y)
 

最后是全部代码

 import numpy as np
from collections import Counter
from tqdm import tqdm

class TreeNode:
    def __init__(self,data=None,label=None,fi=None,fv=None,left=None,right=None):
        self.data = data
        self.label = label
        self.fi = fi
        self.fv = fv
        self.left = left
        self.right = right

class KDTreeKNN:
    def __init__(self,k=3):
        self.k = k
    
    def buildTree(self,X,y,depth):
        n_size,n_feature = X.shape
        #递归终止条件
        if n_size == 1:
            tree = TreeNode(data=X[0],label=y[0])
            return tree

        fi = depth % n_feature
        argsort = np.argsort(X[:,fi])
        middle_idx = argsort[n_size // 2]
        left_idxs,right_idxs = argsort[:n_size//2],argsort[n_size//2+1:]

        fv = X[middle_idx,fi]
        data,label = X[middle_idx],y[middle_idx]
        left,right = None,None
        if len(left_idxs) > 0:
            left = self.buildTree(X[left_idxs],y[left_idxs],depth+1)
        if len(right_idxs) > 0:
            right = self.buildTree(X[right_idxs],y[right_idxs],depth+1)
        tree = TreeNode(data,label,fi,fv,left,right)
        return tree
    
    def fit(self,X,y):
        self.tree = self.buildTree(X,y,0)
        
    def _predict(self,x):
        finded = []
        labels = []
        for i in range(self.k):
            nearest_point,nearest_dis,nearest_label = self.find_nearest(x,finded)
            finded.append(nearest_point)
            labels.append(nearest_label)
        
        counter={}
        for i in labels:
            counter.setdefault(i,0)
            counter[i]+=1
        sort=sorted(counter.items(),key=lambda x:x[1])
        return sort[0][0]
    
    def predict(self,X):
        return np.array([self._predict(x) for x in X])

    def score(self,X,y):
    	return np.sum(self.predict(X)==y) / len(y)
    
    def _isin(self,x,finded):
        for f in finded:
            if KDTreeKNN.distance(x,f) < 1e-6: return True
        return False
        
    @staticmethod
    def distance(a,b):
        return np.sqrt(((a-b)**2).sum())
    
    def find_nearest(self,x,finded):
        nearest_point = None
        nearest_dis = np.inf
        nearest_label = None
        def travel(kdtree,x):
            nonlocal nearest_dis,nearest_point,nearest_label
            if kdtree == None:
                return

            #如果根节点到目标点的距离小于最近距离,则更新nearest_point和nearest_dis
            if KDTreeKNN.distance(kdtree.data,x) < nearest_dis and not self._isin(kdtree.data,finded) :
                nearest_dis = KDTreeKNN.distance(kdtree.data,x)
                nearest_point = kdtree.data
                nearest_label = kdtree.label

            if kdtree.fi == None or kdtree.fv == None:
                return

            #进入下一个相应的子节点
            if x[kdtree.fi] < kdtree.fv:
                travel(kdtree.left,x)
                if x[kdtree.fi] + nearest_dis > kdtree.fv:
                    travel(kdtree.right,x)
            elif x[kdtree.fi] > kdtree.fv:
                travel(kdtree.right,x)
                if x[kdtree.fi] - nearest_dis < kdtree.fv:
                    travel(kdtree.left,x)
            else:
                travel(kdtree.left,x)
                travel(kdtree.right,x)
        travel(self.tree,x)
        return nearest_point,nearest_dis,nearest_label
 

我们可以使用 iris 数据集或者其他数据集测试该KNN算法的准确率等。

完整的实验代码在我的github上👉 QYHcrossover/ML-numpy: 机器学习算法numpy实现 (github.com) 欢迎star⭐

查看更多关于KDTree实现KNN算法的详细内容...

  阅读:46次