Differences between revisions 4 and 5

Deletions are marked like this. Additions are marked like this.
Line 12: Line 12:
{{{ {{{#!python
Line 19: Line 19:
{{{ {{{#!python
Line 26: Line 26:
An example of such model is the class LinearLeastSquaresModel as seen the file source (below)

inline:ransac.py

The attached file ( ransac.py ) implements the RANSAC algorithm. An example image:

To run the file, save it to your computer, start IPython

ipython -wthread

Import the module and run the test program

   1 import ransac
   2 ransac.test()

To use the module you need to create a model class with two methods

   1 def fit(self, data):
   2   """Given the data fit the data with your model and return the model (a vector)"""
   3 def get_error(self, data, model):
   4   """Given a set of data and a model, what is the error of using this model to estimate the data """

An example of such model is the class LinearLeastSquaresModel as seen the file source (below)

   1 import numpy
   2 import scipy # use numpy if scipy unavailable
   3 import scipy.linalg # use numpy if scipy unavailable
   4 
   5 ## Copyright (c) 2004-2007, Andrew D. Straw. All rights reserved.
   6 
   7 ## Redistribution and use in source and binary forms, with or without
   8 ## modification, are permitted provided that the following conditions are
   9 ## met:
  10 
  11 ##     * Redistributions of source code must retain the above copyright
  12 ##       notice, this list of conditions and the following disclaimer.
  13 
  14 ##     * Redistributions in binary form must reproduce the above
  15 ##       copyright notice, this list of conditions and the following
  16 ##       disclaimer in the documentation and/or other materials provided
  17 ##       with the distribution.
  18 
  19 ##     * Neither the name of the Andrew D. Straw nor the names of its
  20 ##       contributors may be used to endorse or promote products derived
  21 ##       from this software without specific prior written permission.
  22 
  23 ## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
  24 ## "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
  25 ## LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
  26 ## A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
  27 ## OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
  28 ## SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
  29 ## LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  30 ## DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  31 ## THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  32 ## (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
  33 ## OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  34 
  35 def ransac(data,model,n,k,t,d,debug=False,return_all=False):
  36     """fit model parameters to data using the RANSAC algorithm
  37     
  38 This implementation written from pseudocode found at
  39 http://en.wikipedia.org/w/index.php?title=RANSAC&oldid=116358182
  40 
  41 {{{
  42 Given:
  43     data - a set of observed data points
  44     model - a model that can be fitted to data points
  45     n - the minimum number of data values required to fit the model
  46     k - the maximum number of iterations allowed in the algorithm
  47     t - a threshold value for determining when a data point fits a model
  48     d - the number of close data values required to assert that a model fits well to data
  49 Return:
  50     bestfit - model parameters which best fit the data (or nil if no good model is found)
  51 iterations = 0
  52 bestfit = nil
  53 besterr = something really large
  54 while iterations < k {
  55     maybeinliers = n randomly selected values from data
  56     maybemodel = model parameters fitted to maybeinliers
  57     alsoinliers = empty set
  58     for every point in data not in maybeinliers {
  59         if point fits maybemodel with an error smaller than t
  60              add point to alsoinliers
  61     }
  62     if the number of elements in alsoinliers is > d {
  63         % this implies that we may have found a good model
  64         % now test how good it is
  65         bettermodel = model parameters fitted to all points in maybeinliers and alsoinliers
  66         thiserr = a measure of how well model fits these points
  67         if thiserr < besterr {
  68             bestfit = bettermodel
  69             besterr = thiserr
  70         }
  71     }
  72     increment iterations
  73 }
  74 return bestfit
  75 }}}
  76 """
  77     iterations = 0
  78     bestfit = None
  79     besterr = numpy.inf
  80     best_inlier_idxs = None
  81     while iterations < k:
  82         maybe_idxs, test_idxs = random_partition(n,data.shape[0])
  83         maybeinliers = data[maybe_idxs,:]
  84         test_points = data[test_idxs]
  85         maybemodel = model.fit(maybeinliers)
  86         test_err = model.get_error( test_points, maybemodel)
  87         also_idxs = test_idxs[test_err < t] # select indices of rows with accepted points
  88         alsoinliers = data[also_idxs,:]
  89         if debug:
  90             print 'test_err.min()',test_err.min()
  91             print 'test_err.max()',test_err.max()
  92             print 'numpy.mean(test_err)',numpy.mean(test_err)
  93             print 'iteration %d:len(alsoinliers) = %d'%(
  94                 iterations,len(alsoinliers))
  95         if len(alsoinliers) > d:
  96             betterdata = numpy.concatenate( (maybeinliers, alsoinliers) )
  97             bettermodel = model.fit(betterdata)
  98             better_errs = model.get_error( betterdata, bettermodel)
  99             thiserr = numpy.mean( better_errs )
 100             if thiserr < besterr:
 101                 bestfit = bettermodel
 102                 besterr = thiserr
 103                 best_inlier_idxs = numpy.concatenate( (maybe_idxs, also_idxs) )
 104         iterations+=1
 105     if bestfit is None:
 106         raise ValueError("did not meet fit acceptance criteria")
 107     if return_all:
 108         return bestfit, {'inliers':best_inlier_idxs}
 109     else:
 110         return bestfit
 111 
 112 def random_partition(n,n_data):
 113     """return n random rows of data (and also the other len(data)-n rows)"""
 114     all_idxs = numpy.arange( n_data )
 115     numpy.random.shuffle(all_idxs)
 116     idxs1 = all_idxs[:n]
 117     idxs2 = all_idxs[n:]
 118     return idxs1, idxs2
 119 
 120 class LinearLeastSquaresModel:
 121     """linear system solved using linear least squares
 122 
 123     This class serves as an example that fulfills the model interface
 124     needed by the ransac() function.
 125     
 126     """
 127     def __init__(self,input_columns,output_columns,debug=False):
 128         self.input_columns = input_columns
 129         self.output_columns = output_columns
 130         self.debug = debug
 131     def fit(self, data):
 132         A = numpy.vstack([data[:,i] for i in self.input_columns]).T
 133         B = numpy.vstack([data[:,i] for i in self.output_columns]).T
 134         x,resids,rank,s = scipy.linalg.lstsq(A,B)
 135         return x
 136     def get_error( self, data, model):
 137         A = numpy.vstack([data[:,i] for i in self.input_columns]).T
 138         B = numpy.vstack([data[:,i] for i in self.output_columns]).T
 139         B_fit = scipy.dot(A,model)
 140         err_per_point = numpy.sum((B-B_fit)**2,axis=1) # sum squared error per row
 141         return err_per_point
 142 
 143 def test():
 144     # generate perfect input data
 145 
 146     n_samples = 500
 147     n_inputs = 1
 148     n_outputs = 1
 149     A_exact = 20*numpy.random.random((n_samples,n_inputs) )
 150     perfect_fit = 60*numpy.random.normal(size=(n_inputs,n_outputs) ) # the model
 151     B_exact = scipy.dot(A_exact,perfect_fit)
 152     assert B_exact.shape == (n_samples,n_outputs)
 153 
 154     # add a little gaussian noise (linear least squares alone should handle this well)
 155     A_noisy = A_exact + numpy.random.normal(size=A_exact.shape )
 156     B_noisy = B_exact + numpy.random.normal(size=B_exact.shape )
 157 
 158     if 1:
 159         # add some outliers
 160         n_outliers = 100
 161         all_idxs = numpy.arange( A_noisy.shape[0] )
 162         numpy.random.shuffle(all_idxs)
 163         outlier_idxs = all_idxs[:n_outliers]
 164         non_outlier_idxs = all_idxs[n_outliers:]
 165         A_noisy[outlier_idxs] =  20*numpy.random.random((n_outliers,n_inputs) )
 166         B_noisy[outlier_idxs] = 50*numpy.random.normal(size=(n_outliers,n_outputs) )
 167 
 168     # setup model
 169 
 170     all_data = numpy.hstack( (A_noisy,B_noisy) )
 171     input_columns = range(n_inputs) # the first columns of the array
 172     output_columns = [n_inputs+i for i in range(n_outputs)] # the last columns of the array
 173     debug = False
 174     model = LinearLeastSquaresModel(input_columns,output_columns,debug=debug)
 175 
 176     linear_fit,resids,rank,s = scipy.linalg.lstsq(all_data[:,input_columns],
 177                                                   all_data[:,output_columns])
 178 
 179     # run RANSAC algorithm
 180     ransac_fit, ransac_data = ransac(all_data,model,
 181                                      50, 1000, 7e3, 300, # misc. parameters
 182                                      debug=debug,return_all=True)
 183     if 1:
 184         import pylab
 185 
 186         sort_idxs = numpy.argsort(A_exact[:,0])
 187         A_col0_sorted = A_exact[sort_idxs] # maintain as rank-2 array
 188 
 189         if 1:
 190             pylab.plot( A_noisy[:,0], B_noisy[:,0], 'k.', label='data' )
 191             pylab.plot( A_noisy[ransac_data['inliers'],0], B_noisy[ransac_data['inliers'],0], 'bx', label='RANSAC data' )
 192         else:
 193             pylab.plot( A_noisy[non_outlier_idxs,0], B_noisy[non_outlier_idxs,0], 'k.', label='noisy data' )
 194             pylab.plot( A_noisy[outlier_idxs,0], B_noisy[outlier_idxs,0], 'r.', label='outlier data' )
 195         pylab.plot( A_col0_sorted[:,0],
 196                     numpy.dot(A_col0_sorted,ransac_fit)[:,0],
 197                     label='RANSAC fit' )
 198         pylab.plot( A_col0_sorted[:,0],
 199                     numpy.dot(A_col0_sorted,perfect_fit)[:,0],
 200                     label='exact system' )
 201         pylab.plot( A_col0_sorted[:,0],
 202                     numpy.dot(A_col0_sorted,linear_fit)[:,0],
 203                     label='linear fit' )
 204         pylab.legend()
 205         pylab.show()
 206 
 207 if __name__=='__main__':
 208     test()
ransac.py


Cookbook/RANSAC (last edited 2009-05-09 19:44:22 by DatChu)