// file: $isip/class/pr/StatisticalModelAdaptation/smadp_05.cc // version: $Id: smadp_05.cc 9459 2004-04-19 16:06:08Z gao $ // // isip include files // #include "StatisticalModelAdaptation.h" // method: adapt // // arguments: // RegressionDecisionTree& rdt: (input) tree // Vector& stat_models: (input) vector of statistical model // // return: a bool8 value indicating status // // this method computes contribution of one Gaussian model to // cumulative StatisticalModelAdaptationation matrices G and Z // bool8 StatisticalModelAdaptation::adapt(RegressionDecisionTree& rdt_a, Vector& stat_models_a) { // local variables // RTreeNode* root_node = (RTreeNode*)NULL; // check the algorithm and implementation // if (algorithm_d != MLLR && implementation_d != MEAN) { return Error::handle(name(), L"adapt", ERR_UNSUPM, __FILE__, __LINE__); } // get the root node // root_node = rdt_a.getFirst(); // check the node // if (root_node == (RTreeNode*)NULL) { return Error::handle(name(), L"adapt - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // local variables // RegressionDecisionTreeNode* rdt_node = (RegressionDecisionTreeNode*)NULL; // local variables used to get the mixture offset and to get the // transform id from statistical model index and gaussian model // index // VectorLong mixture_offset; VectorLong stat_to_trans; rdt_a.getMixtureOffset(mixture_offset); rdt_a.getStatToTrans(stat_to_trans); // get the data on this node // rdt_node = root_node->getItem(); // check the node // if (rdt_node == (RegressionDecisionTreeNode*)NULL) { return Error::handle(name(), L"classifyData - NULL RDTNODE", Error::ARG, __FILE__, __LINE__); } // output some debug messages // if (debug_level_d >= Integral::ALL) { stat_to_trans.debug(L"stat_to_trans"); mixture_offset.debug(L"mixture_offset"); } // loop over all statistical models and adapt their parameters // for (int i = 0; i < stat_models_a.length(); i++) { // if the model to adapt is a single Gaussian model, adapt its // parameters using the transformation matrix W // if (stat_models_a(i).getType() == StatisticalModel::GAUSSIAN_MODEL) { int32 temp_index = (int32)mixture_offset(i) + 1; BiGraphVertex* child_node; child_node = findTerminal(root_node, stat_to_trans(temp_index)); if (child_node->getItem()->getTransformation(w_transform_d)) { stat_models_a(i).getGaussianModel().adapt(w_transform_d); } } // if the model to adapt is a Gaussian mixture model, loop over // all mixture components and adapt their parameters // else if (stat_models_a(i).getType() == StatisticalModel::MIXTURE_MODEL) { // get the linked list of mixtures // SingleLinkedList& models = const_cast(stat_models_a(i)).getMixtureModel().getModels(); int32 w_index = 0; bool8 speech_flag = true; // loop over the linked list of mixtures // bool8 more = models.gotoFirst(); while (more) { // get the current mixture component // StatisticalModel* model = models.getCurr(); speech_flag = true; int32 temp_index = (int32)mixture_offset(i) + w_index; BiGraphVertex* child_node; child_node = findTerminal(root_node, stat_to_trans(temp_index)); if (!child_node->getItem()->getTransformFlag()) { speech_flag = child_node->getItem()->getSpeechFlag(); int32 parent_index = child_node->getItem()->getParentNodeIndex(); child_node = findTerminal(root_node, parent_index); if (debug_level_d >= Integral::ALL) { Long(parent_index).debug(L"using the parentnode transfromation"); } } else { speech_flag = child_node->getItem()->getSpeechFlag(); if (debug_level_d >= Integral::ALL) { stat_to_trans(temp_index).debug(L"node transfromation"); } } if (child_node->getItem()->getTransformation(w_transform_d)) { // since it was checked to be Gaussian already, it can be adapted // if (speech_flag) model->getGaussianModel().adapt(w_transform_d); } if (debug_level_d >= Integral::ALL) { Boolean(speech_flag).debug(L"speech_flag"); Long(i).debug(L"transformation node"); w_transform_d.debug(L"transfromation matrix"); } // move to the next mixture component // more = models.gotoNext(); } // end of loop over all mixtures } // end of else if mixture model } // end of loop over all statistical models // exit gracefully // return true; } // method: findTerminal // // arguments: // RTreeNode* root_node: (input) input node // int32 index: (input) the best score upto now // // return: a TreeNode point for the node index node represented // // this method return the node represented by this index number. If // the node is not found, return null. // BiGraphVertex* StatisticalModelAdaptation::findTerminal(RTreeNode*& root_node_a, int32 index_a) { // local variables // float32 score = (float32)0; DoubleLinkedList >* children; BiGraphArc* child; BiGraphVertex* child_node; RTreeNode* best = (RTreeNode*)NULL; // check the node // if (root_node_a != (RTreeNode*)NULL) { // get all the child nodes of this node // children = root_node_a->getChildren(); // if this node has child nodes, accumulate the child nodes // if (root_node_a->gotoFirstChild()) { // loop for the left child // children->gotoFirst(); child_node = children->getCurr()->getVertex(); best = findTerminal(child_node, index_a); if(best != (RTreeNode*)NULL) { return best; } if (child_node == (RTreeNode*)NULL) { child = children->getCurr(); child_node = child->getVertex(); // local variables // RegressionDecisionTreeNode* rdt_node = (RegressionDecisionTreeNode*)NULL; // get the data on this node // rdt_node = child_node->getItem(); score = rdt_node->getClusterScore(); return best; } if( children->gotoNext()) { // loop for the right child // child_node = children->getCurr()->getVertex(); best = findTerminal(child_node, index_a); if(best != (RTreeNode*)NULL) { return best; } } } } RegressionDecisionTreeNode* root_node = (RegressionDecisionTreeNode*)NULL; root_node = root_node_a->getItem(); if ( root_node->getNodeIndex() == index_a) best = root_node_a; // return the best node // return best; }