Package Bio :: Module kNN
[hide private]
[frames] | no frames]

Source Code for Module Bio.kNN

  1  #!/usr/bin/env python 
  2  # This code is part of the Biopython distribution and governed by its 
  3  # license.  Please see the LICENSE file that should have been included 
  4  # as part of this package. 
  5  """Code for doing k-nearest-neighbors classification. 
  6   
  7  k Nearest Neighbors is a supervised learning algorithm that classifies 
  8  a new observation based the classes in its surrounding neighborhood. 
  9   
 10  Glossary: 
 11   - distance   The distance between two points in the feature space. 
 12   - weight     The importance given to each point for classification. 
 13   
 14  Classes: 
 15   - kNN           Holds information for a nearest neighbors classifier. 
 16   
 17   
 18  Functions: 
 19   - train        Train a new kNN classifier. 
 20   - calculate    Calculate the probabilities of each class, given an observation. 
 21   - classify     Classify an observation into a class. 
 22   
 23  Weighting Functions: 
 24   - equal_weight    Every example is given a weight of 1. 
 25   
 26  """ 
 27   
 28  import numpy 
 29   
 30   
31 -class kNN(object):
32 """Holds information necessary to do nearest neighbors classification. 33 34 Attribues: 35 - classes Set of the possible classes. 36 - xs List of the neighbors. 37 - ys List of the classes that the neighbors belong to. 38 - k Number of neighbors to look at. 39 40 """ 41
42 - def __init__(self):
43 """kNN()""" 44 self.classes = set() 45 self.xs = [] 46 self.ys = [] 47 self.k = None
48 49
50 -def equal_weight(x, y):
51 """equal_weight(x, y) -> 1""" 52 # everything gets 1 vote 53 return 1
54 55
56 -def train(xs, ys, k, typecode=None):
57 """train(xs, ys, k) -> kNN 58 59 Train a k nearest neighbors classifier on a training set. xs is a 60 list of observations and ys is a list of the class assignments. 61 Thus, xs and ys should contain the same number of elements. k is 62 the number of neighbors that should be examined when doing the 63 classification. 64 """ 65 knn = kNN() 66 knn.classes = set(ys) 67 knn.xs = numpy.asarray(xs, typecode) 68 knn.ys = ys 69 knn.k = k 70 return knn
71 72
73 -def calculate(knn, x, weight_fn=equal_weight, distance_fn=None):
74 """calculate(knn, x[, weight_fn][, distance_fn]) -> weight dict 75 76 Calculate the probability for each class. knn is a kNN object. x 77 is the observed data. weight_fn is an optional function that 78 takes x and a training example, and returns a weight. distance_fn 79 is an optional function that takes two points and returns the 80 distance between them. If distance_fn is None (the default), the 81 Euclidean distance is used. Returns a dictionary of the class to 82 the weight given to the class. 83 """ 84 x = numpy.asarray(x) 85 86 order = [] # list of (distance, index) 87 if distance_fn: 88 for i in range(len(knn.xs)): 89 dist = distance_fn(x, knn.xs[i]) 90 order.append((dist, i)) 91 else: 92 # Default: Use a fast implementation of the Euclidean distance 93 temp = numpy.zeros(len(x)) 94 # Predefining temp allows reuse of this array, making this 95 # function about twice as fast. 96 for i in range(len(knn.xs)): 97 temp[:] = x - knn.xs[i] 98 dist = numpy.sqrt(numpy.dot(temp, temp)) 99 order.append((dist, i)) 100 order.sort() 101 102 # first 'k' are the ones I want. 103 weights = {} # class -> number of votes 104 for k in knn.classes: 105 weights[k] = 0.0 106 for dist, i in order[:knn.k]: 107 klass = knn.ys[i] 108 weights[klass] = weights[klass] + weight_fn(x, knn.xs[i]) 109 110 return weights
111 112
113 -def classify(knn, x, weight_fn=equal_weight, distance_fn=None):
114 """classify(knn, x[, weight_fn][, distance_fn]) -> class 115 116 Classify an observation into a class. If not specified, weight_fn will 117 give all neighbors equal weight. distance_fn is an optional function 118 that takes two points and returns the distance between them. If 119 distance_fn is None (the default), the Euclidean distance is used. 120 """ 121 weights = calculate( 122 knn, x, weight_fn=weight_fn, distance_fn=distance_fn) 123 124 most_class = None 125 most_weight = None 126 for klass, weight in weights.items(): 127 if most_class is None or weight > most_weight: 128 most_class = klass 129 most_weight = weight 130 return most_class
131