// file: $isip/class/pr/PhoneticDecisionTree/PhoneticDecisionTree.h // version: $Id: PhoneticDecisionTree.h 9061 2003-04-01 02:06:52Z duncan $ // // make sure definitions are only made once // #ifndef ISIP_PHONETIC_DECISION_TREE #define ISIP_PHONETIC_DECISION_TREE // isip include files // #ifndef ISIP_DECISION_TREE_BASE #include #endif #ifndef ISIP_PHONETIC_DECISION_TREE_NODE #include #endif #ifndef ISIP_SEARCH_SYMBOL #include #endif #ifndef ISIP_STATISTICAL_MODEL #include #endif #ifndef ISIP_GAUSSIAN_MODEL #include #endif #ifndef ISIP_MIXTURE_MODEL #include #endif #ifndef ISIP_BIGRAPH_ARC #include #endif #ifndef ISIP_CONTEXT_MAP #include #endif #ifndef ISIP_SEARCH_NODE #include #endif // PhoneticDecisionTree: a class that computes the PhoneticDecisionTree. // currently. // class PhoneticDecisionTree: public DecisionTreeBase { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // other important constants // //---------------------------------------- // define the algorithm choices // enum ALGORITHM { ML = 0, DEF_ALGORITHM = ML }; // define the implementation choices // enum IMPLEMENTATION { DEFAULT = 0, DEF_IMPLEMENTATION = DEFAULT }; // define the static NameMap objects // static const NameMap ALGO_MAP; static const NameMap IMPL_MAP; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String DEF_PARAM; static const String PARAM_ALGORITHM; static const String PARAM_IMPLEMENTATION; static const String PARAM_SPLIT_THRESHOLD; static const String PARAM_MERGE_THRESHOLD; static const String PARAM_NUM_OCC_THRESHOLD; static const String PARAM_BDT; //---------------------------------------- // // other static constants // //---------------------------------------- static const String YES; static const String NO; static const String CPH; static const String POS; //---------------------------------------- // // default values and arguments // //---------------------------------------- // define default values for the thresholds // static const float32 DEF_SPLIT_THRESHOLD = 10; static const float32 DEF_MERGE_THRESHOLD = 5; static const float32 DEF_NUM_OCC_THRESHOLD = 100; //---------------------------------------- // // error codes // //---------------------------------------- static const int32 ERR = 00100300; //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: // define the structures // typedef Triple< Pair, Float, Boolean> TopoTriple; typedef Triple > DataPoint; typedef SingleLinkedList Data; typedef BiGraphVertex TreeNode; // algorithm name // ALGORITHM algorithm_d; // implementation name // IMPLEMENTATION implementation_d; // data on the root node // PhoneticDecisionTreeNode pdt_rootnode_d; // thresholds for building the decision trees // Float split_threshold_d; Float merge_threshold_d; Float num_occ_threshold_d; // static memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } // other static methods // static bool8 diagnose(Integral::DEBUG debug_level); // debug methods: // setDebug is inherited from the base class // bool8 debug(const unichar* msg) const; // method: destructor // ~PhoneticDecisionTree() { } // method: default constructor // PhoneticDecisionTree(ALGORITHM algorithm = DEF_ALGORITHM, IMPLEMENTATION implementation = DEF_IMPLEMENTATION, float32 split_threshold = DEF_SPLIT_THRESHOLD, float32 merge_threshold = DEF_MERGE_THRESHOLD, float32 num_occ_threshold = DEF_NUM_OCC_THRESHOLD) { algorithm_d = algorithm; implementation_d = implementation; split_threshold_d = split_threshold; merge_threshold_d = merge_threshold; num_occ_threshold_d = num_occ_threshold; } // method: copy constructor // PhoneticDecisionTree(const PhoneticDecisionTree& arg) { assign(arg); } // assign methods // bool8 assign(const PhoneticDecisionTree& arg); // method: operator= // PhoneticDecisionTree& operator= (const PhoneticDecisionTree& arg) { assign(arg); return *this; } // i/o methods // int32 sofSize() const; // method: read // bool8 read(Sof& sof, int32 tag) { return read(sof, tag, name()); } bool8 read(Sof& sof, int32 tag, const String& name); // method: write // bool8 write(Sof& sof, int32 tag) const { return write(sof, tag, name()); } bool8 write(Sof& sof, int32 tag, const String& name) const; // method: readData // bool8 readData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); // method: writeData // bool8 writeData(Sof& sof, const String& pname = DEF_PARAM) const; // equality methods // bool8 eq(const PhoneticDecisionTree& arg) const; // method: new // static void* operator new(size_t size) { return mgr_d.get(); } // method: new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 grow_size) { return mgr_d.setGrow(grow_size); } // other memory management methods // bool8 clear(Integral::CMODE ctype = Integral::DEF_CMODE); //--------------------------------------------------------------------------- // // class-specific public methods: // set methods // //--------------------------------------------------------------------------- // method: setAlgorithm // bool8 setAlgorithm(ALGORITHM algorithm) { algorithm_d = algorithm; return true; } // method: setImplementation // bool8 setImplementation(IMPLEMENTATION implementation) { implementation_d = implementation; return true; } // method: setSplitThreshold // bool8 setSplitThreshold(float32 split_threshold) { split_threshold_d = split_threshold; return true; } // method: setMergeThreshold // bool8 setMergeThreshold(float32 merge_threshold) { merge_threshold_d = merge_threshold; return true; } // method: setNumOccThreshold // bool8 setNumOccThreshold(float32 num_occ_threshold) { num_occ_threshold_d = num_occ_threshold; return true; } // method: set // bool8 set(ALGORITHM algorithm = DEF_ALGORITHM, IMPLEMENTATION implementation = DEF_IMPLEMENTATION, float32 split_threshold = DEF_SPLIT_THRESHOLD, float32 merge_threshold = DEF_MERGE_THRESHOLD, float32 num_occ_threshold = DEF_NUM_OCC_THRESHOLD) { algorithm_d = algorithm; implementation_d = implementation; split_threshold_d = split_threshold; merge_threshold_d = merge_threshold; num_occ_threshold_d = num_occ_threshold; return true; } //--------------------------------------------------------------------------- // // class-specific public methods: // get methods // //--------------------------------------------------------------------------- // method: getAlgorithm // ALGORITHM getAlgorithm() const { return algorithm_d; } // method: getImplementation // IMPLEMENTATION getImplementation() const { return implementation_d; } // method: getSplitThreshold // float32 getSplitThreshold() const { return split_threshold_d; } // method: getMergeThreshold // float32 getMergeThreshold() const { return merge_threshold_d; } // method: getNumOccThreshold // float64 getNumOccThreshold() const { return num_occ_threshold_d; } // method: get // bool8 get(ALGORITHM& algorithm, IMPLEMENTATION& implementation, float32& split_threshold, float32& merge_threshold, float32& num_occ_threshold) { algorithm = algorithm_d; implementation = implementation_d; split_threshold = split_threshold_d; merge_threshold = merge_threshold_d; num_occ_threshold = num_occ_threshold_d; return true; } // method: getStatTrain // bool8 getStatTrain(Vector& context_map, Vector >& sub_graphs, Vector& symbol_table, Vector& contextless_symbol_table, HashTable& symbol_hash, Vector& stat_models, Filename& phonetic_dt_file, HashTable& tied_model_hash, Vector& tied_stat_models); // method: getStatTest // bool8 getStatTest(Vector& context_map, int32& left_context, int32& right_context, Vector& upper_symbol_table, Vector& upper_contextless_symbol_table, Vector >& sub_graphs, Vector& symbol_table, HashTable& symbol_hash, Filename& ques_ans_file); //--------------------------------------------------------------------------- // // class-specific public methods: // computational methods // //--------------------------------------------------------------------------- // method: runDecisionTree // bool8 runDecisionTree(); // method: trainDecisionTree // bool8 trainDecisionTree(); // method: load // bool8 load(const Attributes& attributes, PhoneticDecisionTreeNode& pdtnode); // method: loadTrain // bool8 loadTrain(Vector& context_map, int32& left_context, int32& right_context, Vector& upper_symbol_table, Vector& contextless_symbol_table, Vector >& sub_graphs, Vector& symbol_table, HashTable& symbol_hash, Vector& stat_models, Filename& ques_ans_file, HashTable& tied_symbol_hash, Vector& tied_stat_models); // method: loadTest method // bool8 loadTest(Filename& phonetic_dt_file); // method to set the parser // bool8 setParser(SofParser* parser); //--------------------------------------------------------------------------- // // private methods // //--------------------------------------------------------------------------- private: // method: classifyDataPoint // Long classifyDataPoint(DataPoint& datapoint); // classification and merging methods // bool8 classifyData(TreeNode* node, Attribute& attribute); bool8 mergeLeafNodes(TreeNode* start_node, TreeNode* best_node); // subtree manipulation methods // bool8 splitSubTree(TreeNode* node); bool8 mergeSubTree(TreeNode* node); bool8 reindexSubTree(TreeNode* node, int32& index); Long findClass(TreeNode* node, DataPoint& datapoint); // method to find best attribute // bool8 findBestAttribute(TreeNode* node, Attribute& best_attribute, float32& likelihood); // method to find the index of a typical StatisticalModel at a node // Long findTypicalIndex(TreeNode* node); // method to mark a node // bool8 markNode(TreeNode* node, bool8& flag); // method to update the typical-index of the best-node // bool8 updateTypicalIndex(TreeNode* start_node, TreeNode* best_node); // mathematical manipulation methods // bool8 computeSumOccupancy(TreeNode* node, float32& sum_num_occ); bool8 isSplitOccupancyBelowThreshold(TreeNode* node, Attribute& attribute); bool8 computeDeterminantPooledCovariance(TreeNode* node, float32& det_pooled_covariance); float64 computeScale(StatisticalModel& stat_model); // compute likelihood methods // bool8 computeLikelihoodNode(TreeNode* node, float32& likelihood); bool8 computeLikelihoodSplitNode(TreeNode* node, Attribute& attribute, float32& split_likelihood); bool8 computeLikelihoodMergeNodes(TreeNode* start_node, TreeNode* node, float32& merge_likelihood); // contexts generation methods // bool8 createContexts(Vector& symbols, int32& length, Vector& all_contexts); bool8 appendContextLevel(Vector& symbols, int32& level, Vector& all_contexts); bool8 validateContexts(Vector& contextless_symbol_table, Vector& all_contexts, Vector& valid_contexts); bool8 getUnseenContexts(Vector& seen_contexts, Vector& valid_contexts, Vector& unseen_contexts); // method: updateLowerLevel // bool8 updateLowerLevel(Vector& context_map, Vector& unseen_context_map, Vector >& sub_graphs, Vector& symbol_table, HashTable& symbol_hash); // method: getCentralSymbols // bool8 getCentralSymbols(Vector& symbol_table, Vector& contextless_symbol_table, SingleLinkedList& central_symbols); // method: readQuestionAnswer // bool8 readQuestionAnswer(Filename& ques_ans_file, SingleLinkedList >& questions, HashTable& answers); // method: poolStatisticalModel // bool8 poolStatisticalModel(Vector& context_map, Vector& contextless_symbol_table, Vector >& sub_graphs, Vector& symbol_table, HashTable& symbol_hash, Vector& stat_models, int32& context_len, SingleLinkedList >& questions, HashTable& answers, Data& data, HashTable& tied_symbol_hash, Vector& tied_stat_models); // method: isTiedSSymbol // bool8 isTiedSSymbol(SearchSymbol& search_symbol, HashTable& symbol_hash); }; // end of include file // #endif