// file: $isip/class/pr/RegressionDecisionTree/RegressionDecisionTree.h // version: $Id: RegressionDecisionTree.h 9470 2004-05-10 15:40:08Z gao $ // // make sure definitions are only made once // #ifndef ISIP_REGRESSION_DECISION_TREE #define ISIP_REGRESSION_DECISION_TREE // isip include files // #ifndef ISIP_DECISION_TREE_BASE #include #endif #ifndef ISIP_REGRESSION_DECISION_TREE_NODE #include #endif #ifndef ISIP_STATISTICAL_MODEL #include #endif #ifndef ISIP_GAUSSIAN_MODEL #include #endif #ifndef ISIP_MIXTURE_MODEL #include #endif #ifndef ISIP_BIGRAPH_ARC #include #endif #ifndef ISIP_BIGRAPH_VERTEX #include #endif // forward class definitions: // we must define the RegressionDecisionTreeNode class here first // because the header files might be short-circuited by the ifndef. // class RegressionDecisionTreeNode; // RegressionDecisionTree: a class that build a RegressionDecisionTree // by giving statistical models and speech/nonspeech tag for // statistical models. // class RegressionDecisionTree: public DecisionTreeBase { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // other important constants // //---------------------------------------- // define the algorithm choices // enum ALGORITHM { MLLR = 0, DEF_ALGORITHM = MLLR }; // define the implementation choices // enum IMPLEMENTATION { MEAN = 0, VARIANCE, COMBINED, DEF_IMPLEMENTATION = MEAN }; // define the adaptation supervision choices // enum SUPERVISION_MODE { SUPERVISED = 0, UNSUPERVISED, DEF_SUPERVISION_MODE = SUPERVISED }; // define the adaptation sequence choices // enum SEQUENCE_MODE { BATCH = 0, INCREMENTAL, DEF_SEQUENCE_MODE = BATCH }; // define the static NameMap objects // static const NameMap ALGO_MAP; static const NameMap IMPL_MAP; static const NameMap SUP_MODE_MAP; static const NameMap SEQ_MODE_MAP; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String DEF_PARAM; static const String PARAM_ALGORITHM; static const String PARAM_IMPLEMENTATION; static const String PARAM_SUPERVISION_MODE; static const String PARAM_SEQUENCE_MODE; static const String PARAM_SPEECH_FLAG; static const String PARAM_NUM_TERMINALS; static const String PARAM_PERTURB_DEPTH; static const String PARAM_OCC_THRESHOLD; static const String PARAM_COMPONENTS_THRESHOLD; static const String PARAM_MIXTURE_OFFSET; static const String PARAM_MAP_STAT_TO_TRANS; static const String PARAM_SPLIT_THRESHOLD; static const String PARAM_MERGE_THRESHOLD; static const String PARAM_NUM_OCC_THRESHOLD; static const String PARAM_BDT; //---------------------------------------- // // default values and arguments // //---------------------------------------- // define default values for the thresholds // static const int32 DEF_NUM_TERMINALS = 1; static const bool8 DEF_SPEECH_FLAG = true; static const float64 DEF_PERTURB_DEPTH = 0.2; static const float64 DEF_OCC_THRESHOLD = 1000.0; static const int32 DEF_COMPONENTS_THRESHOLD = 100; static const float32 DEF_MERGE_THRESHOLD = 5; static const float32 DEF_NUM_OCC_THRESHOLD = 100; //---------------------------------------- // // error codes // //---------------------------------------- static const int32 ERR = 00100500; static const int32 ERR_ADAPT_NO_GAUSSIAN = 00100501; static const int32 ERR_UNALG_UNIMP = 00100502; //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: // define the structures // typedef Triple RDataPoint; typedef SingleLinkedList RData; typedef BiGraphVertex RTreeNode; // algorithm name // ALGORITHM algorithm_d; // implementation name // IMPLEMENTATION implementation_d; // adaptation supervision mode // SUPERVISION_MODE supervision_mode_d; // adaptation sequence mode // SEQUENCE_MODE sequence_mode_d; // data on the root node // RegressionDecisionTreeNode rdt_rootnode_d; // flag to indicate if adapt non-speech sound // Boolean speech_flag_d; // total number of terminals // Long num_terminals_d; // perturb depth // Double perturb_depth_d; // minimum occupation count for adaptation // Double occ_threshold_d; // minimal components for each node to split // Long components_threshold_d; // variables use to get the transform id from statistical model // index and gaussian model index // VectorLong mixture_offset_d; VectorLong map_stat_to_trans_d; // local copy of the statistical model // Vector stat_models_d; // static memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } // other static methods // static bool8 diagnose(Integral::DEBUG debug_level); // debug methods: // setDebug is inherited from the base class // bool8 debug(const unichar* msg) const; // method: destructor // ~RegressionDecisionTree() { } // method: default constructor // RegressionDecisionTree(ALGORITHM algorithm = DEF_ALGORITHM, IMPLEMENTATION implementation = DEF_IMPLEMENTATION) { algorithm_d = algorithm; implementation_d = implementation; num_terminals_d = DEF_NUM_TERMINALS; } // method: copy constructor // RegressionDecisionTree(const RegressionDecisionTree& arg) { assign(arg); } // assign methods // bool8 assign(const RegressionDecisionTree& arg); // method: operator= // RegressionDecisionTree& operator= (const RegressionDecisionTree& arg) { assign(arg); return *this; } // i/o methods // int32 sofSize() const; // method: read // bool8 read(Sof& sof, int32 tag) { return read(sof, tag, name()); } bool8 read(Sof& sof, int32 tag, const String& name); // method: write // bool8 write(Sof& sof, int32 tag) const { return write(sof, tag, name()); } bool8 write(Sof& sof, int32 tag, const String& name) const; // method: readData // bool8 readData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); // method: writeData // bool8 writeData(Sof& sof, const String& pname = DEF_PARAM) const; // equality methods // bool8 eq(const RegressionDecisionTree& arg) const; // method: new // static void* operator new(size_t size) { return mgr_d.get(); } // method: new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 grow_size) { return mgr_d.setGrow(grow_size); } // other memory management methods // bool8 clear(Integral::CMODE ctype = Integral::DEF_CMODE); //--------------------------------------------------------------------------- // // class-specific public methods: // set methods // //--------------------------------------------------------------------------- // method: setAlgorithm // bool8 setAlgorithm(ALGORITHM algorithm) { algorithm_d = algorithm; return true; } // method: setImplementation // bool8 setImplementation(IMPLEMENTATION implementation) { implementation_d = implementation; return true; } // method: set // bool8 set(ALGORITHM algorithm = DEF_ALGORITHM, IMPLEMENTATION implementation = DEF_IMPLEMENTATION) { algorithm_d = algorithm; implementation_d = implementation; return true; } //--------------------------------------------------------------------------- // // class-specific public methods: // get methods // //--------------------------------------------------------------------------- // method: getAlgorithm // ALGORITHM getAlgorithm() const { return algorithm_d; } // method: getImplementation // IMPLEMENTATION getImplementation() const { return implementation_d; } // method: get // bool8 get(ALGORITHM& algorithm, IMPLEMENTATION& implementation) { algorithm = algorithm_d; implementation = implementation_d; return true; } // method: getMixtureOffset // bool8 getMixtureOffset(VectorLong& mixture_offset) const { mixture_offset.assign(mixture_offset_d); return true; } // method: getStatToTrans // bool8 getStatToTrans(VectorLong& stat_to_trans) const { stat_to_trans.assign(map_stat_to_trans_d); return true; } // method to init the regression tree // bool8 initRegressionTree(Vector& stat_models_a, Vector & speech_tag_a); // method to create transform // bool8 createTransform(Vector& stat_models_a); // method to create transform // bool8 createTransforms(RTreeNode*& root_node_a, Vector& stat_models_a); // method to get the mean value for the gaussian model // bool8 getMean(int32 sm_index_a, int32 gm_index_a, VectorFloat& mean_a); // method to get the GaussianModel // bool8 getGaussianModel(int32 sm_index_a, int32 gm_index_a, GaussianModel& gm_a); // updateTransformID // bool8 updateTransformID(Vector& stat_models_a); //--------------------------------------------------------------------------- // // class-specific public methods: // computational methods // //--------------------------------------------------------------------------- // runDecisionTree method // bool8 runDecisionTree(); // method to set the parser // bool8 setParser(SofParser* parser); //--------------------------------------------------------------------------- // // private methods // //--------------------------------------------------------------------------- private: // method to find best terminal // RTreeNode* findBestTerminal(RTreeNode*& node_a, RTreeNode*& best_a, float32& score_a); RTreeNode* findBestTerminal(RTreeNode*& node_a, RTreeNode*& best_a, int32& components_a); BiGraphVertex* findTerminal(RTreeNode*& root_node_a, int32 index_a); // method to build regression cluster // bool8 buildRegressionCluster(RTreeNode*& root_node); // method to perturb mean // bool8 perturbMean(RegressionDecisionTreeNode*& child_node_a, float64 perturbDepth); // method to cluster the children // bool8 clusterChildren(RegressionDecisionTreeNode*& root_node_a, RegressionDecisionTreeNode*& left_child_node_a, RegressionDecisionTreeNode*& right_child_a); // method to calculate the distance // bool8 calculateDistance(RegressionDecisionTreeNode*& root_node_a, RegressionDecisionTreeNode*& left_child_node_a, RegressionDecisionTreeNode*& right_child_a); // method to create the childre node // bool8 createChildNode(RegressionDecisionTreeNode*& root_node_a, RegressionDecisionTreeNode*& left_child_node_a, RegressionDecisionTreeNode*& right_child_a); }; // end of include file // #endif