[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

random_forest.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RANDOM_FOREST_HXX
38 #define VIGRA_RANDOM_FOREST_HXX
39 
40 #include <iostream>
41 #include <algorithm>
42 #include <map>
43 #include <set>
44 #include <list>
45 #include <numeric>
46 #include "mathutil.hxx"
47 #include "array_vector.hxx"
48 #include "sized_int.hxx"
49 #include "matrix.hxx"
50 #include "random.hxx"
51 #include "functorexpression.hxx"
52 #include "random_forest/rf_common.hxx"
53 #include "random_forest/rf_nodeproxy.hxx"
54 #include "random_forest/rf_split.hxx"
55 #include "random_forest/rf_decisionTree.hxx"
56 #include "random_forest/rf_visitors.hxx"
57 #include "random_forest/rf_region.hxx"
58 #include "sampling.hxx"
59 #include "random_forest/rf_preprocessing.hxx"
60 #include "random_forest/rf_online_prediction_set.hxx"
61 #include "random_forest/rf_earlystopping.hxx"
62 #include "random_forest/rf_ridge_split.hxx"
63 namespace vigra
64 {
65 
66 /** \addtogroup MachineLearning Machine Learning
67 
68  This module provides classification algorithms that map
69  features to labels or label probabilities.
70  Look at the RandomForest class first for a overview of most of the
71  functionality provided as well as use cases.
72 **/
73 //@{
74 
75 namespace detail
76 {
77 
78 
79 
80 /* \brief sampling option factory function
81  */
82 inline SamplerOptions make_sampler_opt ( RandomForestOptions & RF_opt)
83 {
84  SamplerOptions return_opt;
85  return_opt.withReplacement(RF_opt.sample_with_replacement_);
86  return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
87  return return_opt;
88 }
89 }//namespace detail
90 
91 /** Random Forest class
92  *
93  * \tparam <LabelType = double> Type used for predicted labels.
94  * \tparam <PreprocessorTag = ClassificationTag> Class used to preprocess
95  * the input while learning and predicting. Currently Available:
96  * ClassificationTag and RegressionTag. It is recommended to use
97  * Splitfunctor::Preprocessor_t while using custom splitfunctors
98  * as they may need the data to be in a different format.
99  * \sa Preprocessor
100  *
101  * Simple usage for classification (regression is not yet supported):
102  * look at RandomForest::learn() as well as RandomForestOptions() for additional
103  * options.
104  *
105  * \code
106  * using namespace vigra;
107  * using namespace rf;
108  * typedef xxx feature_t; \\ replace xxx with whichever type
109  * typedef yyy label_t; \\ likewise
110  *
111  * // allocate the training data
112  * MultiArrayView<2, feature_t> f = get_training_features();
113  * MultiArrayView<2, label_t> l = get_training_labels();
114  *
115  * RandomForest<label_t> rf;
116  *
117  * // construct visitor to calculate out-of-bag error
118  * visitors::OOB_Error oob_v;
119  *
120  * // perform training
121  * rf.learn(f, l, visitors::create_visitor(oob_v));
122  *
123  * std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n";
124  *
125  * // get features for new data to be used for prediction
126  * MultiArrayView<2, feature_t> pf = get_features();
127  *
128  * // allocate space for the response (pf.shape(0) is the number of samples)
129  * MultiArrayView<2, label_t> prediction(pf.shape(0), 1);
130  * MultiArrayView<2, double> prob(pf.shape(0), rf.class_count());
131  *
132  * // perform prediction on new data
133  * rf.predictLabels(pf, prediction);
134  * rf.predictProbabilities(pf, prob);
135  *
136  * \endcode
137  *
138  * Additional information such as Variable Importance measures are accessed
139  * via Visitors defined in rf::visitors.
140  * Have a look at rf::split for other splitting methods.
141  *
142 */
143 template <class LabelType = double , class PreprocessorTag = ClassificationTag >
145 {
146 
147  public:
148  //public typedefs
150  typedef detail::DecisionTree DecisionTree_t;
152  typedef GiniSplit Default_Split_t;
156  StackEntry_t;
157  typedef LabelType LabelT;
158 
159  //problem independent data.
160  Options_t options_;
161  //problem dependent data members - is only set if
162  //a copy constructor, some sort of import
163  //function or the learn function is called
165  ProblemSpec_t ext_param_;
166  /*mutable ArrayVector<int> tree_indices_;*/
167  rf::visitors::OnlineLearnVisitor online_visitor_;
168 
169 
170  void reset()
171  {
172  ext_param_.clear();
173  trees_.clear();
174  }
175 
176  public:
177 
178  /** \name Constructors
179  * Note: No copy Constructor specified as no pointers are manipulated
180  * in this class
181  */
182  /*\{*/
183  /**\brief default constructor
184  *
185  * \param options general options to the Random Forest. Must be of Type
186  * Options_t
187  * \param ext_param problem specific values that can be supplied
188  * additionally. (class weights , labels etc)
189  * \sa RandomForestOptions, ProblemSpec
190  *
191  */
192  RandomForest(Options_t const & options = Options_t(),
193  ProblemSpec_t const & ext_param = ProblemSpec_t())
194  :
195  options_(options),
196  ext_param_(ext_param)/*,
197  tree_indices_(options.tree_count_,0)*/
198  {
199  /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii)
200  tree_indices_[ii] = ii;*/
201  }
202 
203  /**\brief Create RF from external source
204  * \param treeCount Number of trees to add.
205  * \param topology_begin
206  * Iterator to a Container where the topology_ data
207  * of the trees are stored.
208  * Iterator should support at least treeCount forward
209  * iterations. (i.e. topology_end - topology_begin >= treeCount
210  * \param parameter_begin
211  * iterator to a Container where the parameters_ data
212  * of the trees are stored. Iterator should support at
213  * least treeCount forward iterations.
214  * \param problem_spec
215  * Extrinsic parameters that specify the problem e.g.
216  * ClassCount, featureCount etc.
217  * \param options (optional) specify options used to train the original
218  * Random forest. This parameter is not used anywhere
219  * during prediction and thus is optional.
220  *
221  */
222  /* TODO: This constructor may be replaced by a Constructor using
223  * NodeProxy iterators to encapsulate the underlying data type.
224  */
225  template<class TopologyIterator, class ParameterIterator>
226  RandomForest(int treeCount,
227  TopologyIterator topology_begin,
228  ParameterIterator parameter_begin,
229  ProblemSpec_t const & problem_spec,
230  Options_t const & options = Options_t())
231  :
232  trees_(treeCount, DecisionTree_t(problem_spec)),
233  ext_param_(problem_spec),
234  options_(options)
235  {
236  for(unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
237  {
238  trees_[k].topology_ = *topology_begin;
239  trees_[k].parameters_ = *parameter_begin;
240  }
241  }
242 
243  /*\}*/
244 
245 
246  /** \name Data Access
247  * data access interface - usage of member variables is deprecated
248  */
249 
250  /*\{*/
251 
252 
253  /**\brief return external parameters for viewing
254  * \return ProblemSpec_t
255  */
256  ProblemSpec_t const & ext_param() const
257  {
258  vigra_precondition(ext_param_.used() == true,
259  "RandomForest::ext_param(): "
260  "Random forest has not been trained yet.");
261  return ext_param_;
262  }
263 
264  /**\brief set external parameters
265  *
266  * \param in external parameters to be set
267  *
268  * set external parameters explicitly.
269  * If Random Forest has not been trained the preprocessor will
270  * either ignore filling values set this way or will throw an exception
271  * if values specified manually do not match the value calculated
272  & during the preparation step.
273  */
274  void set_ext_param(ProblemSpec_t const & in)
275  {
276  vigra_precondition(ext_param_.used() == false,
277  "RandomForest::set_ext_param():"
278  "Random forest has been trained! Call reset()"
279  "before specifying new extrinsic parameters.");
280  }
281 
282  /**\brief access random forest options
283  *
284  * \return random forest options
285  */
286  Options_t & set_options()
287  {
288  return options_;
289  }
290 
291 
292  /**\brief access const random forest options
293  *
294  * \return const Option_t
295  */
296  Options_t const & options() const
297  {
298  return options_;
299  }
300 
301  /**\brief access const trees
302  */
303  DecisionTree_t const & tree(int index) const
304  {
305  return trees_[index];
306  }
307 
308  /**\brief access trees
309  */
310  DecisionTree_t & tree(int index)
311  {
312  return trees_[index];
313  }
314 
315  /*\}*/
316 
317  /**\brief return number of features used while
318  * training.
319  */
320  int feature_count() const
321  {
322  return ext_param_.column_count_;
323  }
324 
325 
326  /**\brief return number of features used while
327  * training.
328  *
329  * deprecated. Use feature_count() instead.
330  */
331  int column_count() const
332  {
333  return ext_param_.column_count_;
334  }
335 
336  /**\brief return number of classes used while
337  * training.
338  */
339  int class_count() const
340  {
341  return ext_param_.class_count_;
342  }
343 
344  /**\brief return number of trees
345  */
346  int tree_count() const
347  {
348  return options_.tree_count_;
349  }
350 
351 
352 
353  template<class U,class C1,
354  class U2, class C2,
355  class Split_t,
356  class Stop_t,
357  class Visitor_t,
358  class Random_t>
359  void onlineLearn( MultiArrayView<2,U,C1> const & features,
360  MultiArrayView<2,U2,C2> const & response,
361  int new_start_index,
362  Visitor_t visitor_,
363  Split_t split_,
364  Stop_t stop_,
365  Random_t & random,
366  bool adjust_thresholds=false);
367 
368  template <class U, class C1, class U2,class C2>
369  void onlineLearn( MultiArrayView<2, U, C1> const & features,
370  MultiArrayView<2, U2,C2> const & labels,int new_start_index,bool adjust_thresholds=false)
371  {
373  onlineLearn(features,
374  labels,
375  new_start_index,
376  rf_default(),
377  rf_default(),
378  rf_default(),
379  rnd,
380  adjust_thresholds);
381  }
382 
383  template<class U,class C1,
384  class U2, class C2,
385  class Split_t,
386  class Stop_t,
387  class Visitor_t,
388  class Random_t>
389  void reLearnTree(MultiArrayView<2,U,C1> const & features,
390  MultiArrayView<2,U2,C2> const & response,
391  int treeId,
392  Visitor_t visitor_,
393  Split_t split_,
394  Stop_t stop_,
395  Random_t & random);
396 
397  template<class U, class C1, class U2, class C2>
398  void reLearnTree(MultiArrayView<2, U, C1> const & features,
399  MultiArrayView<2, U2, C2> const & labels,
400  int treeId)
401  {
403  reLearnTree(features,
404  labels,
405  treeId,
406  rf_default(),
407  rf_default(),
408  rf_default(),
409  rnd);
410  }
411 
412 
413  /**\name Learning
414  * Following functions differ in the degree of customization
415  * allowed
416  */
417  /*\{*/
418  /**\brief learn on data with custom config and random number generator
419  *
420  * \param features a N x M matrix containing N samples with M
421  * features
422  * \param response a N x D matrix containing the corresponding
423  * response. Current split functors assume D to
424  * be 1 and ignore any additional columns.
425  * This is not enforced to allow future support
426  * for uncertain labels, label independent strata etc.
427  * The Preprocessor specified during construction
428  * should be able to handle features and labels
429  * features and the labels.
430  * see also: SplitFunctor, Preprocessing
431  *
432  * \param visitor visitor which is to be applied after each split,
433  * tree and at the end. Use rf_default() for using
434  * default value. (No Visitors)
435  * see also: rf::visitors
436  * \param split split functor to be used to calculate each split
437  * use rf_default() for using default value. (GiniSplit)
438  * see also: rf::split
439  * \param stop
440  * predicate to be used to calculate each split
441  * use rf_default() for using default value. (EarlyStoppStd)
442  * \param random RandomNumberGenerator to be used. Use
443  * rf_default() to use default value.(RandomMT19337)
444  *
445  *
446  */
447  template <class U, class C1,
448  class U2,class C2,
449  class Split_t,
450  class Stop_t,
451  class Visitor_t,
452  class Random_t>
453  void learn( MultiArrayView<2, U, C1> const & features,
454  MultiArrayView<2, U2,C2> const & response,
455  Visitor_t visitor,
456  Split_t split,
457  Stop_t stop,
458  Random_t const & random);
459 
460  template <class U, class C1,
461  class U2,class C2,
462  class Split_t,
463  class Stop_t,
464  class Visitor_t>
465  void learn( MultiArrayView<2, U, C1> const & features,
466  MultiArrayView<2, U2,C2> const & response,
467  Visitor_t visitor,
468  Split_t split,
469  Stop_t stop)
470 
471  {
473  learn( features,
474  response,
475  visitor,
476  split,
477  stop,
478  rnd);
479  }
480 
481  template <class U, class C1, class U2,class C2, class Visitor_t>
482  void learn( MultiArrayView<2, U, C1> const & features,
483  MultiArrayView<2, U2,C2> const & labels,
484  Visitor_t visitor)
485  {
486  learn( features,
487  labels,
488  visitor,
489  rf_default(),
490  rf_default());
491  }
492 
493  template <class U, class C1, class U2,class C2,
494  class Visitor_t, class Split_t>
495  void learn( MultiArrayView<2, U, C1> const & features,
496  MultiArrayView<2, U2,C2> const & labels,
497  Visitor_t visitor,
498  Split_t split)
499  {
500  learn( features,
501  labels,
502  visitor,
503  split,
504  rf_default());
505  }
506 
507  /**\brief learn on data with default configuration
508  *
509  * \param features a N x M matrix containing N samples with M
510  * features
511  * \param labels a N x D matrix containing the corresponding
512  * N labels. Current split functors assume D to
513  * be 1 and ignore any additional columns.
514  * this is not enforced to allow future support
515  * for uncertain labels.
516  *
517  * learning is done with:
518  *
519  * \sa rf::split, EarlyStoppStd
520  *
521  * - Randomly seeded random number generator
522  * - default gini split functor as described by Breiman
523  * - default The standard early stopping criterion
524  */
525  template <class U, class C1, class U2,class C2>
526  void learn( MultiArrayView<2, U, C1> const & features,
527  MultiArrayView<2, U2,C2> const & labels)
528  {
529  learn( features,
530  labels,
531  rf_default(),
532  rf_default(),
533  rf_default());
534  }
535  /*\}*/
536 
537 
538 
539  /**\name prediction
540  */
541  /*\{*/
542  /** \brief predict a label given a feature.
543  *
544  * \param features: a 1 by featureCount matrix containing
545  * data point to be predicted (this only works in
546  * classification setting)
547  * \param stop: early stopping criterion
548  * \return double value representing class. You can use the
549  * predictLabels() function together with the
550  * rf.external_parameter().class_type_ attribute
551  * to get back the same type used during learning.
552  */
553  template <class U, class C, class Stop>
554  LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const;
555 
556  template <class U, class C>
557  LabelType predictLabel(MultiArrayView<2, U, C>const & features)
558  {
559  return predictLabel(features, rf_default());
560  }
561  /** \brief predict a label with features and class priors
562  *
563  * \param features: same as above.
564  * \param prior: iterator to prior weighting of classes
565  * \return sam as above.
566  */
567  template <class U, class C>
568  LabelType predictLabel(MultiArrayView<2, U, C> const & features,
569  ArrayVectorView<double> prior) const;
570 
571  /** \brief predict multiple labels with given features
572  *
573  * \param features: a n by featureCount matrix containing
574  * data point to be predicted (this only works in
575  * classification setting)
576  * \param labels: a n by 1 matrix passed by reference to store
577  * output.
578  *
579  * If the input contains an NaN value, an precondition exception is thrown.
580  */
581  template <class U, class C1, class T, class C2>
583  MultiArrayView<2, T, C2> & labels) const
584  {
585  vigra_precondition(features.shape(0) == labels.shape(0),
586  "RandomForest::predictLabels(): Label array has wrong size.");
587  for(int k=0; k<features.shape(0); ++k)
588  {
589  vigra_precondition(!detail::contains_nan(rowVector(features, k)),
590  "RandomForest::predictLabels(): NaN in feature matrix.");
591  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
592  }
593  }
594 
595  /** \brief predict multiple labels with given features
596  *
597  * \param features: a n by featureCount matrix containing
598  * data point to be predicted (this only works in
599  * classification setting)
600  * \param labels: a n by 1 matrix passed by reference to store
601  * output.
602  * \param nanLabel: label to be returned for the row of the input that
603  * contain an NaN value.
604  */
605  template <class U, class C1, class T, class C2>
607  MultiArrayView<2, T, C2> & labels,
608  LabelType nanLabel) const
609  {
610  vigra_precondition(features.shape(0) == labels.shape(0),
611  "RandomForest::predictLabels(): Label array has wrong size.");
612  for(int k=0; k<features.shape(0); ++k)
613  {
614  if(detail::contains_nan(rowVector(features, k)))
615  labels(k,0) = nanLabel;
616  else
617  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
618  }
619  }
620 
621  /** \brief predict multiple labels with given features
622  *
623  * \param features: a n by featureCount matrix containing
624  * data point to be predicted (this only works in
625  * classification setting)
626  * \param labels: a n by 1 matrix passed by reference to store
627  * output.
628  * \param stop: an early stopping criterion.
629  */
630  template <class U, class C1, class T, class C2, class Stop>
632  MultiArrayView<2, T, C2> & labels,
633  Stop & stop) const
634  {
635  vigra_precondition(features.shape(0) == labels.shape(0),
636  "RandomForest::predictLabels(): Label array has wrong size.");
637  for(int k=0; k<features.shape(0); ++k)
638  labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop));
639  }
640  /** \brief predict the class probabilities for multiple labels
641  *
642  * \param features same as above
643  * \param prob a n x class_count_ matrix. passed by reference to
644  * save class probabilities
645  * \param stop earlystopping criterion
646  * \sa EarlyStopping
647 
648  When a row of the feature array contains an NaN, the corresponding instance
649  cannot belong to any of the classes. The corresponding row in the probability
650  array will therefore contain all zeros.
651  */
652  template <class U, class C1, class T, class C2, class Stop>
653  void predictProbabilities(MultiArrayView<2, U, C1>const & features,
655  Stop & stop) const;
656  template <class T1,class T2, class C>
657  void predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
658  MultiArrayView<2, T2, C> & prob);
659 
660  /** \brief predict the class probabilities for multiple labels
661  *
662  * \param features same as above
663  * \param prob a n x class_count_ matrix. passed by reference to
664  * save class probabilities
665  */
666  template <class U, class C1, class T, class C2>
668  MultiArrayView<2, T, C2> & prob) const
669  {
670  predictProbabilities(features, prob, rf_default());
671  }
672 
673  template <class U, class C1, class T, class C2>
674  void predictRaw(MultiArrayView<2, U, C1>const & features,
675  MultiArrayView<2, T, C2> & prob) const;
676 
677 
678  /*\}*/
679 
680 };
681 
682 
683 template <class LabelType, class PreprocessorTag>
684 template<class U,class C1,
685  class U2, class C2,
686  class Split_t,
687  class Stop_t,
688  class Visitor_t,
689  class Random_t>
691  MultiArrayView<2,U2,C2> const & response,
692  int new_start_index,
693  Visitor_t visitor_,
694  Split_t split_,
695  Stop_t stop_,
696  Random_t & random,
697  bool adjust_thresholds)
698 {
699  online_visitor_.activate();
700  online_visitor_.adjust_thresholds=adjust_thresholds;
701 
702  using namespace rf;
703  //typedefs
706  RandFunctor_t;
707  // default values and initialization
708  // Value Chooser chooses second argument as value if first argument
709  // is of type RF_DEFAULT. (thanks to template magic - don't care about
710  // it - just smile and wave.
711 
712  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
713  Default_Stop_t default_stop(options_);
714  typename RF_CHOOSER(Stop_t)::type stop
715  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
716  Default_Split_t default_split;
717  typename RF_CHOOSER(Split_t)::type split
718  = RF_CHOOSER(Split_t)::choose(split_, default_split);
719  rf::visitors::StopVisiting stopvisiting;
722  typename RF_CHOOSER(Visitor_t)::type>
723  IntermedVis;
724  IntermedVis
725  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
726  #undef RF_CHOOSER
727  vigra_precondition(options_.prepare_online_learning_,"onlineLearn: online learning must be enabled on RandomForest construction");
728 
729  // Preprocess the data to get something the split functor can work
730  // with. Also fill the ext_param structure by preprocessing
731  // option parameters that could only be completely evaluated
732  // when the training data is known.
733  ext_param_.class_count_=0;
734  Preprocessor_t preprocessor( features, response,
735  options_, ext_param_);
736 
737  // Make stl compatible random functor.
738  RandFunctor_t randint ( random);
739 
740  // Give the Split functor information about the data.
741  split.set_external_parameters(ext_param_);
742  stop.set_external_parameters(ext_param_);
743 
744 
745  //Create poisson samples
746  PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_));
747 
748  //TODO: visitors for online learning
749  //visitor.visit_at_beginning(*this, preprocessor);
750 
751  // THE MAIN EFFING RF LOOP - YEAY DUDE!
752  for(int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
753  {
754  online_visitor_.tree_id=ii;
755  poisson_sampler.sample();
756  std::map<int,int> leaf_parents;
757  leaf_parents.clear();
758  //Get all the leaf nodes for that sample
759  for(int s=0;s<poisson_sampler.numOfSamples();++s)
760  {
761  int sample=poisson_sampler[s];
762  online_visitor_.current_label=preprocessor.response()(sample,0);
763  online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
764  int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
765 
766 
767  //Add to the list for that leaf
768  online_visitor_.add_to_index_list(ii,leaf,sample);
769  //TODO: Class count?
770  //Store parent
771  if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
772  {
773  leaf_parents[leaf]=online_visitor_.last_node_id;
774  }
775  }
776 
777 
778  std::map<int,int>::iterator leaf_iterator;
779  for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
780  {
781  int leaf=leaf_iterator->first;
782  int parent=leaf_iterator->second;
783  int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
784  ArrayVector<Int32> indeces;
785  indeces.clear();
786  indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
787  StackEntry_t stack_entry(indeces.begin(),
788  indeces.end(),
789  ext_param_.class_count_);
790 
791 
792  if(parent!=-1)
793  {
794  if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
795  {
796  stack_entry.leftParent=parent;
797  }
798  else
799  {
800  vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong");
801  stack_entry.rightParent=parent;
802  }
803  }
804  //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf);
805  trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
806  //Now, the last one moved onto leaf
807  online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
808  //Now it should be classified correctly!
809  }
810 
811  /*visitor
812  .visit_after_tree( *this,
813  preprocessor,
814  poisson_sampler,
815  stack_entry,
816  ii);*/
817  }
818 
819  //visitor.visit_at_end(*this, preprocessor);
820  online_visitor_.deactivate();
821 }
822 
823 template<class LabelType, class PreprocessorTag>
824 template<class U,class C1,
825  class U2, class C2,
826  class Split_t,
827  class Stop_t,
828  class Visitor_t,
829  class Random_t>
831  MultiArrayView<2,U2,C2> const & response,
832  int treeId,
833  Visitor_t visitor_,
834  Split_t split_,
835  Stop_t stop_,
836  Random_t & random)
837 {
838  using namespace rf;
839 
840 
842  RandFunctor_t;
843 
844  // See rf_preprocessing.hxx for more info on this
845  ext_param_.class_count_=0;
847 
848  // default values and initialization
849  // Value Chooser chooses second argument as value if first argument
850  // is of type RF_DEFAULT. (thanks to template magic - don't care about
851  // it - just smile and wave.
852 
853  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
854  Default_Stop_t default_stop(options_);
855  typename RF_CHOOSER(Stop_t)::type stop
856  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
857  Default_Split_t default_split;
858  typename RF_CHOOSER(Split_t)::type split
859  = RF_CHOOSER(Split_t)::choose(split_, default_split);
860  rf::visitors::StopVisiting stopvisiting;
863  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
864  IntermedVis
865  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
866  #undef RF_CHOOSER
867  vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
868  online_visitor_.activate();
869 
870  // Make stl compatible random functor.
871  RandFunctor_t randint ( random);
872 
873  // Preprocess the data to get something the split functor can work
874  // with. Also fill the ext_param structure by preprocessing
875  // option parameters that could only be completely evaluated
876  // when the training data is known.
877  Preprocessor_t preprocessor( features, response,
878  options_, ext_param_);
879 
880  // Give the Split functor information about the data.
881  split.set_external_parameters(ext_param_);
882  stop.set_external_parameters(ext_param_);
883 
884  /**\todo replace this crappy class out. It uses function pointers.
885  * and is making code slower according to me.
886  * Comment from Nathan: This is copied from Rahul, so me=Rahul
887  */
888  Sampler<Random_t > sampler(preprocessor.strata().begin(),
889  preprocessor.strata().end(),
890  detail::make_sampler_opt(options_)
891  .sampleSize(ext_param().actual_msample_),
892  &random);
893  //initialize First region/node/stack entry
894  sampler
895  .sample();
896 
897  StackEntry_t
898  first_stack_entry( sampler.sampledIndices().begin(),
899  sampler.sampledIndices().end(),
900  ext_param_.class_count_);
901  first_stack_entry
902  .set_oob_range( sampler.oobIndices().begin(),
903  sampler.oobIndices().end());
904  online_visitor_.reset_tree(treeId);
905  online_visitor_.tree_id=treeId;
906  trees_[treeId].reset();
907  trees_[treeId]
908  .learn( preprocessor.features(),
909  preprocessor.response(),
910  first_stack_entry,
911  split,
912  stop,
913  visitor,
914  randint);
915  visitor
916  .visit_after_tree( *this,
917  preprocessor,
918  sampler,
919  first_stack_entry,
920  treeId);
921 
922  online_visitor_.deactivate();
923 }
924 
925 template <class LabelType, class PreprocessorTag>
926 template <class U, class C1,
927  class U2,class C2,
928  class Split_t,
929  class Stop_t,
930  class Visitor_t,
931  class Random_t>
934  MultiArrayView<2, U2,C2> const & response,
935  Visitor_t visitor_,
936  Split_t split_,
937  Stop_t stop_,
938  Random_t const & random)
939 {
940  using namespace rf;
941  //this->reset();
942  //typedefs
944  RandFunctor_t;
945 
946  // See rf_preprocessing.hxx for more info on this
948 
949  vigra_precondition(features.shape(0) == response.shape(0),
950  "RandomForest::learn(): shape mismatch between features and response.");
951 
952  // default values and initialization
953  // Value Chooser chooses second argument as value if first argument
954  // is of type RF_DEFAULT. (thanks to template magic - don't care about
955  // it - just smile and wave).
956 
957  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
958  Default_Stop_t default_stop(options_);
959  typename RF_CHOOSER(Stop_t)::type stop
960  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
961  Default_Split_t default_split;
962  typename RF_CHOOSER(Split_t)::type split
963  = RF_CHOOSER(Split_t)::choose(split_, default_split);
964  rf::visitors::StopVisiting stopvisiting;
967  typename RF_CHOOSER(Visitor_t)::type> IntermedVis;
968  IntermedVis
969  visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
970  #undef RF_CHOOSER
971  if(options_.prepare_online_learning_)
972  online_visitor_.activate();
973  else
974  online_visitor_.deactivate();
975 
976 
977  // Make stl compatible random functor.
978  RandFunctor_t randint ( random);
979 
980 
981  // Preprocess the data to get something the split functor can work
982  // with. Also fill the ext_param structure by preprocessing
983  // option parameters that could only be completely evaluated
984  // when the training data is known.
985  Preprocessor_t preprocessor( features, response,
986  options_, ext_param_);
987 
988  // Give the Split functor information about the data.
989  split.set_external_parameters(ext_param_);
990  stop.set_external_parameters(ext_param_);
991 
992 
993  //initialize trees.
994  trees_.resize(options_.tree_count_ , DecisionTree_t(ext_param_));
995 
996  Sampler<Random_t > sampler(preprocessor.strata().begin(),
997  preprocessor.strata().end(),
998  detail::make_sampler_opt(options_)
999  .sampleSize(ext_param().actual_msample_),
1000  &random);
1001 
1002  visitor.visit_at_beginning(*this, preprocessor);
1003  // THE MAIN EFFING RF LOOP - YEAY DUDE!
1004 
1005  for(int ii = 0; ii < static_cast<int>(trees_.size()); ++ii)
1006  {
1007  //initialize First region/node/stack entry
1008  sampler
1009  .sample();
1010  StackEntry_t
1011  first_stack_entry( sampler.sampledIndices().begin(),
1012  sampler.sampledIndices().end(),
1013  ext_param_.class_count_);
1014  first_stack_entry
1015  .set_oob_range( sampler.oobIndices().begin(),
1016  sampler.oobIndices().end());
1017  trees_[ii]
1018  .learn( preprocessor.features(),
1019  preprocessor.response(),
1020  first_stack_entry,
1021  split,
1022  stop,
1023  visitor,
1024  randint);
1025  visitor
1026  .visit_after_tree( *this,
1027  preprocessor,
1028  sampler,
1029  first_stack_entry,
1030  ii);
1031  }
1032 
1033  visitor.visit_at_end(*this, preprocessor);
1034  // Only for online learning?
1035  online_visitor_.deactivate();
1036 }
1037 
1038 
1039 
1040 
1041 template <class LabelType, class Tag>
1042 template <class U, class C, class Stop>
1044  ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const
1045 {
1046  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1047  "RandomForestn::predictLabel():"
1048  " Too few columns in feature matrix.");
1049  vigra_precondition(rowCount(features) == 1,
1050  "RandomForestn::predictLabel():"
1051  " Feature matrix must have a singlerow.");
1052  MultiArray<2, double> probabilities(Shape2(1, ext_param_.class_count_), 0.0);
1053  LabelType d;
1054  predictProbabilities(features, probabilities, stop);
1055  ext_param_.to_classlabel(argMax(probabilities), d);
1056  return d;
1057 }
1058 
1059 
1060 //Same thing as above with priors for each label !!!
1061 template <class LabelType, class PreprocessorTag>
1062 template <class U, class C>
1065  ArrayVectorView<double> priors) const
1066 {
1067  using namespace functor;
1068  vigra_precondition(columnCount(features) >= ext_param_.column_count_,
1069  "RandomForestn::predictLabel(): Too few columns in feature matrix.");
1070  vigra_precondition(rowCount(features) == 1,
1071  "RandomForestn::predictLabel():"
1072  " Feature matrix must have a single row.");
1073  Matrix<double> prob(1,ext_param_.class_count_);
1074  predictProbabilities(features, prob);
1075  std::transform( prob.begin(), prob.end(),
1076  priors.begin(), prob.begin(),
1077  Arg1()*Arg2());
1078  LabelType d;
1079  ext_param_.to_classlabel(argMax(prob), d);
1080  return d;
1081 }
1082 
1083 template<class LabelType,class PreprocessorTag>
1084 template <class T1,class T2, class C>
1086  ::predictProbabilities(OnlinePredictionSet<T1> & predictionSet,
1087  MultiArrayView<2, T2, C> & prob)
1088 {
1089  //Features are n xp
1090  //prob is n x NumOfLabel probability for each feature in each class
1091 
1092  vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
1093  "RandomFroest::predictProbabilities():"
1094  " Feature matrix and probability matrix size mismatch.");
1095  // num of features must be bigger than num of features in Random forest training
1096  // but why bigger?
1097  vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
1098  "RandomForestn::predictProbabilities():"
1099  " Too few columns in feature matrix.");
1100  vigra_precondition( columnCount(prob)
1101  == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1102  "RandomForestn::predictProbabilities():"
1103  " Probability matrix must have as many columns as there are classes.");
1104  prob.init(0.0);
1105  //store total weights
1106  std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
1107  //Go through all trees
1108  int set_id=-1;
1109  for(int k=0; k<options_.tree_count_; ++k)
1110  {
1111  set_id=(set_id+1) % predictionSet.indices[0].size();
1112  typedef std::set<SampleRange<T1> > my_set;
1113  typedef typename my_set::iterator set_it;
1114  //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it;
1115  //Build a stack with all the ranges we have
1116  std::vector<std::pair<int,set_it> > stack;
1117  stack.clear();
1118  for(set_it i=predictionSet.ranges[set_id].begin();
1119  i!=predictionSet.ranges[set_id].end();++i)
1120  stack.push_back(std::pair<int,set_it>(2,i));
1121  //get weights predicted by single tree
1122  int num_decisions=0;
1123  while(!stack.empty())
1124  {
1125  set_it range=stack.back().second;
1126  int index=stack.back().first;
1127  stack.pop_back();
1128  ++num_decisions;
1129 
1130  if(trees_[k].isLeafNode(trees_[k].topology_[index]))
1131  {
1132  ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
1133  trees_[k].parameters_,
1134  index).prob_begin();
1135  for(int i=range->start;i!=range->end;++i)
1136  {
1137  //update votecount.
1138  for(int l=0; l<ext_param_.class_count_; ++l)
1139  {
1140  prob(predictionSet.indices[set_id][i], l) += static_cast<T2>(weights[l]);
1141  //every weight in totalWeight.
1142  totalWeights[predictionSet.indices[set_id][i]] += static_cast<T1>(weights[l]);
1143  }
1144  }
1145  }
1146 
1147  else
1148  {
1149  if(trees_[k].topology_[index]!=i_ThresholdNode)
1150  {
1151  throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes");
1152  }
1153  Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
1154  if(range->min_boundaries[node.column()]>=node.threshold())
1155  {
1156  //Everything goes to right child
1157  stack.push_back(std::pair<int,set_it>(node.child(1),range));
1158  continue;
1159  }
1160  if(range->max_boundaries[node.column()]<node.threshold())
1161  {
1162  //Everything goes to the left child
1163  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1164  continue;
1165  }
1166  //We have to split at this node
1167  SampleRange<T1> new_range=*range;
1168  new_range.min_boundaries[node.column()]=FLT_MAX;
1169  range->max_boundaries[node.column()]=-FLT_MAX;
1170  new_range.start=new_range.end=range->end;
1171  int i=range->start;
1172  while(i!=range->end)
1173  {
1174  //Decide for range->indices[i]
1175  if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
1176  {
1177  new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
1178  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1179  --range->end;
1180  --new_range.start;
1181  std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
1182 
1183  }
1184  else
1185  {
1186  range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
1187  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
1188  ++i;
1189  }
1190  }
1191  //The old one ...
1192  if(range->start==range->end)
1193  {
1194  predictionSet.ranges[set_id].erase(range);
1195  }
1196  else
1197  {
1198  stack.push_back(std::pair<int,set_it>(node.child(0),range));
1199  }
1200  //And the new one ...
1201  if(new_range.start!=new_range.end)
1202  {
1203  std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
1204  stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
1205  }
1206  }
1207  }
1208  predictionSet.cumulativePredTime[k]=num_decisions;
1209  }
1210  for(unsigned int i=0;i<totalWeights.size();++i)
1211  {
1212  double test=0.0;
1213  //Normalise votes in each row by total VoteCount (totalWeight
1214  for(int l=0; l<ext_param_.class_count_; ++l)
1215  {
1216  test+=prob(i,l);
1217  prob(i, l) /= totalWeights[i];
1218  }
1219  assert(test==totalWeights[i]);
1220  assert(totalWeights[i]>0.0);
1221  }
1222 }
1223 
1224 template <class LabelType, class PreprocessorTag>
1225 template <class U, class C1, class T, class C2, class Stop_t>
1228  MultiArrayView<2, T, C2> & prob,
1229  Stop_t & stop_) const
1230 {
1231  //Features are n xp
1232  //prob is n x NumOfLabel probability for each feature in each class
1233 
1234  vigra_precondition(rowCount(features) == rowCount(prob),
1235  "RandomForestn::predictProbabilities():"
1236  " Feature matrix and probability matrix size mismatch.");
1237 
1238  // num of features must be bigger than num of features in Random forest training
1239  // but why bigger?
1240  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1241  "RandomForestn::predictProbabilities():"
1242  " Too few columns in feature matrix.");
1243  vigra_precondition( columnCount(prob)
1244  == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1245  "RandomForestn::predictProbabilities():"
1246  " Probability matrix must have as many columns as there are classes.");
1247 
1248  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1249  Default_Stop_t default_stop(options_);
1250  typename RF_CHOOSER(Stop_t)::type & stop
1251  = RF_CHOOSER(Stop_t)::choose(stop_, default_stop);
1252  #undef RF_CHOOSER
1253  stop.set_external_parameters(ext_param_, tree_count());
1254  prob.init(NumericTraits<T>::zero());
1255  /* This code was originally there for testing early stopping
1256  * - we wanted the order of the trees to be randomized
1257  if(tree_indices_.size() != 0)
1258  {
1259  std::random_shuffle(tree_indices_.begin(),
1260  tree_indices_.end());
1261  }
1262  */
1263  //Classify for each row.
1264  for(int row=0; row < rowCount(features); ++row)
1265  {
1266  MultiArrayView<2, U, StridedArrayTag> currentRow(rowVector(features, row));
1267 
1268  // when the features contain an NaN, the instance doesn't belong to any class
1269  // => indicate this by returning a zero probability array.
1270  if(detail::contains_nan(currentRow))
1271  {
1272  rowVector(prob, row).init(0.0);
1273  continue;
1274  }
1275 
1277 
1278  //totalWeight == totalVoteCount!
1279  double totalWeight = 0.0;
1280 
1281  //Let each tree classify...
1282  for(int k=0; k<options_.tree_count_; ++k)
1283  {
1284  //get weights predicted by single tree
1285  weights = trees_[k /*tree_indices_[k]*/].predict(currentRow);
1286 
1287  //update votecount.
1288  int weighted = options_.predict_weighted_;
1289  for(int l=0; l<ext_param_.class_count_; ++l)
1290  {
1291  double cur_w = weights[l] * (weighted * (*(weights-1))
1292  + (1-weighted));
1293  prob(row, l) += static_cast<T>(cur_w);
1294  //every weight in totalWeight.
1295  totalWeight += cur_w;
1296  }
1297  if(stop.after_prediction(weights,
1298  k,
1299  rowVector(prob, row),
1300  totalWeight))
1301  {
1302  break;
1303  }
1304  }
1305 
1306  //Normalise votes in each row by total VoteCount (totalWeight
1307  for(int l=0; l< ext_param_.class_count_; ++l)
1308  {
1309  prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
1310  }
1311  }
1312 
1313 }
1314 
1315 template <class LabelType, class PreprocessorTag>
1316 template <class U, class C1, class T, class C2>
1318  ::predictRaw(MultiArrayView<2, U, C1>const & features,
1319  MultiArrayView<2, T, C2> & prob) const
1320 {
1321  //Features are n xp
1322  //prob is n x NumOfLabel probability for each feature in each class
1323 
1324  vigra_precondition(rowCount(features) == rowCount(prob),
1325  "RandomForestn::predictProbabilities():"
1326  " Feature matrix and probability matrix size mismatch.");
1327 
1328  // num of features must be bigger than num of features in Random forest training
1329  // but why bigger?
1330  vigra_precondition( columnCount(features) >= ext_param_.column_count_,
1331  "RandomForestn::predictProbabilities():"
1332  " Too few columns in feature matrix.");
1333  vigra_precondition( columnCount(prob)
1334  == static_cast<MultiArrayIndex>(ext_param_.class_count_),
1335  "RandomForestn::predictProbabilities():"
1336  " Probability matrix must have as many columns as there are classes.");
1337 
1338  #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_>
1339  prob.init(NumericTraits<T>::zero());
1340  /* This code was originally there for testing early stopping
1341  * - we wanted the order of the trees to be randomized
1342  if(tree_indices_.size() != 0)
1343  {
1344  std::random_shuffle(tree_indices_.begin(),
1345  tree_indices_.end());
1346  }
1347  */
1348  //Classify for each row.
1349  for(int row=0; row < rowCount(features); ++row)
1350  {
1352 
1353  //totalWeight == totalVoteCount!
1354  double totalWeight = 0.0;
1355 
1356  //Let each tree classify...
1357  for(int k=0; k<options_.tree_count_; ++k)
1358  {
1359  //get weights predicted by single tree
1360  weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
1361 
1362  //update votecount.
1363  int weighted = options_.predict_weighted_;
1364  for(int l=0; l<ext_param_.class_count_; ++l)
1365  {
1366  double cur_w = weights[l] * (weighted * (*(weights-1))
1367  + (1-weighted));
1368  prob(row, l) += static_cast<T>(cur_w);
1369  //every weight in totalWeight.
1370  totalWeight += cur_w;
1371  }
1372  }
1373  }
1374  prob/= options_.tree_count_;
1375 
1376 }
1377 
1378 //@}
1379 
1380 } // namespace vigra
1381 
1382 #include "random_forest/rf_algorithm.hxx"
1383 #endif // VIGRA_RANDOM_FOREST_HXX
int feature_count() const
return number of features used while training.
Definition: random_forest.hxx:320
Definition: rf_region.hxx:57
void set_ext_param(ProblemSpec_t const &in)
set external parameters
Definition: random_forest.hxx:274
Definition: rf_nodeproxy.hxx:626
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
MultiArrayIndex rowCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:669
Definition: rf_preprocessing.hxx:63
int class_count() const
return number of classes used while training.
Definition: random_forest.hxx:339
int column_count() const
return number of features used while training.
Definition: random_forest.hxx:331
Create random samples from a sequence of indices.
Definition: sampling.hxx:233
Int32 leftParent
Definition: rf_region.hxx:69
Definition: rf_split.hxx:993
Definition: matrix.hxx:121
problem specification class for the random forest.
Definition: rf_common.hxx:533
RandomForest(Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t())
default constructor
Definition: random_forest.hxx:192
void sample()
Definition: sampling.hxx:468
INT & child(Int32 l)
Definition: rf_nodeproxy.hxx:224
LabelType predictLabel(MultiArrayView< 2, U, C >const &features, Stop &stop) const
predict a label given a feature.
Definition: random_forest.hxx:1044
ProblemSpec_t const & ext_param() const
return external parameters for viewing
Definition: random_forest.hxx:256
Definition: accessor.hxx:43
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const
predict multiple labels with given features
Definition: random_forest.hxx:582
Standard early stopping criterion.
Definition: rf_common.hxx:880
void reset_tree(int tree_id)
Definition: rf_visitors.hxx:625
Definition: random.hxx:669
DecisionTree_t & tree(int index)
access trees
Definition: random_forest.hxx:310
Definition: rf_nodeproxy.hxx:87
Options_t & set_options()
access random forest options
Definition: random_forest.hxx:286
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random)
learn on data with custom config and random number generator
Definition: random_forest.hxx:933
void reLearnTree(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random)
Definition: random_forest.hxx:830
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const
predict the class probabilities for multiple labels
Definition: random_forest.hxx:667
Definition: random_forest.hxx:144
const difference_type & shape() const
Definition: multi_array.hxx:1596
Definition: rf_visitors.hxx:244
detail::SelectIntegerType< 32, detail::SignedIntTypes >::type Int32
32-bit signed int
Definition: sized_int.hxx:175
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, LabelType nanLabel) const
predict multiple labels with given features
Definition: random_forest.hxx:606
Iterator argMax(Iterator first, Iterator last)
Find the maximum element in a sequence.
Definition: algorithm.hxx:96
Definition: rf_visitors.hxx:573
MultiArrayShape< 2 >::type Shape2
shape type for MultiArray<2, T>
Definition: multi_shape.hxx:254
SamplerOptions & stratified(bool in=true)
Draw equally many samples from each "stratum". A stratum is a group of like entities, e.g. pixels belonging to the same object class. This is useful to create balanced samples when the class probabilities are very unbalanced (e.g. when there are many background and few foreground pixels). Stratified sampling thus avoids that a trained classifier is biased towards the majority class.
Definition: sampling.hxx:144
SamplerOptions & withReplacement(bool in=true)
Sample from training population with replacement.
Definition: sampling.hxx:86
MultiArrayView< 2, T, C > rowVector(MultiArrayView< 2, T, C > const &m, MultiArrayIndex d)
Definition: matrix.hxx:695
const_iterator begin() const
Definition: array_vector.hxx:223
MultiArrayIndex columnCount(const MultiArrayView< 2, T, C > &x)
Definition: matrix.hxx:682
RandomForest(int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t())
Create RF from external source.
Definition: random_forest.hxx:226
int tree_count() const
return number of trees
Definition: random_forest.hxx:346
Definition: random.hxx:336
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:652
Options object for the random forest.
Definition: rf_common.hxx:170
MultiArrayView & init(const U &init)
Definition: multi_array.hxx:1154
size_type size() const
Definition: array_vector.hxx:358
void predictProbabilities(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const
predict the class probabilities for multiple labels
DecisionTree_t const & tree(int index) const
access const trees
Definition: random_forest.hxx:303
const_iterator end() const
Definition: array_vector.hxx:237
void learn(MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels)
learn on data with default configuration
Definition: random_forest.hxx:526
Definition: rf_visitors.hxx:224
void predictLabels(MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const
predict multiple labels with given features
Definition: random_forest.hxx:631
Options_t const & options() const
access const random forest options
Definition: random_forest.hxx:296

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.11.0