// file: $isip/class/stats/MixtureModel/MixtureModel.h // version: $Id: MixtureModel.h 9742 2004-08-20 20:43:15Z may $ // // make sure definitions are only made once // #ifndef ISIP_MIXTURE_MODEL #define ISIP_MIXTURE_MODEL // isip include files // #ifndef ISIP_VECTOR_FLOAT #include #endif #ifndef ISIP_VECTOR_DOUBLE #include #endif #ifndef ISIP_MATRIX_FLOAT #include #endif #ifndef ISIP_SINGLE_LINKED_LIST #include #endif #ifndef ISIP_MEMORY_MANAGER #include #endif #ifndef ISIP_STATISTICAL_MODEL #include #endif // MixtureModel: a class to score test vectors according to a // distribution which is a linear combination of other StatisticalModels: // // L(x) = w1 * S1(x) + w2 * S2(x) + ... + wN * SN(x) // // 'L' is the output likelihood // 'N' is the number of models // 'w's are the weights // 'S's are the models // class MixtureModel : public StatisticalModelBase { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String DEF_PARAM; static const String PARAM_WEIGHTS; static const String PARAM_MODELS; //---------------------------------------- // // default values and arguments // //---------------------------------------- // define the default value(s) of the class data // static const float32 MIN_SCORE = -1e100; // default arguments to methods // static const float32 DEF_NORM = 1.0; //---------------------------------------- // // error codes // //---------------------------------------- static const int32 ERR = 60200; //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: // list of StatisticalModels // SingleLinkedList models_d; // vector of weights // VectorFloat weights_d; // memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } // other static methods // static bool8 diagnose(Integral::DEBUG debug_level); // debug methods // setDebug inherited from base class // bool8 debug(const unichar* msg) const; // method: destructor // ~MixtureModel() {} // method: default constructor // MixtureModel(MODE mode = DEF_MODE) { mode_d = mode; is_valid_d = false; } // method: copy constructor // MixtureModel(const MixtureModel& arg) { assign(arg); } // method: assign // bool8 assign(const MixtureModel& arg); // method: operator= // MixtureModel& operator=(const MixtureModel& arg) { assign(arg); return *this; } // method: sofSize // int32 sofSize() const { return (weights_d.sofSize() + models_d.sofSize()); } // method: sofAccumulatorSize // int32 sofAccumulatorSize() const { int32 size = 0; for (bool8 more = const_cast &>(models_d).gotoFirst(); more; more = const_cast &>(models_d).gotoNext()) { size += models_d.getCurr()->sofAccumulatorSize(); } return size; } // method: sofOccupanciesSize // int32 sofOccupanciesSize() const { int32 size = 0; for (bool8 more = const_cast &>(models_d).gotoFirst(); more; more = const_cast &>(models_d).gotoNext()) { size += models_d.getCurr()->sofOccupanciesSize(); } return size; } // other i/o methods // bool8 read(Sof& sof, int32 tag, const String& name = CLASS_NAME); bool8 write(Sof& sof, int32 tag, const String& name = CLASS_NAME) const; bool8 readData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); bool8 writeData(Sof& sof, const String& pname = DEF_PARAM) const; bool8 readAccumulator(Sof& sof, int32 tag, const String& name = CLASS_NAME); bool8 writeAccumulator(Sof& sof, int32 tag, const String& name = CLASS_NAME) const; bool8 readAccumulatorData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); bool8 writeAccumulatorData(Sof& sof, const String& pname = DEF_PARAM) const; bool8 readOccupancies(Sof& sof, int32 tag, const String& name = CLASS_NAME); bool8 writeOccupancies(Sof& sof, int32 tag, const String& name = CLASS_NAME) const; bool8 readOccupanciesData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); bool8 writeOccupanciesData(Sof& sof, const String& pname = DEF_PARAM) const; // equality methods // bool8 eq(const MixtureModel& arg) const; // method: operator new // static void* operator new(size_t size) { return mgr_d.get(); } // method: operator new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: operator delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: operator delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 grow_size) { return mgr_d.setGrow(grow_size); } // method: clear // bool8 clear(Integral::CMODE cmode = Integral::DEF_CMODE) { weights_d.clear(cmode); models_d.clear(cmode); is_valid_d = false; return true; } //--------------------------------------------------------------------------- // // class-specific public methods: // additional methods unique to this class needed to support the // interface contract // //-------------------------------------------------------------------------- // method: getModels // SingleLinkedList& getModels() { return models_d; } // method: setModels // bool8 setModels(SingleLinkedList& new_models); // methods to manipulate the weights // VectorFloat& getDaWeights() { return weights_d; } bool8 getWeights(VectorFloat& arg) const; bool8 setWeights(const VectorFloat& arg); // normalization methods // bool8 initializeWeights(); bool8 isNormalized(float32 norm = DEF_NORM) const; bool8 normalizeWeights(float32 norm = DEF_NORM); // methods to manipulate the models // bool8 add(StatisticalModelBase& ptr); //--------------------------------------------------------------------------- // // class-specific public methods: // required for the base class interface contract // //-------------------------------------------------------------------------- // StatisticalModelBase required methods // bool8 assign(const StatisticalModelBase& arg); bool8 eq(const StatisticalModelBase& arg) const; // set methods // bool8 setMode(MODE arg); // method: className // const String& className() const { return CLASS_NAME; } // initialization methods // bool8 init(); // method: getMean // bool8 getMean(VectorFloat& mean); // method: getCovariance // bool8 getCovariance(MatrixFloat& cov); // computational methods // float32 getLikelihood(const VectorFloat& input); float32 getLogLikelihood(const VectorFloat& input); //--------------------------------------------------------------------------- // // class-specific public methods: // mixture reestimation methods // //--------------------------------------------------------------------------- // method: getOccupancy // float64 getOccupancy() { Double occupancy = 0; for (bool8 more = models_d.gotoFirst(); more; more = models_d.gotoNext()) { occupancy += models_d.getCurr()->getOccupancy(); } return occupancy; } // method: setOccupancy // bool8 setOccupancy(float64 arg) { for (bool8 more = models_d.gotoFirst(); more; more = models_d.gotoNext()) { models_d.getCurr()->setOccupancy(arg); } return true; } // method: getAccessCount // int32 getAccessCount() { int32 num_access = 0; for (bool8 more = models_d.gotoFirst(); more; more = models_d.gotoNext()) { num_access += models_d.getCurr()->getAccessCount(); } return num_access; } // method: setAccessCount // bool8 setAccessCount(int32 arg) { for (bool8 more = models_d.gotoFirst(); more; more = models_d.gotoNext()) { models_d.getCurr()->setAccessCount(arg); } return true; } // method: resetAccumulators // bool8 resetAccumulators() { for (bool8 more = models_d.gotoFirst(); more; more = models_d.gotoNext()) { models_d.getCurr()->resetAccumulators(); } return true; } // method: initialize // bool8 initialize(VectorFloat& param); // method: accumulate // bool8 accumulate(VectorFloat& data); // method: accumulate // bool8 accumulate(VectorDouble& param, VectorFloat& data, bool8 precomp); // method: update // bool8 update(VectorFloat& varfloor, int32 min_count); //--------------------------------------------------------------------------- // // private methods // //--------------------------------------------------------------------------- private: }; // end of include file // #endif