// file: $isip/class/pr/RegressionDecisionTreeNode/RegressionDecisionTreeNode.h // version: $Id: RegressionDecisionTreeNode.h 9470 2004-05-10 15:40:08Z gao $ // // make sure definitions are only made once // #ifndef ISIP_REGRESSION_DECISION_TREE_NODE #define ISIP_REGRESSION_DECISION_TREE_NODE #ifndef ISIP_SINGLE_LINKED_LIST #include #endif #ifndef ISIP_TRIPLE #include #endif #ifndef ISIP_LONG #include #endif #ifndef ISIP_BOOLEAN #include #endif #ifndef ISIP_STATISTICAL_MODEL #include #endif #ifndef ISIP_GAUSSIAN_MODEL #include #endif #ifndef ISIP_HASH_TABLE #include #endif #ifndef ISIP_STRING #include #endif #ifndef ISIP_DEBUG_LEVEL #include #endif // RegressionDecisionTreeNode: A class to store a collection of // information for statistical models in RegressionDecisionTree class. // class RegressionDecisionTreeNode { //--------------------------------------------------------------------------- // // public constants // //--------------------------------------------------------------------------- public: // define the class name // static const String CLASS_NAME; //---------------------------------------- // // i/o related constants // //---------------------------------------- static const String DEF_PARAM; static const String PARAM_DATAPOINTS; static const String PARAM_AVERAGE_MEAN; static const String PARAM_CLUSTER_SCORE; static const String PARAM_CLUSTER_ACCUMULATE; static const String PARAM_NUMBER_COMPONENTS; static const String PARAM_NODE_INDEX; static const String PARAM_AVERAGE_COVARIANCE; static const String PARAM_SPEECH_FLAG; static const String PARAM_SPLIT_FLAG; static const String PARAM_PARENT_NODE_INDEX; static const String PARAM_TRANSFORM_FLAG; static const String PARAM_STAT_MODELS; static const String PARAM_W_TRANSFORM; static const String PARAM_DBGL; //---------------------------------------- // // other important constants // //---------------------------------------- //---------------------------------------- // // default values and arguments // //---------------------------------------- // define default values for the typical statistical-model index // static const int32 DEF_TYPICAL_INDEX = -1; // define default values for the actual statistical-model index // static const int32 DEF_ACTUAL_INDEX = -1; // define default values for this flag // static const bool DEF_FLAG_EXISTS = true; //--------------------------------------- // // error codes // //--------------------------------------- static const int32 ERR = 00100600; static const int32 ERR_ADAPT_NO_GAUSSIAN = 00100610; //--------------------------------------------------------------------------- // // protected data // //--------------------------------------------------------------------------- protected: typedef Triple RDataPoint; typedef SingleLinkedList RData; // list of gaussian components for this node // RData gaussian_models_d; // node property variables // VectorFloat average_mean_d; // node cluster mean Float cluster_score_d; // node cluster score Float cluster_accumulate_d; // accumulation value in this cluster Long number_components_d; // number of components in this cluster Long node_index_d; // node index number MatrixFloat average_covariance_d; // node cluster variance // flag indicates if this node is speech/non-speech node // Boolean speech_flag_d; // flag to indicator if this node is splittable // Boolean split_flag_d; // parent node index, use to search transformation index // Long parent_node_index_d; // transformation data enough? // Boolean transform_flag_d; // local copy of the statistical model // Vector stat_models_d; // define transformation matrix W (one row for each Gaussian // dimension) // Vector w_transform_d; // define Z matrix // MatrixFloat z_glob_d; // define G matrix // Vector g_glob_d; // define a debug level // DebugLevel debug_level_d; // define a static memory manager // static MemoryManager mgr_d; //--------------------------------------------------------------------------- // // required public methods // //--------------------------------------------------------------------------- public: // method: name // static const String& name() { return CLASS_NAME; } // method: diagnose // static bool8 diagnose(Integral::DEBUG debug_level); // method: debug // bool8 debug(const unichar* message) const; // method: setDebug // bool8 setDebug(Integral::DEBUG debug_level) { debug_level_d = debug_level; return true; } // method: destructor // ~RegressionDecisionTreeNode() { } // constructor(s) // RegressionDecisionTreeNode(); RegressionDecisionTreeNode(const RegressionDecisionTreeNode& copy_node); // assign methods // bool8 assign(const RegressionDecisionTreeNode& copy_node); // method: sofSize // int32 sofSize() const; // method: read // bool8 read(Sof& sof, int32 tag, const String& cname = CLASS_NAME); // method: write // bool8 write(Sof& sof, int32 tag, const String& cname = CLASS_NAME) const; // method: readData // bool8 readData(Sof& sof, const String& pname = DEF_PARAM, int32 size = SofParser::FULL_OBJECT, bool8 param = true, bool8 nested = false); // method: writeData // bool8 writeData(Sof& sof, const String& pname = DEF_PARAM) const; // equality method // bool8 eq(const RegressionDecisionTreeNode& compare_node) const; // method: new // static void* operator new(size_t size) { return mgr_d.get(); } // method: new[] // static void* operator new[](size_t size) { return mgr_d.getBlock(size); } // method: delete // static void operator delete(void* ptr) { mgr_d.release(ptr); } // method: delete[] // static void operator delete[](void* ptr) { mgr_d.releaseBlock(ptr); } // method: setGrowSize // static bool8 setGrowSize(int32 grow_size) { return mgr_d.setGrow(grow_size); } // clear methods // bool8 clear(Integral::CMODE ctype = Integral::DEF_CMODE); //--------------------------------------------------------------------------- // // class-specific public methods // //--------------------------------------------------------------------------- // method: setComponents // bool8 setComponents(RData& arg) { gaussian_models_d.assign(arg); number_components_d = gaussian_models_d.length(); return true; } // method: getComponents // RData& getComponents() const { return const_cast(gaussian_models_d); } // method: setAverageMean // bool8 setAverageMean(VectorFloat& arg) { average_mean_d.assign(arg); return true; } // method: getAverageMean // bool8 getAverageMean(VectorFloat& arg) { arg.assign(average_mean_d); return true; } // method: setAverageCov // bool8 setAverageCov(MatrixFloat& arg) { average_covariance_d.assign(arg); return true; } // method: getAverageCov // bool8 getAverageCov(MatrixFloat& arg) { arg.assign(average_covariance_d); return true; } // method: setClusterScore // bool8 setClusterScore(Float arg) { cluster_score_d = arg; return true; } // method: getClusterScore // Float& getClusterScore() { if (!speech_flag_d || !split_flag_d) cluster_score_d = 0; return cluster_score_d; } // method: setNumComponents // bool8 setNumComponents(int32 arg) { number_components_d = arg; return true; } // method: getNumComponents // Long& getNumComponents() { return number_components_d; } // method: setSpeechFlag // bool8 setSpeechFlag(Boolean arg) { speech_flag_d = arg; return true; } // method: getSpeechFlag // Boolean& getSpeechFlag() { return speech_flag_d; } // method: setSplitFlag // bool8 setSplitFlag(Boolean arg) { split_flag_d = arg; return true; } // method: getSplitFlag // Boolean& getSplitFlag() { return split_flag_d; } // method: setTransformFlag // bool8 setTransformFlag(bool8 arg) { transform_flag_d = arg; return true; } // method: getTransformFlag // Boolean& getTransformFlag() { return transform_flag_d; } // method: setParentNodeIndex // bool8 setParentNodeIndex(int32 arg) { parent_node_index_d = arg; return true; } // method: getParentNodeIndex // Long& getParentNodeIndex() { return parent_node_index_d; } // method: setClusterAccumulate // bool8 setClusterAccumulate(Float arg) { cluster_accumulate_d = arg; return true; } // method: getClusterAccumulate // Float& getClusterAccumulate() { return cluster_accumulate_d; } // method: setTransformation // bool8 setTransformation(Vector & arg) { w_transform_d.assign(arg); return true; } // method: getTransformation // bool8 getTransformation(Vector & arg) const { if (!speech_flag_d) return false; arg.assign(w_transform_d); return true; } // method: setNodeIndex // bool8 setNodeIndex(int32 arg) { node_index_d = arg; return true; } // method: getNodeIndex // Long getNodeIndex() const { return node_index_d; } // createTransform method // bool8 createTransform(Vector& stat_models_a); // updateDistribution method // bool8 updateDistribution(Vector& stat_models_a); // containModel method // bool8 containModel(int32 sm_index_a, int32 gm_index_a); //--------------------------------------------------------------------------- // // class-specific public methods: // //--------------------------------------------------------------------------- //--------------------------------------------------------------------------- // // private methods // //--------------------------------------------------------------------------- private: // calcClusterDistribution method // bool8 calcClusterDistribution(Vector& stat_models_a); // computeSumOccupancy method // bool8 computeSumOccupancy(Vector& stat_models_a, float32& sum_num_occ_a); // adaptPart method // bool8 adaptPart(Vector& g_a, MatrixFloat& z_a, GaussianModel& gm_a); // nodeScore method // bool8 nodeScore(Vector& stat_models_a); }; // end of include file // #endif