// file: $isip/class/stat/StatisticalModel/StatisticalModel.h // version: $Id: StatisticalModel.h 9378 2003-12-23 20:56:11Z alphonso $ // // make sure definitions are only made once // #ifndef ISIP_STATISTICAL_MODEL #define ISIP_STATISTICAL_MODEL // isip include files // #ifndef ISIP_STATISTICAL_MODEL_BASE #include #endif #ifndef ISIP_STATISTICAL_MODEL_UNDEFINED #include #endif // forward class definitions // class VectorDouble; class MatrixFloat; class NameMap; class MixtureModel; class GaussianModel; class SupportVectorModel; class RelevanceVectorModel; class UniformModel; // StatisticalModel: a class whose chief goal is to evaluate the likelihood // of an input vector with respect to a statistical model. this class provides // a virtual interface to any class that implements the StatisticalModelBase // interface contract. // class StatisticalModel : public StatisticalModelBase { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // other important constants // //---------------------------------------- // define the currently available model types // enum TYPE { UNKNOWN = 0, GAUSSIAN_MODEL, MIXTURE_MODEL, UNIFORM_MODEL, SUPPORT_VECTOR_MODEL, RELEVANCE_VECTOR_MODEL, DEF_TYPE = UNKNOWN }; // define the implementation choices // enum ALGORITHM {MIXTURE_SPLITTING = 0, DEF_ALGORITHM = MIXTURE_SPLITTING}; // define the implementation choices // enum IMPLEMENTATION {VARIANCE_SPLITTING = 0, DEF_IMPLEMENTATION = VARIANCE_SPLITTING}; // a static name map // static const NameMap TYPE_MAP; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String PARAM_TYPE; static const String DEF_PARAM; //---------------------------------------- // // default values and arguments // //---------------------------------------- // define the default value(s) of the class data // static const StatisticalModelUndefined NO_STAT_MODEL; // define the perturb factor for splitting the models // static const float32 DEF_PERTURB_FACTOR = 0.2; //---------------------------------------- // // error codes // //---------------------------------------- static const int32 ERR = 60300; //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: // algorithm name // ALGORITHM algorithm_d; // implementation name // IMPLEMENTATION implementation_d; // a virtual pointer to a statistical model // StatisticalModelBase* virtual_model_d; // memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } // other static methods // static bool8 diagnose(Integral::DEBUG debug_level); // method: debug // setDebug is inherited from base class // bool8 debug(const unichar* message) const { return virtual_model_d->debug(message); } // method: destructor // ~StatisticalModel() { if (virtual_model_d != (StatisticalModelBase*)&NO_STAT_MODEL) { delete virtual_model_d; } } // method: default constructor // StatisticalModel(TYPE type = DEF_TYPE) { algorithm_d = DEF_ALGORITHM; implementation_d = DEF_IMPLEMENTATION; virtual_model_d = (StatisticalModelBase*)&NO_STAT_MODEL; setType(type); } // method: copy constructor // StatisticalModel(const StatisticalModel& arg) { virtual_model_d = (StatisticalModelBase*)&NO_STAT_MODEL; assign(arg); } // method: assign // bool8 assign(const StatisticalModel& arg) { algorithm_d = arg.algorithm_d; implementation_d = arg.implementation_d; return assign(*arg.virtual_model_d); } // method: operator= // StatisticalModel& operator=(const StatisticalModel& arg) { assign(arg); return *this; } // method: eq // bool8 eq(const StatisticalModel& arg) const { return virtual_model_d->eq(*arg.virtual_model_d); } // method: sofSize // int32 sofSize() const { return (TYPE_MAP.elementSofSize() + virtual_model_d->sofSize()); } // method: sofAccumulatorSize // int32 sofAccumulatorSize() const { return (TYPE_MAP.elementSofSize() + virtual_model_d->sofAccumulatorSize()); } // method: sofOccupanciesSize // int32 sofOccupanciesSize() const { return (TYPE_MAP.elementSofSize() + virtual_model_d->sofOccupanciesSize()); } // other i/o methods // bool8 read(Sof& sof, int32 tag, const String& name = CLASS_NAME); bool8 write(Sof& sof, int32 tag, const String& name = CLASS_NAME) const; bool8 readData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); bool8 writeData(Sof& sof, const String& pname = DEF_PARAM) const; bool8 readAccumulator(Sof& sof, int32 tag, const String& name = CLASS_NAME); bool8 writeAccumulator(Sof& sof, int32 tag, const String& name = CLASS_NAME) const; bool8 readAccumulatorData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); bool8 writeAccumulatorData(Sof& sof, const String& pname = DEF_PARAM) const; bool8 readOccupancies(Sof& sof, int32 tag, const String& name = CLASS_NAME); bool8 writeOccupancies(Sof& sof, int32 tag, const String& name = CLASS_NAME) const; bool8 readOccupanciesData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); bool8 writeOccupanciesData(Sof& sof, const String& pname = DEF_PARAM) const; // method: operator new // static void* operator new(size_t size) { return mgr_d.get(); } // method: operator new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: operator delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: operator delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 size) { return mgr_d.setGrow(size); } // other memory-management methods // bool8 clear(Integral::CMODE cmode = Integral::DEF_CMODE); //--------------------------------------------------------------------------- // // class-specific public methods: // methods needed for training models // //--------------------------------------------------------------------------- // method: resetAccumulators // bool8 resetAccumulators() { return virtual_model_d->resetAccumulators(); } // method: getOccupancy // float64 getOccupancy() { return virtual_model_d->getOccupancy(); } // method: setOccupancy // bool8 setOccupancy(float64 arg) { return virtual_model_d->setOccupancy(arg); } // method: getAccessCount // int32 getAccessCount() { return virtual_model_d->getAccessCount(); } // method: setAccessCount // bool8 setAccessCount(int32 arg) { return virtual_model_d->setAccessCount(arg); } // method: initialize // bool8 initialize(VectorFloat& param) { return virtual_model_d->initialize(param); } // method: accumulate // bool8 accumulate(VectorFloat& data) { return virtual_model_d->accumulate(data); } // method: accumulate // bool8 accumulate(VectorDouble& param, VectorFloat& data, bool8 precomp) { return virtual_model_d->accumulate(param, data, precomp); } // method: update // bool8 update(VectorFloat& varfloor, int32 min_count) { return virtual_model_d->update(varfloor, min_count); } // method: addModelToMixture // static bool8 addModelToMixture(StatisticalModel& model_a, StatisticalModel& mixture_a); // method: splitMixtureModel // bool8 splitMixtureModel(int32 arg); //--------------------------------------------------------------------------- // // class-specific public methods: // additional methods needed to facilitate base class manipulations // //--------------------------------------------------------------------------- // method: getAlgorithm // ALGORITHM getAlgorithm() { return algorithm_d; } // method: getImplementation // IMPLEMENTATION setImplementation() { return implementation_d; } // method: setAlgorithm // bool8 setAlgorithm(ALGORITHM algorithm) { algorithm_d = algorithm; return true; } // method: setImplementation // bool8 setImplementation(IMPLEMENTATION implementation) { implementation_d = implementation; return true; } // configuration methods // bool8 setType(TYPE type); // method: getType // TYPE getType() const { return (TYPE)TYPE_MAP(virtual_model_d->className()); } //--------------------------------------------------------------------------- // // class-specific public methods: // these functions interface to the base class interface contract. // //--------------------------------------------------------------------------- // method: constructor // StatisticalModel(const StatisticalModelBase& arg) { algorithm_d = DEF_ALGORITHM; implementation_d = DEF_IMPLEMENTATION; virtual_model_d = (StatisticalModelBase*)&NO_STAT_MODEL; assign(arg); } // StatisticalModelBase required methods // bool8 assign(const StatisticalModelBase& arg); // method: eq // bool8 eq(const StatisticalModelBase& arg) const { return virtual_model_d->eq(arg); } // method: setMode // bool8 setMode(MODE mode) { return virtual_model_d->setMode(mode); } // method: getMode // MODE getMode() const { return virtual_model_d->getMode(); } // method: className // const String& className() const { return virtual_model_d->className(); } // method: init // bool8 init() { return virtual_model_d->init(); } // method: getMean // bool8 getMean(VectorFloat& mean) { return virtual_model_d->getMean(mean); } // method: getCovariance // bool8 getCovariance(MatrixFloat& cov) { return virtual_model_d->getCovariance(cov); } // method: getLikelihood // float32 getLikelihood(const VectorFloat& input) { return virtual_model_d->getLikelihood(input); } // method: getLogLikelihood // float32 getLogLikelihood(const VectorFloat& input) { return virtual_model_d->getLogLikelihood(input); } // method: getMixtureModel // MixtureModel& getMixtureModel() { return *((MixtureModel*)virtual_model_d); } // method: getGaussianModel // GaussianModel& getGaussianModel() { return *((GaussianModel*)virtual_model_d); } // method: getSupportVectorModel // SupportVectorModel& getSupportVectorModel() { return *((SupportVectorModel*)virtual_model_d); } // method: getRelevanceVectorModel // RelevanceVectorModel& getRelevanceVectorModel() { return *((RelevanceVectorModel*)virtual_model_d); } // method: getUniformModel // UniformModel& getUniformModel() { return *((UniformModel*)virtual_model_d); } //--------------------------------------------------------------------------- // // private methods // //--------------------------------------------------------------------------- private: }; // end of include file // #endif