// file: $isip_ifc/class/stat/RelevanceVectorModel/RelevanceVectorModel.h // version: $Id: RelevanceVectorModel.h 10423 2006-02-09 00:52:30Z raghavan $ // // make sure definitions are only made once // #ifndef ISIP_RELEVANCE_VECTOR_MODEL #define ISIP_RELEVANCE_VECTOR_MODEL // isip include files // #ifndef ISIP_VECTOR #include #endif #ifndef ISIP_KERNEL #include #endif #ifndef ISIP_VECTOR_FLOAT #include #endif #ifndef ISIP_VECTOR_BYTE #include #endif #ifndef ISIP_MATRIX_FLOAT #include #endif #ifndef ISIP_MEMORY_MANAGER #include #endif #ifndef ISIP_STATISTICAL_MODEL_BASE #include #endif // RelevanceVectorModel: a class to compute the probability of the test vector // from a hyperplane in the transformed space // // N // ---- // score = bias + \ [ weight(i) * kernel (x, sv[i]) ] // /___ i // i // // g[i] = kernel (x, sv[i]), where i = 0 to N // s_squared = g * inv_hessian * g' // kappa = 1.0 / sqrt(1 + PI * s_squared / 8.0) // probability = 1.0 / (1.0 + exp(-(kappa * score))) // // 'sv' is the support vector // 'N' is the number of relevance vectors // 'weight(i)' is the weigth for i-th relevance vector // 'bias' is the distance of hyperplane from the origin // // resulting distance (positive or negative) can be transformed to // posterior probability using e.g. sigmoid transformation // // probability = sigmoid (distance) // class RelevanceVectorModel : public StatisticalModelBase { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // other important constants // //---------------------------------------- enum ALGORITHM { SINGLE_KERNEL = 0, MULTIPLE_KERNEL, DEF_ALGORITHM = SINGLE_KERNEL }; enum IMPLEMENTATION { LOGISTIC_SIGMOID_LINK = 0, DEF_IMPLEMENTATION = LOGISTIC_SIGMOID_LINK }; // define static NameMap objects for the enumerated values // static const NameMap ALGO_MAP; static const NameMap IMPL_MAP; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String DEF_PARAM; static const String PARAM_ALGORITHM; static const String PARAM_IMPLEMENTATION; static const String PARAM_BIAS; static const String PARAM_WEIGHTS; static const String PARAM_INV_HESSIAN; static const String PARAM_TARGETS; static const String PARAM_VECTORS; static const String PARAM_KERNELS; static const String DEF_COMMENT_TAG; //---------------------------------------- // // error codes // //---------------------------------------- //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: // define algorithm and implementation // ALGORITHM algorithm_d; IMPLEMENTATION implementation_d; // relevance vector weights, vectors, and labels // MatrixFloat inv_hessian_d; VectorFloat weights_d; VectorByte targets_d; Vector vectors_d; Float bias_d; // kernel function: we assume that the same kernel is used for // all training points. this only makes sense though since anything else // would mean different training points operating in different spaces. // Vector kernels_d; // memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } // other static methods // static bool8 diagnose(Integral::DEBUG debug_level); // debug methods // setDebug is inherited from base class // bool8 debug(const unichar* msg) const; // method: destructor // ~RelevanceVectorModel() {} // method: default constructor // RelevanceVectorModel() { algorithm_d = DEF_ALGORITHM; implementation_d = DEF_IMPLEMENTATION; } // method: copy constructor // RelevanceVectorModel(const RelevanceVectorModel& arg) { assign(arg); } // assign methods // bool8 assign(const RelevanceVectorModel& arg); // method: operator= // RelevanceVectorModel& operator=(const RelevanceVectorModel& arg) { assign(arg); return *this; } // sofSize method // int32 sofSize() const; // method: sofAccumulatorSize // int32 sofAccumulatorSize() const { return Error::handle(name(), L"sofAccumulatorSize", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: sofOccupanciesSize // int32 sofOccupanciesSize() const { return Error::handle(name(), L"sofOccupanciesSize", Error::NOT_IMPLEM, __FILE__, __LINE__); } // other i/o methods // bool8 read(Sof& sof, int32 tag, const String& name = CLASS_NAME); bool8 write(Sof& sof, int32 tag, const String& name = CLASS_NAME) const; bool8 readData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); bool8 writeData(Sof& sof, const String& pname = DEF_PARAM) const; // method: readAccumulator // bool8 readAccumulator(Sof& sof, int32 tag, const String& cname = CLASS_NAME) { return Error::handle(name(), L"readAccumulator", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: writeAccumulator // bool8 writeAccumulator(Sof& sof, int32 tag, const String& cname = CLASS_NAME) const { return Error::handle(name(), L"writeAccumulator", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: readAccumulatorData // bool8 readAccumulatorData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false) { return Error::handle(name(), L"readAccumulatorData", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: writeAccumulatorData // bool8 writeAccumulatorData(Sof& sof, const String& pname = DEF_PARAM) const { return Error::handle(name(), L"writeAccumulatorData", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: readOccupancies // bool8 readOccupancies(Sof& sof, int32 tag, const String& cname = CLASS_NAME) { return Error::handle(name(), L"readOccupancies", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: writeOccupancies // bool8 writeOccupancies(Sof& sof, int32 tag, const String& cname = CLASS_NAME) const { return Error::handle(name(), L"writeOccupancies", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: readOccupanciesData // bool8 readOccupanciesData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false) { return Error::handle(name(), L"readOccupanciesData", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: writeOccupanciesData // bool8 writeOccupanciesData(Sof& sof, const String& pname = DEF_PARAM) const { return Error::handle(name(), L"writeOccupanciesData", Error::NOT_IMPLEM, __FILE__, __LINE__); } // equality methods // bool8 eq(const RelevanceVectorModel& arg) const; // method: new // static void* operator new(size_t size) { return mgr_d.get(); } // method: new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 grow_size) { return mgr_d.setGrow(grow_size); } // method: clear // bool8 clear(Integral::CMODE cmode = Integral::DEF_CMODE) { inv_hessian_d.clear(cmode); weights_d.clear(cmode); targets_d.clear(cmode); vectors_d.clear(cmode); bias_d.clear(cmode); kernels_d.clear(cmode); return true; } //--------------------------------------------------------------------------- // // class-specific public methods: // additional methods unique to this class // //-------------------------------------------------------------------------- // method: setAlgorithm // bool8 setAlgorithm(ALGORITHM algo) { algorithm_d = algo; return true; } // method: getAlgorithm // ALGORITHM getAlgorithm() { return algorithm_d; } // method: setImplementation // bool8 setImplementation(IMPLEMENTATION impl) { implementation_d = impl; return true; } // method: getImplemenation // IMPLEMENTATION getImplemenation() { return implementation_d; } // method: getInvHessian // MatrixFloat& getInvHessian() { return inv_hessian_d; } // method: getWeights // VectorFloat& getWeights() { return weights_d; } // method: getTargets // VectorByte& getTargets() { return targets_d; } // method: getRelevanceVectors // Vector& getRelevanceVectors() { return vectors_d; } // method: getBias // Float& getBias() { return bias_d; } // method: getKernel // Kernel& getKernel(int32 index) { if (algorithm_d == SINGLE_KERNEL) { return kernels_d(index); } else { return kernels_d(0); } } // method: getKernels // Vector& getKernels() { return kernels_d; } // get the SVM distance (this method will check dimensions first) // float32 getDistanceProb(const VectorFloat& input); //--------------------------------------------------------------------------- // // class-specific public methods: // required for the base class interface contract // //-------------------------------------------------------------------------- // StatisticalModelBase required methods // bool8 assign(const StatisticalModelBase& arg); bool8 eq(const StatisticalModelBase& arg) const; // set methods // bool8 setMode(MODE arg); // method: className // const String& className() const { return CLASS_NAME; } // initialization methods // bool8 init(); // method: getLikelihood // float32 getLikelihood(const VectorFloat& input) { return Integral::exp(getLogLikelihood(input)); } // computational methods // float32 getLogLikelihood(const VectorFloat& input); //--------------------------------------------------------------------------- // // class-specific public methods: // accumulate and update methods needed for training models // //--------------------------------------------------------------------------- // method: getMean // bool8 getMean(VectorFloat& mean) { return Error::handle(name(), L"getMean", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: getCovariance // bool8 getCovariance(MatrixFloat& cov) { return Error::handle(name(), L"getCovariance", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: resetAccumulators // bool8 resetAccumulators() { return Error::handle(name(), L"resetAccumulators", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: getOccupancy // float64 getOccupancy() { return Error::handle(name(), L"getOccupancy", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: setOccupancy // bool8 setOccupancy(float64 arg) { return Error::handle(name(), L"setOccupancy", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: getAccessCount // int32 getAccessCount() { return Error::handle(name(), L"getAccessCount", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method: setAccessCount // bool8 setAccessCount(int32 arg) { return Error::handle(name(), L"setAccessCount", Error::NOT_IMPLEM, __FILE__, __LINE__); } // method that accumulates the statistics for the model which are // needed to update the model parameters during training // bool8 accumulate(VectorDouble& param, VectorFloat& data, bool8 precomp); // method that updates the statistical model parameters using the // accumulated statistics during training // bool8 update(VectorFloat& varfloor, int32 min_count); // methods that initializes the statistical model parameters using // accumulated feature vectors // bool8 accumulate(VectorFloat& data); bool8 initialize(VectorFloat& param); //--------------------------------------------------------------------------- // // private methods // //--------------------------------------------------------------------------- private: }; // end of include file // #endif