37 #ifndef VIGRA_RF_COMMON_HXX 38 #define VIGRA_RF_COMMON_HXX 44 struct ClassificationTag
69 friend RF_DEFAULT& ::vigra::rf_default();
99 template<
class T,
class C>
104 static T & choose(T & t, C &)
111 class Value_Chooser<detail::RF_DEFAULT, C>
116 static C & choose(detail::RF_DEFAULT &, C & c)
133 static detail::RF_DEFAULT result;
176 double training_set_proportion_;
177 int training_set_size_;
178 int (*training_set_func_)(int);
180 training_set_calc_switch_;
182 bool sample_with_replacement_;
184 stratification_method_;
195 int (*mtry_func_)(int) ;
197 bool predict_weighted_;
199 int min_split_node_size_;
200 bool prepare_online_learning_;
204 typedef std::map<std::string, double_array> map_type;
206 int serialized_size()
const 215 #define COMPARE(field) result = result && (this->field == rhs.field); 216 COMPARE(training_set_proportion_);
217 COMPARE(training_set_size_);
218 COMPARE(training_set_calc_switch_);
219 COMPARE(sample_with_replacement_);
220 COMPARE(stratification_method_);
221 COMPARE(mtry_switch_);
223 COMPARE(tree_count_);
224 COMPARE(min_split_node_size_);
225 COMPARE(predict_weighted_);
232 return !(*
this == rhs_);
235 void unserialize(Iter
const & begin, Iter
const & end)
238 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
239 "RandomForestOptions::unserialize():" 240 "wrong number of parameters");
241 #define PULL(item_, type_) item_ = type_(*iter); ++iter; 242 PULL(training_set_proportion_,
double);
243 PULL(training_set_size_,
int);
246 PULL(sample_with_replacement_, 0 != );
251 PULL(tree_count_,
int);
252 PULL(min_split_node_size_,
int);
253 PULL(predict_weighted_, 0 !=);
257 void serialize(Iter
const & begin, Iter
const & end)
const 260 vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
261 "RandomForestOptions::serialize():" 262 "wrong number of parameters");
263 #define PUSH(item_) *iter = double(item_); ++iter; 264 PUSH(training_set_proportion_);
265 PUSH(training_set_size_);
266 if(training_set_func_ != 0)
274 PUSH(training_set_calc_switch_);
275 PUSH(sample_with_replacement_);
276 PUSH(stratification_method_);
288 PUSH(min_split_node_size_);
289 PUSH(predict_weighted_);
293 void make_from_map(map_type & in)
295 #define PULL(item_, type_) item_ = type_(in[#item_][0]); 296 #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0); 297 PULL(training_set_proportion_,
double);
298 PULL(training_set_size_,
int);
300 PULL(tree_count_,
int);
301 PULL(min_split_node_size_,
int);
302 PULLBOOL(sample_with_replacement_,
bool);
303 PULLBOOL(prepare_online_learning_,
bool);
304 PULLBOOL(predict_weighted_,
bool);
317 void make_map(map_type & in)
const 319 #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_)); 320 #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0)); 321 PUSH(training_set_proportion_,
double);
322 PUSH(training_set_size_,
int);
324 PUSH(tree_count_,
int);
325 PUSH(min_split_node_size_,
int);
326 PUSH(sample_with_replacement_,
bool);
327 PUSH(prepare_online_learning_,
bool);
328 PUSH(predict_weighted_,
bool);
334 PUSHFUNC(mtry_func_,
int);
335 PUSHFUNC(training_set_func_,
int);
348 training_set_proportion_(1.0),
349 training_set_size_(0),
350 training_set_func_(0),
351 training_set_calc_switch_(RF_PROPORTIONAL),
352 sample_with_replacement_(true),
353 stratification_method_(RF_NONE),
354 mtry_switch_(RF_SQRT),
357 predict_weighted_(false),
359 min_split_node_size_(1),
360 prepare_online_learning_(false)
376 vigra_precondition(in == RF_EQUAL ||
377 in == RF_PROPORTIONAL ||
380 "RandomForestOptions::use_stratification()" 381 "input must be RF_EQUAL, RF_PROPORTIONAL," 382 "RF_EXTERNAL or RF_NONE");
383 stratification_method_ = in;
389 prepare_online_learning_=in;
399 sample_with_replacement_ = in;
413 training_set_proportion_ = in;
414 training_set_calc_switch_ = RF_PROPORTIONAL;
422 training_set_size_ = in;
423 training_set_calc_switch_ = RF_CONST;
435 training_set_func_ = in;
436 training_set_calc_switch_ = RF_FUNCTION;
444 predict_weighted_ =
true;
457 vigra_precondition(in == RF_LOG ||
460 "RandomForestOptions()::features_per_node():" 461 "input must be of type RF_LOG or RF_SQRT");
475 mtry_switch_ = RF_CONST;
487 mtry_switch_ = RF_FUNCTION;
511 min_split_node_size_ = in;
532 template<
class LabelType =
double>
545 typedef std::map<std::string, double_array> map_type;
563 void to_classlabel(
int index, T & out)
const 565 out = T(classes[index]);
568 int to_classIndex(T index)
const 570 return std::find(classes.
begin(), classes.
end(), index) - classes.
begin();
573 #define EQUALS(field) field(rhs.field) 576 EQUALS(column_count_),
577 EQUALS(class_count_),
579 EQUALS(actual_mtry_),
580 EQUALS(actual_msample_),
581 EQUALS(problem_type_),
583 EQUALS(class_weights_),
584 EQUALS(is_weighted_),
586 EQUALS(response_size_)
588 std::back_insert_iterator<ArrayVector<Label_t> >
590 std::copy(rhs.classes.
begin(), rhs.classes.
end(), iter);
593 #define EQUALS(field) field(rhs.field) 597 EQUALS(column_count_),
598 EQUALS(class_count_),
600 EQUALS(actual_mtry_),
601 EQUALS(actual_msample_),
602 EQUALS(problem_type_),
604 EQUALS(class_weights_),
605 EQUALS(is_weighted_),
607 EQUALS(response_size_)
609 std::back_insert_iterator<ArrayVector<Label_t> >
611 std::copy(rhs.classes.
begin(), rhs.classes.
end(), iter);
615 #define EQUALS(field) (this->field = rhs.field); 616 ProblemSpec & operator=(ProblemSpec
const & rhs)
618 EQUALS(column_count_);
619 EQUALS(class_count_);
621 EQUALS(actual_mtry_);
622 EQUALS(actual_msample_);
623 EQUALS(problem_type_);
625 EQUALS(is_weighted_);
627 EQUALS(response_size_)
628 class_weights_.clear();
629 std::back_insert_iterator<ArrayVector<double> >
630 iter2(class_weights_);
631 std::copy(rhs.class_weights_.
begin(), rhs.class_weights_.
end(), iter2);
633 std::back_insert_iterator<ArrayVector<Label_t> >
635 std::copy(rhs.classes.
begin(), rhs.classes.
end(), iter);
642 EQUALS(column_count_);
643 EQUALS(class_count_);
645 EQUALS(actual_mtry_);
646 EQUALS(actual_msample_);
647 EQUALS(problem_type_);
649 EQUALS(is_weighted_);
651 EQUALS(response_size_)
652 class_weights_.clear();
653 std::back_insert_iterator<ArrayVector<double> >
654 iter2(class_weights_);
655 std::copy(rhs.class_weights_.
begin(), rhs.class_weights_.
end(), iter2);
657 std::back_insert_iterator<ArrayVector<Label_t> >
659 std::copy(rhs.classes.
begin(), rhs.classes.
end(), iter);
668 #define COMPARE(field) result = result && (this->field == rhs.field); 669 COMPARE(column_count_);
670 COMPARE(class_count_);
672 COMPARE(actual_mtry_);
673 COMPARE(actual_msample_);
674 COMPARE(problem_type_);
675 COMPARE(is_weighted_);
678 COMPARE(class_weights_);
680 COMPARE(response_size_)
687 return !(*
this == rhs);
691 size_t serialized_size()
const 693 return 10 + class_count_ *int(is_weighted_+1);
698 void unserialize(Iter
const & begin, Iter
const & end)
701 vigra_precondition(end - begin >= 10,
702 "ProblemSpec::unserialize():" 703 "wrong number of parameters");
704 #define PULL(item_, type_) item_ = type_(*iter); ++iter; 705 PULL(column_count_,
int);
706 PULL(class_count_,
int);
708 vigra_precondition(end - begin >= 10 + class_count_,
709 "ProblemSpec::unserialize(): 1");
710 PULL(row_count_,
int);
711 PULL(actual_mtry_,
int);
712 PULL(actual_msample_,
int);
714 PULL(is_weighted_,
int);
716 PULL(precision_,
double);
717 PULL(response_size_,
int);
720 vigra_precondition(end - begin == 10 + 2*class_count_,
721 "ProblemSpec::unserialize(): 2");
722 class_weights_.insert(class_weights_.
end(),
724 iter + class_count_);
725 iter += class_count_;
727 classes.insert(classes.
end(), iter, end);
733 void serialize(Iter
const & begin, Iter
const & end)
const 736 vigra_precondition(end - begin == serialized_size(),
737 "RandomForestOptions::serialize():" 738 "wrong number of parameters");
739 #define PUSH(item_) *iter = double(item_); ++iter; 744 PUSH(actual_msample_);
749 PUSH(response_size_);
752 std::copy(class_weights_.
begin(),
753 class_weights_.
end(),
755 iter += class_count_;
757 std::copy(classes.
begin(),
763 void make_from_map(map_type & in)
765 #define PULL(item_, type_) item_ = type_(in[#item_][0]); 766 PULL(column_count_,
int);
767 PULL(class_count_,
int);
768 PULL(row_count_,
int);
769 PULL(actual_mtry_,
int);
770 PULL(actual_msample_,
int);
772 PULL(is_weighted_,
int);
774 PULL(precision_,
double);
775 PULL(response_size_,
int);
776 class_weights_ = in[
"class_weights_"];
779 void make_map(map_type & in)
const 781 #define PUSH(item_) in[#item_] = double_array(1, double(item_)); 786 PUSH(actual_msample_);
791 PUSH(response_size_);
792 in[
"class_weights_"] = class_weights_;
804 problem_type_(CHECKLATER),
812 ProblemSpec & column_count(
int in)
822 template<
class C_Iter>
826 int size = end-begin;
827 for(
int k=0; k<size; ++k, ++begin)
828 classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
838 template<
class W_Iter>
841 class_weights_.clear();
842 class_weights_.insert(class_weights_.
end(), begin, end);
853 class_weights_.clear();
858 problem_type_ = CHECKLATER;
859 is_weighted_ =
false;
883 int min_split_node_size_;
887 : min_split_node_size_(opt.min_split_node_size_)
891 void set_external_parameters(
ProblemSpec<T>const &,
int = 0,
bool =
false)
894 template<
class Region>
895 bool operator()(Region& region)
897 return region.size() < min_split_node_size_;
900 template<
class WeightIter,
class T,
class C>
910 #endif //VIGRA_RF_COMMON_HXX RandomForestOptions & features_per_node(RF_OptionTag in)
use built in mapping to calculate mtry
Definition: rf_common.hxx:455
RandomForestOptions & tree_count(int in)
Definition: rf_common.hxx:495
detail::RF_DEFAULT & rf_default()
factory function to return a RF_DEFAULT tag
Definition: rf_common.hxx:131
RandomForestOptions & samples_per_tree(double in)
specify the fraction of the total number of samples used per tree for learning.
Definition: rf_common.hxx:411
RandomForestOptions & features_per_node(int(*in)(int))
use a external function to calculate mtry
Definition: rf_common.hxx:484
const_iterator begin() const
Definition: array_vector.hxx:223
RandomForestOptions & samples_per_tree(int in)
directly specify the number of samples per tree
Definition: rf_common.hxx:420
RandomForestOptions & samples_per_tree(int(*in)(int))
use external function to calculate the number of samples each tree should be learnt with...
Definition: rf_common.hxx:433
problem specification class for the random forest.
Definition: rf_common.hxx:533
LabelType Label_t
problem class
Definition: rf_common.hxx:542
RandomForestOptions & min_split_node_size(int in)
Number of examples required for a node to be split.
Definition: rf_common.hxx:509
Definition: accessor.hxx:43
RandomForestOptions & features_per_node(int in)
Set mtry to a constant value.
Definition: rf_common.hxx:472
Standard early stopping criterion.
Definition: rf_common.hxx:880
ProblemSpec & classes_(C_Iter begin, C_Iter end)
supply with class labels -
Definition: rf_common.hxx:823
bool operator!=(FFTWComplex< R > const &a, const FFTWComplex< R > &b)
not equal
Definition: fftw3.hxx:841
bool operator==(FFTWComplex< R > const &a, const FFTWComplex< R > &b)
equal
Definition: fftw3.hxx:825
RandomForestOptions()
create a RandomForestOptions object with default initialisation.
Definition: rf_common.hxx:346
ProblemSpec & class_weights(W_Iter begin, W_Iter end)
supply with class weights -
Definition: rf_common.hxx:839
RandomForestOptions & sample_with_replacement(bool in)
sample from training population with or without replacement?
Definition: rf_common.hxx:397
RandomForestOptions & predict_weighted()
weight each tree with number of samples in that node
Definition: rf_common.hxx:442
Base class for, and view to, vigra::MultiArray.
Definition: multi_array.hxx:652
Options object for the random forest.
Definition: rf_common.hxx:170
RandomForestOptions & use_stratification(RF_OptionTag in)
specify stratification strategy
Definition: rf_common.hxx:374
const_iterator end() const
Definition: array_vector.hxx:237
ProblemSpec()
set default values (-> values not set)
Definition: rf_common.hxx:798
RF_OptionTag
Definition: rf_common.hxx:140
Problem_t
problem types
Definition: rf_common.hxx:519