// file: $isip_ifc/class/pr/RelevanceVectorMachine/RVMTrainData.h // version: $Id: RVMTrainData.h 10424 2006-02-10 20:37:45Z raghavan $ // // make sure definitions are only made once // #ifndef ISIP_RVM_TRAIN_DATA #define ISIP_RVM_TRAIN_DATA #ifndef ISIP_VECTOR #include #endif #ifndef ISIP_FILENAME #include #endif #ifndef ISIP_MATRIX_FLOAT #include #endif #ifndef ISIP_VECTOR_FLOAT #include #endif #ifndef ISIP_BOOLEAN #include #endif // RVMTrainData: a utility class to hold training quantities. Rather // than include all of this data in the RVM class itself // (memory-consumptive), a helper class can be passed around between // functions during training and is not present during prediction. All // member data is public so that superfluous calls to getXYZ are not // necessary in the training code // class RVMTrainData { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // default values and arguments // //---------------------------------------- static const float32 DEF_ALPHA_THRESH = 1e12; static const float32 DEF_MIN_ALLOWED_WEIGHT = 1e-8; static const int32 DEF_MAX_RVM_ITS = (int32)(5000); static const int32 DEF_MAX_UPDATE_ITS = (int32)(1000); static const float32 DEF_MIN_THETA = 1e-8; static const float32 DEF_MOMENTUM = 0.85; static const int32 DEF_MAX_ADDITIONS = 1; static const bool8 DEF_SAVE_RESTART = false; static const bool8 DEF_LOAD_RESTART = false; static const int32 DEF_MAX_MEMORY = 500000000; // ~ 1/2 Gigabyte static const bool8 DEF_USE_SUBSETS = true; static const int32 DEF_SUBSET_SIZE = 500; //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: // a static memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // public data // //--------------------------------------------------------------------------- public: // tuning parameters: These are the only parameters that a user need // worry about prior to training. The default quantities are usually // sufficient. However, run-time performance and accuracy can be // influenced by appropriately tuning these parameters. // // maximum hyperparameter value allowed before pruning. decreasing this // value can speed up convergence of the model but may yield overpruning // and poor generalization. the value should always be greater than zero // float32 alpha_thresh_d; // minimum value of a weight allowed in the model. typically as the weight // decreases toward zero, it should be pruned. // float32 min_allowed_weight_d; // maximum number of training iterations to carry out before stopping // adjusting this parameter can result in sub-optimal results // int32 max_rvm_its_d; // maximum number of iterations that are allowed to pass betweeen // model updates (adding or pruning of a hyperparameter) before training // is terminated (for the full mode of training) or a vector is manually // added (for the incremental mode of training) // int32 max_update_its_d; // minimum value of the theta calculation (the divisor of equation 17 in [3]) // that will trigger a model addition (in the incremental training mode). // float32 min_theta_d; // hyperparameter update momentum term. a larger value for this term // can lead to faster convergence, while too large a value can cause // oscillation. the value is typically on the range [0,1] // float32 momentum_d; // number of hyperparameters to add at a time. adding a small number of // hyperparameters at a time will yield a smoother movement through the // model space, but may increase the total convergence time. // int32 max_additions_d; // whether or not to create backup copies of training data. if true then // data will be occasionally saved to disk in the file provided. that // file can later be used to restart training in the middle of the // convergence process. *** the restart facility currently is available // only for incremental training *** // bool8 save_restart_d; Filename restart_save_file_d; // whether or not to bootstrap training from a restart file. if true then // the given restart file is read and training is continued from that point // forward. *** the restart facility currently is available // only for incremental training *** // bool8 load_restart_d; Filename restart_load_file_d; // maximum amount of memory to use for kernel computations. this will be // used in computing the number of subsets when using subset training. // it could also be used if we move to some sort of kernel caching // mechanism // int32 max_memory_d; // whether or not to use subset training. if true, then the data will be // divided into subsets and each subset will be trained separately until // they all simultaneously converge // bool8 use_subsets_d; int32 subset_size_d; //---------------------------------------- // // run-time training quantities: // typically, these are not modified by // user programs // //---------------------------------------- // model data // int32 num_samples_d; // number of remaining RVs MatrixFloat A_d; // hyperparameter matrix int32 dimA_d; // number of non-pruned params MatrixFloat phi_d; // working design matrix VectorFloat curr_weights_d; // updated weights VectorFloat last_rvm_weights_d; // stored weights for rvm pass VectorFloat targets_d; // IRLS training quantities // VectorFloat sigma_d; // error vector MatrixFloat B_d; // data-dependent "noise" VectorFloat gradient_d; // gradient w.r.t. weights MatrixFloat hessian_d; // hessian w.r.t. weights MatrixFloat covar_cholesky_d; // cholesky decomposition of covar VectorFloat old_irls_weights_d; // stored weights for irls pass int32 last_changed_d; // counter for last time model // changed // incremental training quantities // VectorFloat S_d; // updates for incremental train VectorFloat Q_d; // updates for incremental train VectorFloat hyperparams_d; // current hyperparameters Vector active_params_d; // parameters that can be modified VectorFloat weights_d; // current hyperparameters VectorFloat last_hyperparams_d; // previous iterations hyperparams VectorFloat twoback_hyperparams_d; // hyperparameters from two // iterations ago - used to // detect cycles in training int32 num_subsets_d; // number of training subsets int32 active_subset_d; // currently active subset Vector subsets_d; // subset indices VectorLong active_rvs_d; // data cache for larger training sets // /*** initialize these variables***/ VectorLong cache_index_table_d; // index mapping table for cache Vector phi_cache_d; // cached rows of the phi matrix //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } // other static methods: the diagnose method since there are no // computational methods to test. the setDebug method is empty // because all data is public and thus need not be accessed through a // debug method. // static bool8 diagnose(Integral::DEBUG debug_level) { return true; } // method: setDebug // static bool8 setDebug(Integral::DEBUG level) { return true; } // other debug methods: // bool8 debug(const unichar* msg) const; // method: destructor // ~RVMTrainData() { } // method: constructor // we need only initialize the tunable parameters // RVMTrainData(); // method: copy constructor // RVMTrainData(const RVMTrainData& arg) { assign(arg); } // assign methods: // bool8 assign(const RVMTrainData& copy); // method: operator= // RVMTrainData& operator= (const RVMTrainData& arg) { assign(arg); return *this; } // i/o methods: this class can neither be written nor read // // method: sofSize // int32 sofSize() const { return 0; } // method: read // bool8 read(Sof& sof, int32 tag, const String& name = CLASS_NAME) { return true; } // method: write // bool8 write(Sof& sof, int32 tag, const String& name = CLASS_NAME) const { return true; } // method: readData // bool8 readData(Sof& sof, const String& pname = CLASS_NAME, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false) { return true; } // method: writeData // bool8 writeData(Sof& sof, const String& pname = CLASS_NAME) const { return true; } // method: eq // bool8 eq(const RVMTrainData& arg) const; // method: new // static void* operator new(size_t size) { return mgr_d.get(); } // method: new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 grow_size) { return mgr_d.setGrow(grow_size); } // method: clear // bool8 clear(Integral::CMODE cmode = Integral::DEF_CMODE); bool8 assignParameters(const RVMTrainData& copy); }; // end of include file // #endif