// file: $isip/class/search/TrainNode/TrainNode.h // // make sure definitions are only made once // #ifndef ISIP_TRAIN_NODE #define ISIP_TRAIN_NODE #ifndef ISIP_LONG #include #endif #ifndef ISIP_VECTOR #include #endif #ifndef ISIP_VECTOR_FLOAT #include #endif #ifndef ISIP_VECTOR_DOUBLE #include #endif #ifndef ISIP_CONTEXT #include #endif #ifndef ISIP_STATISTICAL_MODEL #include #endif // TrainNode: A class to store training information for each state in // the search space. // class TrainNode { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String DEF_PARAM; //---------------------------------------- // // other important constants // //---------------------------------------- //---------------------------------------- // // default values and arguments // //---------------------------------------- // define the default time stamp // static const int32 DEF_TIMESTAMP = -1; //--------------------------------------- // // error codes // //--------------------------------------- //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: // define the time instance (t) // int32 frame_d; // define the backward probability for the state at time (t) // float64 beta_d; // define the forward probability for the state at time (t) // float64 alpha_d; // define the score associated with the state at time (t) // float64 score_d; // define a flag that tells us if the train node is reachable form // the valid hypothesis // bool8 is_valid_d; bool8 is_alpha_valid_d; bool8 is_beta_valid_d; bool8 is_accum_valid_d; // define the trace/instance reference pointer // Context* reference_d; // define the statistical model pointer // StatisticalModel* stat_model_d; // define a static debug level // static Integral::DEBUG debug_level_d; // define a static memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } // method: diagnose // static bool8 diagnose(Integral::DEBUG debug_level); // method: debug // bool8 debug(const unichar* message) const; // method: setDebug // static bool8 setDebug(Integral::DEBUG debug_level) { debug_level_d = debug_level; return true; } // method: destructor // ~TrainNode() { if (debug_level_d >= Integral::ALL) { fprintf(stdout, "Destructor of train_node: %p\n", this); fflush(stdout); } } // constructor(s) // TrainNode(); TrainNode(const TrainNode& copy_node); // assign methods // bool8 assign(const TrainNode& copy_node); // method: sofSize // int32 sofSize() const { return Error::handle(name(), L"sofSize", Error::ARG, __FILE__, __LINE__); } // method: read // bool8 read(Sof& sof, int32 tag, const String& cname = CLASS_NAME) { return Error::handle(name(), L"read", Error::ARG, __FILE__, __LINE__); } // method: write // bool8 write(Sof& sof, int32 tag, const String& cname = CLASS_NAME) const { return Error::handle(name(), L"write", Error::ARG, __FILE__, __LINE__); } // method: readData // bool8 readData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false) { return Error::handle(name(), L"readData", Error::ARG, __FILE__, __LINE__); } // method: writeData // bool8 writeData(Sof& sof, const String& pname = DEF_PARAM) const { return Error::handle(name(), L"writeData", Error::ARG, __FILE__, __LINE__); } // equality method // bool8 eq(const TrainNode& compare_node) const; // method: new // static void* operator new(size_t size) { return mgr_d.get(); } // method: new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 grow_size) { return mgr_d.setGrow(grow_size); } // clear methods // bool8 clear(Integral::CMODE ctype = Integral::DEF_CMODE); //--------------------------------------------------------------------------- // // class-specific public methods // //--------------------------------------------------------------------------- // method: setAlpha // bool8 setAlpha(float64 arg) { return (alpha_d = arg); } // method: getAlpha // float64 getAlpha() const { return alpha_d; } // method: setBeta // bool8 setBeta(float64 arg) { return (beta_d = arg); } // method: getBeta // float64 getBeta() const { return beta_d; } // method: setFrame // bool8 setFrame(int32 arg) { return (frame_d = arg); } // method: getFrame // int32 getFrame() const { return frame_d; } // method: setScore // bool8 setScore(float64 arg) { return (score_d = arg); } // method: getScore // float64 getScore() const { return score_d; } // method: setReference // bool8 setReference(Context* arg) { return (reference_d = arg); } // method: getReference // Context* getReference() const { return reference_d; } // method: getStatisticalModel // StatisticalModel* getStatisticalModel() { return stat_model_d; } // method: setStatisticalModel // bool8 setStatisticalModel(StatisticalModel* arg) { return (stat_model_d = arg); } // method: getValidModel // bool8 getValidModel() const { if (stat_model_d != (StatisticalModel*)NULL) { return true; } return false; } // method: setValidNode // bool8 setValidNode(bool8 arg) { return (is_valid_d = arg); } // method: getValidNode // bool8 getValidNode() const { return is_valid_d; } // method: setAlphaValid // bool8 setAlphaValid(bool8 arg = true) { return (is_alpha_valid_d = arg); } // method: isAlphaValid // bool8 isAlphaValid() const { return is_alpha_valid_d; } // method: setAccumulatorValid // bool8 setAccumulatorValid(bool8 arg = true) { return (is_accum_valid_d = arg); } // method: isAccumulatorValid // bool8 isAccumulatorValid() const { return is_accum_valid_d; } // method: setBetaValid // bool8 setBetaValid(bool8 arg = true) { return (is_beta_valid_d = arg); } // method: isBetaValid // bool8 isBetaValid() const { return is_beta_valid_d; } //--------------------------------------------------------------------------- // // class-specific public methods: // accumulate and update methods needed for training models // //--------------------------------------------------------------------------- // method to update the models using the accumulators generated // during training // bool8 update(VectorFloat& varfloor, int32 min_model); // method to accumulate the statistics in training which are // used to update the model // bool8 accumulate(float64 utter_prob, Vector& data, float32 min_mpd, float32 min_occupancy); //--------------------------------------------------------------------------- // // private methods // //--------------------------------------------------------------------------- private: }; // end of include file // #endif