// file: $isip_ifc/class/pr/RelevanceVectorMachine/RelevanceVectorMachine.h // version: $Id: RelevanceVectorMachine.h 10424 2006-02-10 20:37:45Z raghavan $ // // make sure definitions are only made once // #ifndef ISIP_RELEVANCE_VECTOR_MACHINE #define ISIP_RELEVANCE_VECTOR_MACHINE // isip include files // #ifndef ISIP_SDB #include #endif #ifndef ISIP_KERNEL #include #endif #ifndef ISIP_VECTOR #include #endif #ifndef ISIP_MATRIX_FLOAT #include #endif #ifndef ISIP_VECTOR_BYTE #include #endif #ifndef ISIP_AUDIO_DATABASE #include #endif #ifndef ISIP_FEATURE_FILE #include #endif #ifndef ISIP_STATISTICAL_MODEL #include #endif #ifndef ISIP_RVM_TRAIN_DATA #include #endif #ifndef ISIP_RELEVANCE_VECTOR_MODEL #include #endif #ifndef ISIP_MEMORY_MANAGER #include #endif // RelevanceVectorMachine: a class to train and test a relevance vector // machine (RVM) based dichotomous classifier. // // References: // // [1] M. E. Tipping, "Sparse Bayesian Learning and the Relevance Vector // Machine," Journal of Machine Learning Research, vol 1, pp. 211-244, // June 2001. // // [2] David J. C. MacKay, "Probable Networks and Plausible // Predictions - A Review of Practical Bayesian Methods for Supervised // Neural Networks," Network: Computation in Neural Systems, vol. 6, no. 3, // pp. 469-505, 1995. // // [3] A. Faul and M. E. Tipping, "Analysis of Sparse Bayesian Learning," // Proceedings of the 2001 Neural Information Processing Systems, 2001. // class RelevanceVectorMachine { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //--------------------------------------------------------------------------- // // other important constants // //--------------------------------------------------------------------------- // define the algorithm choices // enum ALGORITHM { FULL = 0, ITERATIVE_REFINEMENT, TIPPING_CONSTRUCTIVE, DEF_ALGORITHM = FULL}; // define the implementation choices // enum IMPLEMENTATION { LINEAR = 0, POLYNOMIAL, RBF, SIGMOID, DEF_IMPLEMENTATION = LINEAR }; // define output type choices // enum TYPE { TEXT = 0, BINARY, DEF_TYPE = BINARY }; // define the static NameMap objects // static const NameMap IMPL_MAP; static const NameMap ALGO_MAP; static const NameMap TYPE_MAP; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String DEF_PARAM; static const String DEF_COMMENT_TAG; static const String PARAM_BIAS; static const String PARAM_WEIGHTS; static const String PARAM_INV_HESSIAN; static const String PARAM_TARGETS; static const String PARAM_VECTORS; static const String PARAM_KERNEL; static const String PARAM_AUDIO_DB; static const String PARAM_OUTPUT_TYPE; static const String PARAM_OUTPUT_FILE; static const String PARAM_POLYNOMIAL_DEGREE; static const String PARAM_RBF_GAMMA; static const String PARAM_SIGMOID_KAPPA; static const String PARAM_SIGMOID_DELTA; static const String PARAM_ALGORITHM; static const String PARAM_IMPLEMENTATION; static const String PARAM_ALPHA_THRESH; static const String PARAM_MIN_ALLOWED_WEIGHT; static const String PARAM_MAX_RVM_ITS; static const String PARAM_MAX_UPDATE_ITS; static const String PARAM_MIN_THETA; static const String PARAM_MOMENTUM; static const String PARAM_MAX_ADDITIONS; //---------------------------------------- // // default values and arguments // //---------------------------------------- static const float32 DEF_ALPHA_THRESH = 1e8; static const float32 DEF_MIN_ALLOWED_WEIGHT = 1e-8; static const int32 DEF_CV_SETS = 1; static const int32 DEF_CV_PERCENT = 80; static const float32 DEF_BIAS = 0.0; static const float32 DEF_POLYNOMIAL_DEGREE = 3.0; static const float32 DEF_RBF_GAMMA = 0.5; static const float32 DEF_SIGMOID_KAPPA = 1.0; static const float32 DEF_SIGMOID_DELTA = 1.0; // control parameters // static const float32 MIN_CV_PERCENT = 5.0; static const float32 MAX_CV_PERCENT = 95.0; static const float32 MIN_DATA_VALUE = -1e30; static const float32 MAX_DATA_VALUE = 1e30; //---------------------------------------- // // error codes // //---------------------------------------- static const int32 ERR = 999999; static const int32 ERR_INIT = 999998; //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: // algorithm name // ALGORITHM algorithm_d; // implementation name // IMPLEMENTATION implementation_d; // output type // TYPE output_type_d; // output file // Filename output_file_d; // audio database file // Filename audio_db_file_d; // audio database // AudioDatabase audio_db_d; // maximum number of points // Long max_points_d; // default polynomial degree // Float polynomial_degree_d; // default rbf gamma // Float rbf_gamma_d; // default sigmoid kappa // Float sigmoid_kappa_d; // default sigmoid delta // Float sigmoid_delta_d; // error cache parameter // VectorFloat error_cache_d; // relevance vector weights, vectors, and labels // MatrixFloat inv_hessian_d; // A in [2] VectorFloat weights_d; // w in [1] VectorByte targets_d; // t in [1] Vector vectors_d; // x in [1] Float bias_d; // kernel function: we assume that the same kernel is used for // all training points. this only makes sense though since anything else // would mean different training points operating in different spaces. // Kernel kernel_d; // K in [1] // input observations and their labels // Vector in_class_data_d; Vector out_class_data_d; // verbosity // static Integral::DEBUG verbosity_d; // a static debug level // static Integral::DEBUG debug_level_d; // a static memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } // other static methods // static bool8 diagnose(Integral::DEBUG debug_level); // method: setDebug // static bool8 setDebug(Integral::DEBUG level) { debug_level_d = level; return true; } // method: debug // bool8 debug(const unichar* msg) const; // method: destructor // ~RelevanceVectorMachine() { } // constructor // RelevanceVectorMachine(); // method: copy constructor // RelevanceVectorMachine(const RelevanceVectorMachine& arg) { assign(arg); } // method: assign // bool8 assign(const RelevanceVectorMachine& copy); // method: operator= // RelevanceVectorMachine& operator= (const RelevanceVectorMachine& arg) { assign(arg); return *this; } // i/o methods // int32 sofSize() const; bool8 read(Sof& sof, int32 tag, const String& name = CLASS_NAME); bool8 write(Sof& sof, int32 tag, const String& name = CLASS_NAME) const; bool8 readData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); bool8 writeData(Sof& sof, const String& pname = DEF_PARAM) const; // method: eq // bool8 eq(const RelevanceVectorMachine& arg) const; // method: new // static void* operator new(size_t size) { return mgr_d.get(); } // method: new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 grow_size) { return mgr_d.setGrow(grow_size); } // method: clear // bool8 clear(Integral::CMODE cmode = Integral::DEF_CMODE); //--------------------------------------------------------------------------- // // class-specific public methods: set and get methods // //--------------------------------------------------------------------------- // method: getAlgorithm // ALGORITHM getAlgorithm() { return algorithm_d; } // method: setAlgorithm // bool8 setAlgorithm(ALGORITHM arg) { return (algorithm_d = arg); } // method: getImplementation // IMPLEMENTATION getImplementation() { return implementation_d; } // method: setImplementation // bool8 setImplementation(IMPLEMENTATION arg) { return (implementation_d = arg); } // method: setDegree // bool8 setDegree(float32 arg) { if (arg != DEF_POLYNOMIAL_DEGREE) { polynomial_degree_d = arg; } return true; } // method: setKappa // bool8 setKappa(float32 arg) { if (arg != DEF_SIGMOID_KAPPA) { sigmoid_kappa_d = arg; } return true; } // method: setDelta // bool8 setDelta(float32 arg) { if (arg != DEF_SIGMOID_DELTA) { sigmoid_delta_d = arg; } return true; } // method: setGamma // bool8 setGamma(float32 arg) { if (arg != DEF_RBF_GAMMA) { rbf_gamma_d = arg; } return true; } // method: setVerbosity // bool8 setVerbosity(Integral::DEBUG verbosity) { verbosity_d = verbosity; return true; } // method: setOutputFile // bool8 setOutputFile(Filename& arg) { return output_file_d.assign(arg); } // method: setOutputType // bool8 setOutputType(TYPE arg) { output_type_d = arg; return true; } // there are currently no set methods for weights since these // should either be trained or read from a file of trained models. // const VectorFloat& getWeights() const { return weights_d; } // method: setKernel // bool8 setKernel(Kernel& kernel) { return kernel_d.assign(kernel); } // method: getKernel // const Kernel& getKernel() const { return kernel_d; } // method: setData // bool8 setData(Vector data) { return vectors_d.assign(data); } // method: getData // const Vector& getData() const { return vectors_d; } // method: setTargets // bool8 setTargets(VectorByte targets) { return targets_d.assign(targets); } // method: getTargets // const VectorByte& getTargets() const { return targets_d; } //--------------------------------------------------------------------------- // // class-specific public methods: core computational methods // //--------------------------------------------------------------------------- // this method initializes the kernel parameters // bool8 initKernel(); // methods to evaluate the RVM at a specified point. note that this returns a // probability // bool8 evaluate(float32& score, const VectorFloat& input); //bool8 evaluate(float32& score, const VectorFloat& input); bool8 evaluate(); // methods to train the model given a set of targets and training vectors // bool8 train(RVMTrainData& tdata, ALGORITHM mode = DEF_ALGORITHM); // this method load the feature vectors used used in the optimization process // bool8 loadFeatures(Sdb& in_sdb, Sdb& out_sdb); // this method writes the relevance vector model to file // bool8 writeModel(); // these methods are used for normalizing the data // bool8 normalizeStdNorm(); bool8 normalizeUnitRange(); bool8 normalizeUnitVector(); //--------------------------------------------------------------------------- // // private data // //--------------------------------------------------------------------------- private: // auxilliary methods used during training and optimization // bool8 trainFull(RVMTrainData& tdata, bool8 is_init = false); bool8 trainIterativeRefinement(RVMTrainData& tdata, bool8 is_init = false); bool8 trainTippingConstructive(RVMTrainData& tdata, bool8 is_init = false); bool8 initFull(RVMTrainData& tdata); bool8 initIterativeRefinement(RVMTrainData& tdata); bool8 initTippingConstructive(RVMTrainData& tdata); bool8 irlsTrain(RVMTrainData& tdata); bool8 computeSigma(RVMTrainData& tdata); bool8 updateHyperparameters(bool8& updated, RVMTrainData& tdata, ALGORITHM = FULL, bool8 use_mackay = true); bool8 updateHyperparametersFull(bool8& updated, RVMTrainData& tdata, bool8 use_mackay = true); bool8 updateHyperparametersIncremental(bool8& updated, RVMTrainData& tdata); bool8 pruneWeights(RVMTrainData& tdata); bool8 pruneAndUpdate(RVMTrainData& tdata); bool8 finalizeTraining(RVMTrainData& tdata, ALGORITHM mode); bool8 reinitialize(RVMTrainData& tdata); bool8 getPhiRow(VectorFloat& ovec, int32 index); float32 computeLikelihood(RVMTrainData& tdata) const; bool8 computeVarianceCholesky(RVMTrainData& tdata); // training restart facilities // bool8 readRestartData(RVMTrainData& tdata); bool8 writeRestartData(RVMTrainData& tdata); // facilities for data caching // bool8 rebuildPhiCache(RVMTrainData& tdata); bool8 getPhiRowFromCache(VectorFloat& ovec, int32 index, RVMTrainData& tdata); bool8 resetActiveParams(RVMTrainData& tdata); // incremental training facilities // bool8 initializeRVSubsets(RVMTrainData& tdata); bool8 initializeSubsetRVM(RelevanceVectorMachine& training_rvm, RVMTrainData& subset_data, RVMTrainData& global_data); bool8 updateSubsetRVM(RVMTrainData& global_data, RelevanceVectorMachine& training_rvm, RVMTrainData& subset_data); }; // end of include file // #endif