// file: $isip/class/pr/PhoneticDecisionTree/pdt_05.cc // version: $Id: pdt_05.cc 9263 2003-07-08 13:50:41Z alphonso $ // // isip include files // #include "PhoneticDecisionTree.h" // method: runDecisionTree // // arguments: // none // // return: a bool8 value indicating status // // this method runs the decisiontree using the specified runmode and // stopmode. // bool8 PhoneticDecisionTree::runDecisionTree() { // local variables // bool8 res; // runmode: TRAIN && stopmode: THRESH // if ((runmode_d == TRAIN) && (stopmode_d == THRESH)) { // construct the root node and insert it in the graph // BiGraphVertex* rootnode = insertVertex(&pdt_rootnode_d); // connect the start node to the root node // insertArc(getStart(), rootnode, false, 0); // train the tree // res = trainDecisionTree(); } // runmode: TEST && stopmode: THRESH // else if ((runmode_d == TEST) && (stopmode_d == THRESH)){ // classify the data in test mode // res = true; } // error: unknown mode // else { return Error::handle(name(), L"runDecisionTree", ERR, __FILE__, __LINE__); } // exit gracefully // return res; } // method: trainDecisionTree // // arguments: // none // // return: a bool8 value indicating status // // this method creates (train) the decision-tree on the basis of // algorithm and implementation // bool8 PhoneticDecisionTree::trainDecisionTree() { // local variables // bool8 res; // algorithm: ML && implementation: DEFAULT // if ((algorithm_d == ML) && (implementation_d == DEFAULT)) { // local variables // TreeNode* root_node = (TreeNode*)NULL; SingleLinkedList leaf_nodes(DstrBase::USER); // first classify the root node into the children on the basis of // central symbol (first-attribute is central-monophone in the // model) // root_node = getFirst(); // check the node // if (root_node == (TreeNode*)NULL) { return Error::handle(name(), L"trainDecisionTree - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } attributes_d.gotoFirst(); Attribute* attribute = attributes_d.getCurr(); res = classifyData(root_node, *attribute); // classify the leaf-nodes of the tree to the children on the basis // of position of the state (datapoint) in the model-topology // (second-attribute is state-position) // // get all the leaf nodes below the node // res = getLeafNodes(*root_node, leaf_nodes); // loop over all the leaf nodes and classify them on the basis of // state-position // attributes_d.gotoNext(); attribute = attributes_d.getCurr(); for (bool8 more = leaf_nodes.gotoFirst(); more; more = leaf_nodes.gotoNext()) { // local variables // TreeNode* temp_node = (TreeNode*)NULL; // get the leaf-node // temp_node = leaf_nodes.getCurr(); // check the node // if (temp_node == (TreeNode*)NULL) { return Error::handle(name(), L"trainDecisionTree - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } res = classifyData(temp_node, *attribute); } // first split each leaf node as one decision-tree, reindex the // leaf-nodes and then merge it as a sub-tree. loop over all the nodes // one-by-one and split each tree at a time // int32 index = 0; leaf_nodes.clear(Integral::RESET); res = getLeafNodes(*root_node, leaf_nodes); for (bool8 more = leaf_nodes.gotoFirst(); more; more = leaf_nodes.gotoNext()) { // local variables // TreeNode* temp_node = (TreeNode*)NULL; // get the leaf-node // temp_node = leaf_nodes.getCurr(); // check the node // if (temp_node == (TreeNode*)NULL) { return Error::handle(name(), L"trainDecisionTree - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // split the sub-tree // res = splitSubTree(temp_node); // reindex the statistical-models on the leaf-nodes of the // sub-tree // res = reindexSubTree(temp_node, index); // merge the sub-tree // res = mergeSubTree(temp_node); } } // error: unknown mode // else { return Error::handle(name(), L"trainDecisionTree", ERR, __FILE__, __LINE__); } // exit gracefully // return res; } // method: splitSubTree // // arguments: // Treenode* node: (input) input node // // return: a bool8 value indicating status // // this method split the input node as a sub-tree till the threshold // conditions are met // bool8 PhoneticDecisionTree::splitSubTree(TreeNode* node_a) { // define local variable // bool8 res = false; bool8 split = true; // check the node // if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"splitSubTree - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // continue to split the tree till we can't find the best leaf-node // to split that satisfies the threshold conditions // while (split) { // define local variable // Attribute attribute; TreeNode* best_node = (TreeNode*)NULL; float32 max_inc_likelihood = (float32)0; split = false; SingleLinkedList leaf_nodes(DstrBase::USER); // get all the leaf nodes below the node // res = getLeafNodes(*node_a, leaf_nodes); // find the best attribute at the current leaf-node. find the best // candidate leaf-node for the split // for (bool8 more = leaf_nodes.gotoFirst(); more; more = leaf_nodes.gotoNext()) { // define local variable // TreeNode* child_node = (TreeNode*)NULL; float32 inc_likelihood = (float32)0; Attribute best_attribute; bool8 att = false; child_node = leaf_nodes.getCurr(); // check the node // if (child_node == (TreeNode*)NULL) { return Error::handle(name(), L"splitSubTree - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // find the best attribute at the current leaf-node and its // likelihood increment // att = findBestAttribute(child_node, best_attribute, inc_likelihood); // update the best node to split at a level // if (att && (inc_likelihood > max_inc_likelihood)) { max_inc_likelihood = inc_likelihood; best_node = child_node; attribute = best_attribute; split = true; res = true; } } // classify into childern only if there is attribute that // satisfies the threshold conditions // if (split) { res = classifyData(best_node, attribute); } } // exit gracefully // return res; } // method: findBestAttribute // // arguments: // Attribute& best_attribute: (output) best attribute to split the node // float32& inc_likelihood: (output) increase in likelihood if the node is split // TreeNode* node: (input) input node // // return: a bool8 value indicating status // // this method computes the best attribute and its corresponding // increase in the likelihood in order to split the input node. // bool8 PhoneticDecisionTree::findBestAttribute(TreeNode* node_a, Attribute& best_attribute_a, float32& inc_likelihood_a) { // local variables // float32 likelihood = (float32)0; bool8 res; // check the node // if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"findBestAttribute - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // compute the likelihood of the the node // res = computeLikelihoodNode(node_a, likelihood); // loop over all the attributes and find the one with maximum // likelihood // for (bool8 k = attributes_d.gotoFirst(); k; k = attributes_d.gotoNext()) { // local variables // Attribute* attribute = (Attribute*)NULL; Attribute temp_attribute; bool8 att; float32 split_likelihood = (float32)0; float32 inc_likelihood = (float32)0; // get the attribute // attribute = attributes_d.getCurr(); // check the attribute // if (attribute == (Attribute*)NULL) { return Error::handle(name(), L"findBestAttribute - NULL ATTRIBUTE", Error::ARG, __FILE__, __LINE__); } temp_attribute = *attribute; // compute the likelihood after the splitting this node on the // current attribute. the return flag will be false if the // node is a pure node given the attribute // att = computeLikelihoodSplitNode(node_a, temp_attribute, split_likelihood); // compute increase in likelihood due to splitting the node on the // current attribute // inc_likelihood = split_likelihood - likelihood; // the best attribute is valid only when // 1) split is valid (input node is not pure given this attribute) // 2) increase in likelihood is due to the greater than the maximum // increase in likelihood till this attribute // 3) state-occupancies of split nodes are greater than the // num_occ_threshold // if (att && (inc_likelihood > inc_likelihood_a) && isSplitOccupancyBelowThreshold(node_a, temp_attribute)) { inc_likelihood_a = inc_likelihood; best_attribute_a = temp_attribute; } } // the best attribute is valid only when the split meets the // threshold conditions. increase in likelihood is greater than the // split_threshold // if (inc_likelihood_a > split_threshold_d) res = true; else res = false; // exit gracefully // return res; } // method: computeLikelihoodNode // // arguments: // float32& likelihood: (output) likelihood at the input node // TreeNode* node: (input) input node // // return: a bool8 value indicating status // // this method computes the likelihood at the input node. // bool8 PhoneticDecisionTree::computeLikelihoodNode(TreeNode* node_a, float32& likelihood_a) { // local variables // float32 likelihood = (float32)0; float32 sum_num_occ; float32 det_pooled_covar; bool8 res; // check the node // if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"computeLikelihoodNode - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get the sum of occupancies at this node // res = computeSumOccupancy(node_a, sum_num_occ); // get the data in singlelinked list // PhoneticDecisionTreeNode* pdt_node = node_a->getItem(); Data& data = pdt_node->getDataPoints(); // get the first datapoint in triple // DataPoint* datapoint = data.getFirst(); // get datapoint statistical model // StatisticalModel& datapoint_stat_model = datapoint->second(); // compute likelihood only when the sum_num_occ is non-zero. this is // necessary for the liklihood computation to have valid division // assuming gaussian distribution. note that the likelihood // computation may be different for any other distribution // if (sum_num_occ != (float32) 0) { // compute likeihood only for single-mixture gaussian // distribution, else error // // reference Eq 6 // J. Zhao, et al, "Tutorial for Decision Tree-Based // State-Tying For Acoustic Modeling", pp. 6, June, 1999. // // L = -0.5 * (n * (1 + ln(2*pi)) + ln(|C|)) * sum_num_occ // // where n = number of features // sum_num_occ = sigma ( state_num_occ(s) ) // s // where s = statistical models at this node // // // compute only if the underlying model is MixtureModel, else // error // if (datapoint_stat_model.getType() == StatisticalModel::MIXTURE_MODEL) { // local variables // StatisticalModel* mixture; // get the mixtures as SingleLinkedList of StatisticalModels // MixtureModel& mixture_model = datapoint_stat_model.getMixtureModel(); SingleLinkedList& mixtures = mixture_model.getModels(); // get the first mixture as StatisticalModel // mixture = mixtures.getFirst(); // check if the distribution is gaussian and single mixture // if ((mixture->getType() == StatisticalModel::GAUSSIAN_MODEL) && (mixtures.length() == (int32)1) ) { // compute determinant of the pooled covariance // res = computeDeterminantPooledCovariance(node_a, det_pooled_covar); // get number of features from the mean dimensions last // datapoint statistical node, set before this if statement // // local variables // VectorFloat mean; int32 num_features; // get the mean of the statistical model. note that the // underlying distribution are actually gaussian in this case // res = datapoint_stat_model.getMean(mean); // get the dimensionality of the mean. it is equal to the num // of features // num_features = mean.length(); // compute likelihood assuming gaussian distribution // float32 log_det = Integral::log(det_pooled_covar); float64 temp = Integral::log(Integral::TWO_PI); likelihood = -0.5 * (num_features * (1 + temp) + log_det) * sum_num_occ; } // error: unknown distribution and multiple mixture models // else { return Error::handle(name(), L"computeLikelihoodNode", ERR, __FILE__, __LINE__); } } // error: only MixtureModel supported // else { return Error::handle(name(), L"computeLikelihoodNode", ERR, __FILE__, __LINE__); } } // end if sum_num_occ != 0 likelihood_a = likelihood; // exit gracefully // return res; } // method: computeDeterminantPooledCovariance // // arguments: // float32& det_pooled_covariance: (output) dereminant of the pooled covariance // at the input node // TreeNode* node: (input) input node // // return: a bool8 value indicating status // // this method computes the likelihood at the input node. // // reference Eq 7 // J. Zhao, et al, "Tutorial for Decision Tree-Based // State-Tying For Acoustic Modeling", pp. 6, June, 1999. // // note that in this implementation we assume the mean as a row // vector while the reference equation assumes the mean to be a column // vector. // // |C| = determinant(C) // // C = (a / sum_num_occ) - b // // where a = sigma ( num_occ(s) ( cov(s) + transpose(mean(s)) * mean(s) ) ) // s // // sum_num_occ = sigma ( state_num_occ(s) ) // s // // b = transpose(c) * c // // c = (sigma ( num_occ(s) * mean(s) )) / sum_num_occ ; // s // // where s = statistical models at this node // bool8 PhoneticDecisionTree:: computeDeterminantPooledCovariance(TreeNode* node_a, float32& det_pooled_covariance_a) { // local variables // float32 sum_num_occ; MatrixFloat a; MatrixFloat b; VectorFloat c; MatrixFloat pooled_covar; bool8 diagonal = true; bool8 res = true; // check the node // if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"computeDeterminantPooledCovariance - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get the data in singlelinked list // PhoneticDecisionTreeNode* pdt_node = node_a->getItem(); Data& data = pdt_node->getDataPoints(); // loop over datapoints and compute the pooled_covariance // for (bool8 i = data.gotoFirst(); i; i = data.gotoNext()) { // local variables // float64 datapoint_num_occ; VectorFloat mean; MatrixFloat covar; MatrixFloat mprod1; VectorFloat vprod1; MatrixFloat mprod2; MatrixFloat msum1; // get the datapoint in triple // DataPoint* datapoint = data.getCurr(); // get the statistical model and the occupancy // StatisticalModel& datapoint_stat_model = datapoint->second(); datapoint_num_occ = datapoint_stat_model.getOccupancy(); // get the mean and covariance of this mixture // res = datapoint_stat_model.getMean(mean); res = datapoint_stat_model.getCovariance(covar); // check if the covariance is non-diagonal // if (!covar.isDiagonal()) { diagonal = false; } // intermediate computations to compute matrices "a" and "c" // res = mprod1.outerProduct(mean, mean); res = msum1.add(covar, mprod1); res = msum1.mult(datapoint_num_occ); a.setDimensions(msum1); res = a.add(msum1); res = vprod1.assign(mean); res = vprod1.mult(datapoint_num_occ); c.setLength(vprod1.length()); res = c.add(vprod1); } // end of for loop for states on this node // get the sum of occupancies at this node // res = computeSumOccupancy(node_a, sum_num_occ); // compute the pooled covariance matrix and its determinant // res = a.div(sum_num_occ); res = c.div(sum_num_occ); res = b.outerProduct(c, c); res = pooled_covar.sub(a,b); // set the pooled covariance to diagonal if all the datapoints had // diagonal covariance. this is an assumptiom though the equations // don't generate a diagonal even if all the input data points have // diagonal covariances // if (diagonal) { // local variables // MatrixFloat temp; temp.setDimensions(pooled_covar); temp.setDiagonal(pooled_covar); pooled_covar.assign(temp); } // compute the determinant of the pooled covariance matrix // det_pooled_covariance_a = pooled_covar.determinant(); // exit gracefully // return res; } // method: computeLikelihoodSplitNode // // arguments: // float32& split_likelihood: (output) likelihood if the node is split // TreeNode* node: (input) input node // Attribute& attribute: (input) input attribute that's used for // splitting the input node // // return: a bool8 value indicating status // // this method computes the likelihood if the input node is split // using the input attribute. // bool8 PhoneticDecisionTree:: computeLikelihoodSplitNode(TreeNode* node_a, Attribute& attribute_a, float32& split_likelihood_a) { // local variables // bool8 res = true; split_likelihood_a = (float32)0; // check the node // if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"computeLikelihoodSplitNode - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get the name and values of this attribute // String& attr_name = attribute_a.first(); SingleLinkedList& attr_values = attribute_a.second(); // get the data points on the node // PhoneticDecisionTreeNode* pdt_node = node_a->getItem(); Data& data = pdt_node->getDataPoints(); // get the first datapoint from the singlelinked list // DataPoint* datapoint = data.getFirst(); // check if this attribute is on the first item of the current node, // else return false // HashTable& datapoint_attr = datapoint->third(); // return error if the attribute is missing on a datapoint // if (!datapoint_attr.containsKey(attr_name)) { return Error::handle(name(), L"computeLikelihoodSplitNode", ERR, __FILE__, __LINE__); } // loop over all the values for this attribute, compute likelihood // for each and then add them // for (bool8 l = attr_values.gotoFirst(); l; l = attr_values.gotoNext()) { // local variables // TreeNode node; PhoneticDecisionTreeNode child_pdt_node; float32 likelihood = (float32)0; Data& child_data = child_pdt_node.getDataPoints(); child_data.setAllocationMode(DstrBase::USER); // get the value of the current attribute // String* value = attr_values.getCurr(); // loop over data and count the number of each value // for (bool8 j = data.gotoFirst(); j; j = data.gotoNext()) { // get the data point in triple // datapoint = data.getCurr(); // get the attribute value in hashtable // HashTable& datapoint_attr = datapoint->third(); // check if this datapoint has this value for the current // attribute and add this to the singlelinked list // if(value->eq(*datapoint_attr.get(attr_name))) { child_data.insert(datapoint); } } // set the singlelinked list of data at this child node // node.setItem(&child_pdt_node); // compute the likelihood for the child node only if there is // data on the node // if (!child_data.isEmpty()) { if (!computeLikelihoodNode(&node, likelihood)) { return Error::handle(name(), L"computeLikelihoodSplitNode", ERR, __FILE__, __LINE__); } } else res = false; // do other computations to get the split likelihood // split_likelihood_a += likelihood; } // exit gracefully // return res; } // method: classifyData // // arguments: // TreeNode* node: (input) input node // Attribute& attribute: (input) input attribute that's used for // splitting the input node // // return: a bool8 value indicating status // // this method classifies the input node using the input attribute and // adds the splitted nodes as the children nodes to the input node in // the decisiontree. this method also adds the best question at the // parent-node // bool8 PhoneticDecisionTree::classifyData(TreeNode* node_a, Attribute& attribute_a) { // local variables // PhoneticDecisionTreeNode* pdt_node = (PhoneticDecisionTreeNode*)NULL; // check the node // if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"classifyData - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get the name and values for the attribute // String& attr_name = attribute_a.first(); SingleLinkedList& attr_values = attribute_a.second(); // get the data on this node // pdt_node = node_a->getItem(); // check the node // if (pdt_node == (PhoneticDecisionTreeNode*)NULL) { return Error::handle(name(), L"classifyData - NULL PDTNODE", Error::ARG, __FILE__, __LINE__); } Data& data = pdt_node->getDataPoints(); // save the best attribute at this node // pdt_node->setBestAttribute(attr_name); // loop over attribute values and set the data in the child node // for each attribute value // for(bool8 more = attr_values.gotoFirst(); more; more = attr_values.gotoNext()) { // local variables // PhoneticDecisionTreeNode child_pdt_node; Data data_child; // get the value of the current attribute // String* attr_value = attr_values.getCurr(); // loop over the data and accumulate the data with this attribute // value // for (bool8 j = data.gotoFirst(); j; j = data.gotoNext()) { // get the data point in triple // DataPoint* datapoint = data.getCurr(); // get the hash table of attributes for this datapoint // HashTable& datapoint_attr = datapoint->third(); if (attr_value->eq(*datapoint_attr.get(attr_name))) { // copy the datapoint and insert into the corresponding child // node // data_child.insert(datapoint); } } // return error if the input node is PURE // if (data_child.length() <= 0) { return Error::handle(name(), L"classifyData - PURE-NODE CANNOT BE CLASSIFIED", Error::ARG, __FILE__, __LINE__); } // add this node to the graph and make connections // child_pdt_node.setDataPoints(data_child); TreeNode* node_child = insertVertex(&child_pdt_node); insertArc(node_a, node_child, true); } // exit gracefully // return true; } // method: computeSumOccupancy // // arguments: // float64& sum_num_occ: (output) sum of the occupancies of all the // datapoints(statistical models) at the input node // // TreeNode* node: (input) input node // // return: a bool8 value indicating status // // this method sums the occupancies of all the datapoints. // bool8 PhoneticDecisionTree::computeSumOccupancy(TreeNode* node_a, float32& sum_num_occ_a) { // local variables // float64 datapoint_num_occ; sum_num_occ_a = (float64)0; // check the node // if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"computeSumOccupancy - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get the data in singlelinked list // PhoneticDecisionTreeNode* pdt_node = node_a->getItem(); Data& data = pdt_node->getDataPoints(); // loop over data and compute the sum of the occupancies of all the // datapoints on the node // for (bool8 j = data.gotoFirst(); j; j = data.gotoNext()) { // get the datapoint in triple // DataPoint* datapoint = data.getCurr(); // get datapoint statistical model // StatisticalModel& datapoint_stat_model = datapoint->second(); // get the occupancy // datapoint_num_occ = (float64)0; datapoint_num_occ = datapoint_stat_model.getOccupancy(); // add this occupancy to the sum // sum_num_occ_a += datapoint_num_occ; } // exit gracefully // return true; } // method: isSplitOccupancyBelowThreshold // // arguments: // TreeNode* node: (input) input node // Attribute& attribute: (input) input attribute used to split the input node // // return: a bool8 value indicating status // // this method checks if the occupancies for each of the child nodes // generated by classifying the input node meets the threshold // conditions. // bool8 PhoneticDecisionTree::isSplitOccupancyBelowThreshold(TreeNode* node_a, Attribute& attribute_a) { // local variables // bool8 res = true; // check the node // if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"isSplitOccupancyBelowThreshold - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get the name and values of this attribute // String& attr_name = attribute_a.first(); SingleLinkedList& attr_values = attribute_a.second(); // get the data points on the node // PhoneticDecisionTreeNode* pdt_node = node_a->getItem(); Data& data = pdt_node->getDataPoints(); // get the first datapoint from the singlelinked list // DataPoint* datapoint = data.getFirst(); // check if this attribute is on the first item of the current node, // else return false // HashTable& datapoint_attr = datapoint->third(); // return error if the attribute is missing on a datapoint // if (!datapoint_attr.containsKey(attr_name)) { return Error::handle(name(), L"isSplitOccupancyBelowThreshold", ERR, __FILE__, __LINE__); } // loop over all the values for this attribute, compute // sum-state_occupancies for each and then check it these meet the // threshold // for (bool8 l = attr_values.gotoFirst(); l; l = attr_values.gotoNext()) { // local variables // TreeNode node; PhoneticDecisionTreeNode child_pdt_node; float32 sum_num_occ = (float32)0; Data& child_data = child_pdt_node.getDataPoints(); child_data.setAllocationMode(DstrBase::USER); // get the value of the current attribute // String* value = attr_values.getCurr(); // loop over data and count the number of each value // for (bool8 j = data.gotoFirst(); j; j = data.gotoNext()) { // get the data point in triple // datapoint = data.getCurr(); // get the attribute value in hashtable // HashTable& datapoint_attr = datapoint->third(); // check if this datapoint has this value for the current // attribute and add this to the singlelinked list // if (value->eq(*datapoint_attr.get(attr_name))) { child_data.insert(datapoint); } } // set the singlelinked list of data at this child node // node.setItem(&child_pdt_node); // compute the likelihood for the child node // computeSumOccupancy(&node, sum_num_occ); // check if this sum for the threshold conditions // if ( sum_num_occ <= num_occ_threshold_d) { res = false; } } // exit gracefully // return res; } // method: mergeSubTree // // arguments: // TreeNode* node: (input) input node // // return: a bool8 value indicating status // // this method merges two leaf nodes at a time, below a given node. // bool8 PhoneticDecisionTree::mergeSubTree(TreeNode* node_a) { // define local variable // TreeNode* node = (TreeNode*)NULL; TreeNode* start_node = (TreeNode*)NULL;; bool8 res = true; bool8 merge = true; // check the node // if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"mergeSubTree - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // merge the leaf nodes, two at a time, till the threshold // conditions are met // while (merge) { // define local variable // TreeNode* best_node = (TreeNode*)NULL; float32 first_likelihood = (float32)0; SingleLinkedList leaf_nodes(DstrBase::USER); merge = false; // get all the leaf nodes below the node // leaf_nodes.clear(Integral::RESET); node = node_a; res = getLeafNodes(*node, leaf_nodes); // get the start node // for (bool8 more = leaf_nodes.gotoFirst(); more; more = leaf_nodes.gotoNext() ) { start_node = leaf_nodes.getCurr(); // check the node // if (start_node == (TreeNode*)NULL) { return Error::handle(name(), L"mergeSubTree - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // mark the current posistion // leaf_nodes.setMark(); // intialize the min_dec_likelihood to the merge_threshold // float32 min_dec_likelihood = merge_threshold_d; // compute the likelihood of the first node // computeLikelihoodNode(start_node, first_likelihood); // find the best leaf-node if any that will be merged with this // start node // for (bool8 more = leaf_nodes.gotoNext(); more; more = leaf_nodes.gotoNext() ) { node = leaf_nodes.getCurr(); // check the node // if (node == (TreeNode*)NULL) { return Error::handle(name(), L"mergeSubTree - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // define local variable // float32 merge_likelihood = (float32)0; float32 dec_likelihood = (float32)0; float32 second_likelihood = (float32)0; bool8 att = false; // find the likelihood decrease by merging the start and this // node. if any of the nodes is non-existing this function // returns false // att = computeLikelihoodMergeNodes(start_node, node, merge_likelihood); // compute decrease in likelihood due to merging of the two // nodes // computeLikelihoodNode(node, second_likelihood); dec_likelihood = first_likelihood + second_likelihood - merge_likelihood; // update the best node that will be merged with the start // node, if the decrease in likelihood is less than the // min_dec_likelihood. note that min_dec_likelihood is // initallized to merge_likelihood_d // if ( att && (dec_likelihood < min_dec_likelihood)) { min_dec_likelihood = dec_likelihood; best_node = node; merge = true; } } // merge the best_node with start_node only if the decrease in // likelihood is less than the merge_threshold, and then mark // the best_node as non-existing and update its typical-index to // the typical-index of the start-node to which it is merged // if (merge) { res = mergeLeafNodes(start_node, best_node); // mark the best-node as non-existing since it is merged with // the start-node // bool8 flag = false; res = markNode(best_node, flag); res = updateTypicalIndex(start_node, best_node); break; } // go back to the marked node // leaf_nodes.gotoMark(); } } // exit gracefully // return res; } // method: computeLikelihoodMergeNodes // // arguments: // TreeNode* start_node: (input) input start_node // // TreeNode* node: (input) input node // // return: a bool8 value indicating status // // this method computes the likelihood if the two input nodes are // merged. // bool8 PhoneticDecisionTree:: computeLikelihoodMergeNodes(TreeNode* start_node_a, TreeNode* node_a, float32& merge_likelihood_a) { // local variables // TreeNode node; Data parent_data; bool8 res = true; bool8 start_flag = true; bool8 flag = true; // check the input nodes // if (start_node_a == (TreeNode*)NULL) { return Error::handle(name(), L"computeLikelihoodMergeNodes - NULL INPUT-VERTEX", Error::ARG, __FILE__, __LINE__); } if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"computeLikelihoodMergeNodes - NULL INPUT-VERTEX", Error::ARG, __FILE__, __LINE__); } // check if this node exists // PhoneticDecisionTreeNode* pdt_start_node = start_node_a->getItem(); start_flag = pdt_start_node->getFlagExists(); // get the data points on the start_node // Data& data_start_node = pdt_start_node->getDataPoints(); // loop over data and insert the datapoint into the parent // for (bool8 j = data_start_node.gotoFirst(); j; j = data_start_node.gotoNext()) { // get the data point in triple // DataPoint* datapoint_start_node = data_start_node.getCurr(); // add this datapoint to the singlelinked list // parent_data.insert(datapoint_start_node); } // check if this node exists // PhoneticDecisionTreeNode* pdt_node = node_a->getItem(); flag = pdt_node->getFlagExists(); // exit gracefully // if (!start_flag || !flag) { return false; } // get the data points on the node // Data& data_node = pdt_node->getDataPoints(); // get the first datapoint from the singlelinked list // DataPoint* datapoint_node = data_node.getFirst(); // loop over data and insert the datapoint into the parent // for (bool8 j = data_node.gotoFirst(); j; j = data_node.gotoNext()) { // get the data point in triple // datapoint_node = data_node.getCurr(); // add this datapoint to the singlelinked list // parent_data.insert(datapoint_node); } // set the singlelinked list of data at this parent node // PhoneticDecisionTreeNode parent_pdt_node; parent_pdt_node.setDataPoints(parent_data); node.setItem(&parent_pdt_node); // compute the likelihood for the child node only if there is // data on the node // if (!parent_data.isEmpty()) { merge_likelihood_a = (float32)0; res = computeLikelihoodNode(&node, merge_likelihood_a); } // exit gracefully // return res; } // method: mergeLeafNodes // // arguments: // TreeNode* start_node: (input) input start_node // // TreeNode* best_node: (input) input best candidate node that will // be merged with the start_node // // return: a bool8 value indicating status // // this method appends the data on the best_node to the start_node and // deletes the best_node // bool8 PhoneticDecisionTree::mergeLeafNodes(TreeNode* start_node_a, TreeNode* best_node_a) { // local variables // PhoneticDecisionTreeNode* start_pdt_node = (PhoneticDecisionTreeNode*)NULL; PhoneticDecisionTreeNode* best_pdt_node = (PhoneticDecisionTreeNode*)NULL; bool8 start_flag = true; bool8 best_flag = true; bool8 res = true; // check the nodes // if (start_node_a == (TreeNode*)NULL) { return Error::handle(name(), L"mergeLeafNodes - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } if (best_node_a == (TreeNode*)NULL) { return Error::handle(name(), L"mergeLeafNodes - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get the data points on the start_node // start_pdt_node = start_node_a->getItem(); // check the data // if (start_pdt_node == (PhoneticDecisionTreeNode*)NULL) { return Error::handle(name(), L"mergeLeafNodes - NULL DATA", Error::ARG, __FILE__, __LINE__); } // see if the start_node exists, it might have been merged // Data& start_data = start_pdt_node->getDataPoints(); start_flag = start_pdt_node->getFlagExists(); // see if the best_node exists, it might have been merged // best_pdt_node = best_node_a->getItem(); // check the data // if (best_pdt_node == (PhoneticDecisionTreeNode*)NULL) { return Error::handle(name(), L"mergeLeafNodes - NULL DATA", Error::ARG, __FILE__, __LINE__); } best_flag = best_pdt_node->getFlagExists(); // merge nodes only if both exists // if (start_flag && best_flag) { // get the data points on the best_node // Data& data = best_pdt_node->getDataPoints(); // loop over data and insert the datapoint into the parent // for (bool8 j = data.gotoFirst(); j; j = data.gotoNext()) { // get the data point in triple // DataPoint* datapoint = data.getCurr(); // add this datapoint to the singlelinked list // start_data.insert(datapoint); } // set the singlelinked list of data at the start_node // // start_pdt_node->setDataPoints(start_data); } else res = false; // exit gracefully // return res; } // method: classifyDataPoint // // arguments: // DataPoint& datapoint: (input) input data-point that will be classified // // return: a Long value (index) indicating the class // // this method classifies the data // Long PhoneticDecisionTree::classifyDataPoint(DataPoint& datapoint_a) { // local variables // Long index = -1; // runmode: TEST && stopmode: THRESH // if ((runmode_d == TEST) && (stopmode_d == THRESH)){ // classify the datapoint in test mode // // classify the input datapoint on the basis of its attributes, // starting from the root-node, till it falls into one of the // leaf-nodes (classes) // // local variables // TreeNode* root_node = (TreeNode*)NULL; // get the root node // root_node = getFirst(); // check the node // if (root_node == (TreeNode*)NULL) { return Error::handle(name(), L"classifyDataPoint - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get the index // index = findClass(root_node, datapoint_a); } // error: unknown mode // else { return Error::handle(name(), L"classifyDataPoint", ERR, __FILE__, __LINE__); } // exit gracefully // return index; } // method: findClass // // arguments: // TreeNode* node: (input) input node, below which the input datapoint will // be classified to any of the leaf-nodes // DataPoint& datapoint: (input) input data-point that will be classified // // return: a Long value (index) indicating the class // // this method clasifies the data // Long PhoneticDecisionTree::findClass(TreeNode* node_a, DataPoint& datapoint_a) { // local variables // Long index = -1; // get the PhoneticDecisionTreeNode // PhoneticDecisionTreeNode* pdt_node = node_a->getItem(); // iterate though this function only if this is not a leaf-node // if (node_a->gotoFirstChild()) { // get the best-attribute-name at this node // String& best_attr_name = pdt_node->getBestAttribute(); // get the value of the best-attribute-name at this datapoint // String* best_attr_value = datapoint_a.third().get(best_attr_name); // get all the attribute-values corresponding to this attribute-name // Attribute* attribute = (Attribute*)NULL; // loop over all the attributes and find the corresponding attribute // for (bool8 more = attributes_d.gotoFirst(); more; more = attributes_d.gotoNext()) { // get the attribute-values coresponding to the best-attribute-name // if (attributes_d.getCurr()->first().eq(best_attr_name)) { attribute = attributes_d.getCurr(); break; } } // make sure the attribute is defined // if (attribute == (Attribute*)NULL) { return Error::handle(name(), L"findClass - null attribute", Error::ARG, __FILE__, __LINE__); } // loop-over all the children-nodes of this node and classify the // datapoint to one of them that has the same // attribute-value. continue to iterate till we hit any of the // leaf-nodes // // get all the child nodes of this node // DoubleLinkedList >* children; BiGraphArc* child; BiGraphVertex* child_node; children = node_a->getChildren(); // loop over all the children // attribute->second().gotoFirst(); for (bool8 moreb = children->gotoFirst(); moreb; moreb = children->gotoNext()) { // get the attr-value corresponding to the child-node. note // that there is one-to-one correspondance between them // String* temp_attr_value; temp_attr_value = attribute->second().getCurr(); attribute->second().gotoNext(); // if the attr-value at this datapoint match, we have got the // corresponding child node // if (temp_attr_value->eq(*best_attr_value)) { child = children->getCurr(); child_node = child->getVertex(); // call this function iteratively // index = findClass(child_node, datapoint_a); } } } // else this node is a leaf node, and so return the typical index // else { // return the typical index at this leaf-node // index = pdt_node->getTypicalIndex(); // error: invalid statistical-model index // if (index == PhoneticDecisionTreeNode::DEF_TYPICAL_INDEX) { return Error::handle(name(), L"classifyData", ERR, __FILE__, __LINE__); } // exit gracefully // return index; } // exit gracefully // return index; } // method: findTypicalIndex // // arguments: // TreeNode* node: (input) input node // // return: a index that indicates the typical model at this node // Long PhoneticDecisionTree::findTypicalIndex(TreeNode* node_a) { // local variables // TreeNode* node = (TreeNode*)NULL; // get the node // node = node_a; // check the node // if (node == (TreeNode*)NULL) { return Error::handle(name(), L"findTypicalNode - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // local variables // Long typical_index = (Long)-1; PhoneticDecisionTreeNode pdt_node; Data data; DataPoint datapoint; float64 max_scale = (float64)-1000; // loop-over all the data at this leaf-node and find the typical // statistical-model to which all the rest of the // statistical-models at this node will be tied to. the model with // the highest scale or lowest variance is a typical model // // get all the data points on the node // pdt_node = *(node->getItem()); data = pdt_node.getDataPoints(); for (bool8 morea = data.gotoFirst(); morea; morea = data.gotoNext()) { // local variables // float64 scale; StatisticalModel datapoint_stat_model; Long temp_index; // get the inverse scale for this datapoint // datapoint = *(data.getCurr()); datapoint_stat_model = datapoint.second(); scale = computeScale(datapoint_stat_model); scale = 2.0 * scale; // get the index of this statistical-model // temp_index = datapoint.first(); // find the typical statistical model index // if ( scale > max_scale) { max_scale = scale; typical_index = temp_index; } } // exit gracefully // return typical_index; } // method: computeScale // // arguments: // StatisticalModel stat_model: (input) input statistical model // // return: this method returns the scale of the input StatisticalModel // (GaussianModel) with single mixture // float64 PhoneticDecisionTree::computeScale(StatisticalModel& stat_model_a) { // temporary variables // float64 scale; VectorFloat mean; MatrixFloat covariance; // compute only if the underlying model is MixtureModel, else // error // if (stat_model_a.getType() == StatisticalModel::MIXTURE_MODEL) { // local variables // MixtureModel mixture_model; SingleLinkedList mixtures; StatisticalModel mixture; // get the mixtures as SingleLinkedList of StatisticalModels // mixture_model = stat_model_a.getMixtureModel(); mixtures = mixture_model.getModels(); // get the first mixture as StatisticalModel // mixture = *(mixtures.getFirst()); // check if the distribution is gaussian and single mixture // if ((mixture.getType() == StatisticalModel::GAUSSIAN_MODEL) && (mixtures.length() == (int32)1) ) { // get number of features from the mean dimensions last // datapoint statistical node, set before this if statement // // local variables // VectorFloat mean; MatrixFloat covariance; // get the mean and covariance of the statistical model. note // that the underlying distribution are actually gaussian in // this case // stat_model_a.getMean(mean); stat_model_a.getCovariance(covariance); // check the arguments // int32 len_mean = mean.length(); int32 len_cov = covariance.getNumRows(); if ((len_mean != len_cov) || (len_mean <= 0) || (len_cov <= 0)) { return false; } // compute the scale factor from its components. // float64 det = Integral::log(covariance.determinant()); float64 tmp = Integral::log(Integral::TWO_PI); scale = (float64)0.5 * ((float64)len_mean * tmp + det); } // error: unknown distribution and multiple mixture models // else { return Error::handle(name(), L"computeScale", ERR, __FILE__, __LINE__); } } // error: only MixtureModel supported // else { return Error::handle(name(), L"computeScale", ERR, __FILE__, __LINE__); } // exit gracefully // return scale; } // method: createContexts // // arguments: // Vector& symbols: (input) input symbols // int32& length: (input) length of the contexts // Vector& all_contexts: (output) all contexts // // return: a bool8 value indicating status // // this method creates all possible contexts // bool8 PhoneticDecisionTree::createContexts(Vector& symbols_a, int32& length_a, Vector& all_contexts_a) { // local variable // bool8 status = true; // set the capacity of the all-context vector // int32 capacity = (int32)Integral::pow((float64)symbols_a.length(), length_a); all_contexts_a.setCapacity(capacity); // loop over the number of context-length. we'll add the symbol for // a certain context length in each iteration // for (int32 i = 0; i < length_a; i++) { // call the function that appends the context at // each-context-level // status = appendContextLevel(symbols_a, i, all_contexts_a); } // exit gracefully // return status; } // method: appendContextLevel // // arguments: // Vector& symbols: (input) input symbols // int32& level: (input) level of the contexts // Vector& all_contexts: (output) appended contexts // // return: a bool8 value indicating status // // this method appends the symbols to the contexts at any given // context-level // bool8 PhoneticDecisionTree::appendContextLevel(Vector& symbols_a, int32& level_a, Vector& all_contexts_a) { // local variable // bool8 status = true; Vector temp_all_contexts; // set the capacity of the temp-all-context vector // int32 capacity = (int32)Integral::pow((float64)symbols_a.length(), (level_a +1)); temp_all_contexts.setCapacity(capacity); // loop-over all the existing contexts, remove the existing ones and // add the new ones with existing ones // int32 all_contexts_len = all_contexts_a.length(); // increment the all_context_length to 1 is the level is zero. this // means that no contexts exist // if ( level_a == (int32)0) { all_contexts_len++; } for (int32 j = 0; j < all_contexts_len; j++) { // get the current context, only if the context exists // Vector vec_ss; if (level_a != (int32)0) { vec_ss = all_contexts_a(j).getContext(); } // loop-over all the symbols and add these to each of the // context. this will increase the context-length // for (int32 i = 0; i < symbols_a.length(); i ++) { // append the context with the symbol // ContextMap context; Vector temp_vec_ss; temp_vec_ss.assign(vec_ss); temp_vec_ss.concat(symbols_a(i)); context.setContext(temp_vec_ss); // add the context (partial or full) in all-context vector // int32 context_len = temp_all_contexts.length(); temp_all_contexts.setLength(context_len + (int32)1); temp_all_contexts(context_len).assign(context); } } // assign the contexts // all_contexts_a.clear(); all_contexts_a.assign(temp_all_contexts); // exit gracefully // return status; } // method: validateContexts // // arguments: // Vector& contextless_symbol_table: (input) contextless // symbol-table // Vector& all_contexts: (input) all contexts // Vector& valid_contexts: (output) all valid contexts // // return: a bool8 value indicating status // // this method removes all the contexts that are non-allowable. the // non-allowable contexts are: // 1. NO_LEFT_CONTEXT can't occur as right symbol in the context // 2. NO_RIGHT_CONTEXT can't occur as left symbol in the context // 3. both of these can't occur as central symbol // 4. all contexts with central symbol as contextless symbols // bool8 PhoneticDecisionTree::validateContexts(Vector& contextless_symbol_table_a, Vector& all_contexts_a, Vector& valid_contexts_a) { // local variable // bool8 status = true; // set the capacity of the valid-contexts vector // int32 capacity = all_contexts_a.length(); valid_contexts_a.setCapacity(capacity); // loop-over all the contexts and accumulate the valid-contexts // for (int32 i = 0; i < all_contexts_a.length(); i++) { // get the context and check of all the 4 invalid conditions // bool8 valid = true; Vector vec_ss = all_contexts_a(i).getContext(); int32 vec_len = vec_ss.length(); for (int32 j = 0; j < vec_len; j++) { // check for condition 1 // if ( (j < (vec_len/2)) && (vec_ss(j).eq(SearchSymbol::NO_RIGHT_CONTEXT)) ) { valid = false; } // check for condition 2 // if ( (j > (vec_len/2)) && (vec_ss(j).eq(SearchSymbol::NO_LEFT_CONTEXT)) ) { valid = false; } // check for condition 3 // if ((j == (vec_len/2)) && ((vec_ss(j).eq(SearchSymbol::NO_LEFT_CONTEXT)) || (vec_ss(j).eq(SearchSymbol::NO_RIGHT_CONTEXT)))) { valid = false; } // check for condition 4 // if ((j == (vec_len/2)) && (contextless_symbol_table_a.contains(&vec_ss(j)))) { valid = false; } } // if valid context, accumulate the context // if (valid) { // add the valid-context // int32 context_len = valid_contexts_a.length(); valid_contexts_a.setLength(context_len + (int32)1); valid_contexts_a(context_len).assign(all_contexts_a(i)); } } // reset the capacity of the valid-contexts vector // capacity = valid_contexts_a.length(); valid_contexts_a.setCapacity(capacity); // exit gracefully // return status; } // method: getUnseenContexts // // arguments: // Vector& seen_contexts: (input) contexts seen // Vector& valid_contexts: (input) all valid contexts // Vector& unseen_contexts: (output) contexts unseen // // return: a bool8 value indicating status // // this method gets all the contexts that are not in the input // contexts // bool8 PhoneticDecisionTree::getUnseenContexts(Vector& seen_contexts_a, Vector& valid_contexts_a, Vector& unseen_contexts_a) { // local variables // bool8 status = true; // set the capacity of the unseen-contexts vector // int32 capacity = valid_contexts_a.length(); unseen_contexts_a.setCapacity(capacity); // loop-over all the valid-contexts and accumulate the unseen-contexts // for (int32 i = 0; i < valid_contexts_a.length(); i++) { // local variables // bool8 seen = false; // get the context // Vector valid_vec_ss = valid_contexts_a(i).getContext(); // add this valid context-map to unseen-context-maps if it doesn't // exists in seen-contexts // for (int32 j = 0; j < seen_contexts_a.length(); j++) { // get the context // Vector seen_vec_ss = seen_contexts_a(j).getContext(); if (seen_vec_ss.eq(valid_vec_ss)) { seen = true; } } if (!seen) { int32 len = unseen_contexts_a.length(); unseen_contexts_a.setLength(len + 1); unseen_contexts_a(len).assign(valid_contexts_a(i)); } } // reset the capacity of the valid-contexts vector // capacity = unseen_contexts_a.length(); unseen_contexts_a.setCapacity(capacity); // exit gracefully // return status; } // method: updateLowerLevel // // arguments: // Vector& context_map: (input) contexts seen // Vector& unseen_context_map: (input) contexts unseen // Vector >& sub_graphs: (input) sub-graphsat lowest-level // Vector& symbol_table: (input) symbol-table at lowest-level // HashTable& symbol_hash: (input) mapping-table at // the lowest-level // // return: a bool8 value indicating status // // this method update the lowest level for unseen-contexts // bool8 PhoneticDecisionTree::updateLowerLevel(Vector& context_map_a, Vector& unseen_context_map_a, Vector >& sub_graphs_a, Vector& symbol_table_a, HashTable& symbol_hash_a) { // update the capacity of the context_map_out, sub-graphs // int32 len_context_map = context_map_a.length(); int32 len_unseen_context_map = unseen_context_map_a.length(); int32 cap_context_map = len_context_map + len_unseen_context_map; context_map_a.setCapacity(cap_context_map); int32 len_sub_graphs = sub_graphs_a.length(); int32 cap_sub_graphs = len_sub_graphs + len_unseen_context_map; sub_graphs_a.setCapacity(cap_sub_graphs); // update the capacity of the symbol-table // DiGraph temp_graph_copy; sub_graphs_a(len_sub_graphs -1).setAllocationMode(DstrBase::SYSTEM); temp_graph_copy.assign(sub_graphs_a(len_sub_graphs - 1)); sub_graphs_a(len_sub_graphs -1).setAllocationMode(DstrBase::USER); int32 len_symbols_per_graph = 0; for (bool8 morea = temp_graph_copy.gotoFirst(); morea; morea = temp_graph_copy.gotoNext()) { len_symbols_per_graph++; } int32 cap_symbol_table = symbol_table_a.length() + (len_symbols_per_graph * len_unseen_context_map); symbol_table_a.setCapacity(cap_symbol_table); // loop-over all the unseen context-maps // for (int32 i = 0; i < unseen_context_map_a.length(); i++) { // local variables // int32 curr_index = 0; int32 symbol_index = 0; int32 tmp_index = 0; int32 len = 0; Long val; ContextMap context; DiGraph graph_copy; // add the unseen context and its index to the context-maps // curr_index = sub_graphs_a.length(); context = unseen_context_map_a(i); context.setContextIndex((ulong)curr_index); len = context_map_a.length(); context_map_a.setLength(len + 1); context_map_a(len).assign(context); // update the context-indices of the unseen-contexts also // unseen_context_map_a(i).assign(context); // create a copy of the subgraph corresponding to the central contex symbol // and append the copy to the vector of subgraphs // sub_graphs_a(curr_index -1).setAllocationMode(DstrBase::SYSTEM); graph_copy.assign(sub_graphs_a(curr_index - 1)); graph_copy.setAllocationMode(DstrBase::USER); sub_graphs_a(curr_index -1).setAllocationMode(DstrBase::USER); sub_graphs_a.concat(graph_copy); // loop over each vertex in the subgraph of the appended copy // SearchSymbol ss; for (bool8 more = sub_graphs_a(curr_index).gotoFirst(); more; more = sub_graphs_a(curr_index).gotoNext()) { // retrieve the symbol index corresponding to the current vertex // symbol_index = sub_graphs_a(curr_index).getCurr()->getItem()->getSymbolId(); // create a new search symbol and append it to the symbol table // ss.assign(L"S_"); tmp_index = (int32)symbol_table_a.length() + 1; ss.concat(tmp_index); while (symbol_table_a.contains(&ss)) { ss.assign(L"S_"); ss.concat(++tmp_index); } symbol_table_a.concat(ss); val.assign(symbol_table_a.length() - 1); symbol_hash_a.insert(ss, &val); sub_graphs_a(curr_index).getCurr()->getItem()->setSymbolId((int32)val); } } // exit gracefully // return true; } // method: markNode // // arguments: // TreeNode* node: (input) input node that will be marked as non-existing // bool8& flag_a: (input) flag that will be set at the node // // return: a bool8 value indicating status // // this method marks the input node // bool8 PhoneticDecisionTree::markNode(TreeNode* node_a, bool8& flag_a) { // local variables // PhoneticDecisionTreeNode* pdt_node = (PhoneticDecisionTreeNode*)NULL; Data data; DataPoint datapoint; bool8 res = true; // check the input node // if (node_a == (TreeNode*)NULL) { return Error::handle(name(), L"markNode - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get the data points on the node // pdt_node = node_a->getItem(); // check the data // if (pdt_node == (PhoneticDecisionTreeNode*)NULL) { return Error::handle(name(), L"markNode - NULL DATA", Error::ARG, __FILE__, __LINE__); } // mark this node as non-existing // res = pdt_node->setFlagExists(flag_a); // exit gracefully // return res; } // method: updateTypicalIndex // // arguments: // TreeNode* start_node: (input) input start_node // // TreeNode* best_node: (input) input best candidate node that will // whose typical-index will be updated // // return: a bool8 value indicating status // // this method updates the typical-index of the best-node to the // typical-index of the start-node. this is needed for the test mode // bool8 PhoneticDecisionTree::updateTypicalIndex(TreeNode* start_node_a, TreeNode* best_node_a) { // local variables // PhoneticDecisionTreeNode* start_pdt_node = (PhoneticDecisionTreeNode*)NULL; PhoneticDecisionTreeNode* best_pdt_node = (PhoneticDecisionTreeNode*)NULL; Long typical_index = -1; DataPoint datapoint; bool8 res = true; // check the nodes // if (start_node_a == (TreeNode*)NULL) { return Error::handle(name(), L"updateTypicalIndex - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } if (best_node_a == (TreeNode*)NULL) { return Error::handle(name(), L"updateTypicalIndex - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get the typical-index of the statistical-model at the start-node // start_pdt_node = start_node_a->getItem(); // check the data // if (start_pdt_node == (PhoneticDecisionTreeNode*)NULL) { return Error::handle(name(), L"updateTypicalIndex - NULL DATA", Error::ARG, __FILE__, __LINE__); } typical_index = start_pdt_node->getTypicalIndex(); // update the typical-index of the statistical-model at the best-node // best_pdt_node = best_node_a->getItem(); // check the data // if (best_pdt_node == (PhoneticDecisionTreeNode*)NULL) { return Error::handle(name(), L"updateTypicalIndex - NULL DATA", Error::ARG, __FILE__, __LINE__); } res = best_pdt_node->setTypicalIndex(typical_index); // exit gracefully // return res; } // method: reindexSubTree // // arguments: // Treenode* node: (input) input node // int32& index: (input/output) value of index // // return: a bool8 value indicating status // // this method reindexes the statistical-models at the leaf-nodes of // the sub-tree under the input node // bool8 PhoneticDecisionTree::reindexSubTree(TreeNode* node_a, int32& index_a) { // define local variable // TreeNode* node = (TreeNode*)NULL; bool8 res = false; SingleLinkedList leaf_nodes(DstrBase::USER); // check the input node // node = node_a; if (node == (TreeNode*)NULL) { return Error::handle(name(), L"reindexSubTree - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // get all the leaf-nodes below the input node // res = getLeafNodes(*node, leaf_nodes); // loop-over all the leaf-nodes and reindex the statistical models // for (bool8 more = leaf_nodes.gotoFirst(); more; more = leaf_nodes.gotoNext()) { // local variables // TreeNode* temp_node = (TreeNode*)NULL; // get the leaf-node // temp_node = leaf_nodes.getCurr(); // check the node // if (temp_node == (TreeNode*)NULL) { return Error::handle(name(), L"reindexTrain - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // local variables // PhoneticDecisionTreeNode* pdt_node = (PhoneticDecisionTreeNode*)NULL;; Data data; DataPoint datapoint; float64 max_scale = (float64)-1000; DataPoint typical_datapoint; StatisticalModel typical_stat_model; // loop-over all the data at this leaf-node and find the typical // statistical-model to which all the rest of the // statistical-models at this node will be tied to. the model with // the highest scale or lowest variance is a typical model // // get all the data points on the node // pdt_node = temp_node->getItem(); // check the node // if (pdt_node == (PhoneticDecisionTreeNode*)NULL) { return Error::handle(name(), L"getStatTrain - NULL PDT VERTEX", Error::ARG, __FILE__, __LINE__); } data = pdt_node->getDataPoints(); for (bool8 morea = data.gotoFirst(); morea; morea = data.gotoNext()) { // local variables // float64 scale; StatisticalModel datapoint_stat_model; Long temp_index; // get the inverse scale for this datapoint // datapoint = *(data.getCurr()); datapoint_stat_model = datapoint.second(); scale = computeScale(datapoint_stat_model); scale = 2.0 * scale; // get the index of this statistical-model // temp_index = datapoint.first(); // find the typical statistical model // if ( scale > max_scale) { max_scale = scale; typical_datapoint = datapoint; typical_stat_model = datapoint_stat_model; } } // end of for-loop over all datapoints at a leafnode // save the new typical-index, actual-index and statistical-model // that represents this leaf-node to this leaf node so that we can // later retrive it. tupical-index gets update during merging but // actual index remains the same // pdt_node->setTypicalIndex((Long)index_a); pdt_node->setActualIndex((Long)index_a); pdt_node->setTypicalStatModel(typical_stat_model); // increment the index // index_a++; } // end of for loop over leaf-nodes // exit gracefully // return res; } // method: getCentralSymbols // // arguments: // Vector& symbol_table: (input) symbol-table // Vector& contextless_symbol_table: (input) contextless // symbol-table // SingleLinkedList& central_symbols: (output) central symbols // // return: a bool8 value indicating status // // this method get the central symbols excluding the contextless // symbols, NO_LEFT_CONTEXT, and NO_RIGHT_CONTEXT // bool8 PhoneticDecisionTree::getCentralSymbols(Vector& symbol_table_a, Vector& contextless_symbol_table_a, SingleLinkedList& central_symbols_a) { // local variables // bool8 status = true; for (int32 k = 0; k < symbol_table_a.length(); k++) { // local variables // String temp_symbol; bool8 include = true; temp_symbol.assign(symbol_table_a(k)); // don't include the contextless symbols // for (int32 kk = 0; kk < contextless_symbol_table_a.length(); kk++) { if (temp_symbol.eq(contextless_symbol_table_a(kk))) { include = false; } } // don't include NO_LEFT_CONTEXT and NO_RIGHT_CONTEXT // if ((temp_symbol.eq(SearchSymbol::NO_LEFT_CONTEXT)) || (temp_symbol.eq(SearchSymbol::NO_RIGHT_CONTEXT))) { include = false; } // include only necessary symbols // if (include) { status = central_symbols_a.insert(&temp_symbol); } } // exit gracefully // return status; } // method: readQuestionAnswer // // arguments: // Filename& ques_ans_file: (input) file containing questions and answers // SingleLinkedList >& questions: (output) phonetic // questions // HashTable& answers: (output) phonetic answers // // return: a bool8 value indicating status // // this method reads the phonetic questions and answers from the input file // bool8 PhoneticDecisionTree:: readQuestionAnswer(Filename& ques_ans_file_a, SingleLinkedList >& questions_a, HashTable& answers_a) { // local variables // bool8 status = true; Vector > ques_ans; // check the question and answers // if (ques_ans_file_a.length() == 0) { return Error::handle(name(), L"loadTrain - invalid question-answer file", ERR, __FILE__, __LINE__); } // print debugging information // if (debug_level_d >= Integral::DETAILED) { Console::increaseIndention(); String output; output.assign(L"\nloading question and answers: "); output.concat(ques_ans_file_a); Console::put(output); Console::decreaseIndention(); } // open the input sof file // Sof input_sof; if(!input_sof.open(ques_ans_file_a, File::READ_ONLY)) { return Error::handle(ques_ans_file_a, L"open", ERR, __FILE__, __LINE__); } // read the question and answers from an sof file // if (!ques_ans.read(input_sof, int32(0))) { return Error::handle(name(), L"readQuestionAnswer - error reading the question and answer file", Error::ARG, __FILE__, __LINE__); } // close the input questions and answers file // input_sof.close(); // local variables // SingleLinkedList temp_all_ans; String tmp_yes; tmp_yes.assign(YES); String tmp_no; tmp_no.assign(NO); temp_all_ans.insert(&tmp_yes); temp_all_ans.insert(&tmp_no); // read questions and then answeres seperately that will be used // later // for (int32 k = 0; k < ques_ans.length(); k++) { // local variables // Long temp_index; String temp_string; String temp_ques; String temp_symbol; Pair temp_question; temp_index = ques_ans(k).first(); temp_ques = ques_ans(k).second(); temp_symbol = ques_ans(k).third(); // read the questions only if the question is not already read // before // temp_question.assign(temp_index, temp_ques); if (!questions_a.contains(&temp_question)) { questions_a.insert(&temp_question); // add the questions & possible-answers to the attributes_d // Attribute temp_attribute; String temp_question_string; temp_question_string.assign(temp_index); temp_question_string.concat(temp_ques); temp_attribute.assign(temp_question_string, temp_all_ans); attributes_d.insert(&temp_attribute); } // read the answeres variable // temp_string.assign(temp_index); temp_string.concat(temp_ques); temp_string.concat(temp_symbol); String temp_yes; temp_yes.assign(YES); answers_a.insert(temp_string, &temp_yes); } // exit gracefully // return status; } // method: poolStatisticalModel // // arguments: // Vector& context_map: (input) ContextMaps // Vector& contextless_symbol_table: (input) symbol-table with // no context (phones) // Vector >& sub_graphs: (input) Subgraphs for the // ContextMaps // Vector& symbol_table: (input) Symbol Table (states) // HashTable& symbol_hash: (input) Mapping from symbols // to StatisticalModels // Vector& stat_models: (input) Pool of StatisticalModels // int32& context_len: (input) length of the contexts // SingleLinkedList >& questions: (input) phonetic // questions // HashTable& answers: (input) phonetic answers // Data& data: (output) output pooled StaisticalModels // HashTable& tied_symbol_hash: (output) HashTable of // search-symbols // whose corresponding // statistical models are tied // Vector& tied_stat_models : (output) tied stat models // // return: a bool8 value indicating status // // this method pools the valid Statistical Models and associates attributes to // be used in the state-tying process // bool8 PhoneticDecisionTree:: poolStatisticalModel(Vector& context_map_a, Vector& contextless_symbol_table_a, Vector >& sub_graphs_a, Vector& symbol_table_a, HashTable& symbol_hash_a, Vector& stat_models_a, int32& context_len_a, SingleLinkedList >& questions_a, HashTable& answers_a, Data& data_a, HashTable& tied_symbol_hash_a, Vector& tied_stat_models_a) { // local variables // bool8 status = true; String temp_cph; temp_cph.assign(CPH); String temp_pos; temp_pos.assign(POS); Long tied_index = 0; HashTable temp_hash; // loop-over all the context-maps and pool the // statistical-models. check if all the context-maps have the same // length as expected // for (int32 i = 0; i < context_map_a.length(); i++) { // local variables // Vector context; Ulong context_index; DiGraph sub_graph; SingleLinkedList snodes; SingleLinkedList sarcs; HashTable attr_value; // get the context and the context-index // context = context_map_a(i).getContext(); context_index = context_map_a(i).getContextIndex(); // reset the flag if the middle symbol is a contextless-symbol // Boolean use_context = true; int32 temp_context_len = context.length(); // check if the middle symbol is a contextless-symbol // for (int32 k = 0; k < contextless_symbol_table_a.length(); k++) { // is the central-symbol a contextless-symbol ? // if (context(temp_context_len/2).eq(contextless_symbol_table_a(k))) { use_context = false; } } // proceed further only if the context is valid // if (use_context) { // error if the length of the context is not the same as // expected // if (context.length() != context_len_a) { return Error::handle(name(), L"loadTrain - ALL THE CONTEXTS DON'T HAVE SAME LENGTH", Error::ARG, __FILE__, __LINE__); } // extract the attributes for this context by looping over all the // questions // for (bool8 morea = questions_a.gotoFirst(); morea; morea = questions_a.gotoNext()) { // local variables // Pair* question; Long direction; SearchSymbol symbol; String question_string; String extended_question; String direction_string; String* answer; String temp_no; temp_no.assign(NO); // get the direction (this also gives the posistion) from the // question // question = questions_a.getCurr(); direction = question->first(); // get the search symbol corresponding to the direction and // position // symbol = context((context.length()/2) + direction); // combine direction(position), question and symbol // direction_string.assign(direction); question_string.concat(direction_string); question_string.concat(question->second()); extended_question.assign(question_string); extended_question.concat(symbol); // get the answer for this extended questions from the answers // hashtable. if the hashtable contains the key, it has the // answer. else the answer is NO // if (answers_a.containsKey(extended_question)) { answer = answers_a.get(extended_question); } else answer = &temp_no; // use the question_string and answer to form attribute-value pair // that is added to the attr_value hashtable // attr_value.insert(question_string, answer); } // add the central-symbol (central-phone) to the attributes // String answer; SearchSymbol symbol = context((context.length() - 1)/2); answer.assign(symbol); attr_value.insert(temp_cph, &answer); // get the corresponding sub-graph for this context-index // sub_graph = sub_graphs_a(context_index); // loop-over all the SearchNode's in this sub-graph, get the // symbol id's (states) and then accumulate the statistical-models // corresponding to the symbols (states) // sub_graph.get(snodes, sarcs); // intialize the position of the symbol (state) // int32 position = (int32)0; for (bool8 more = snodes.gotoFirst(); more; more = snodes.gotoNext()) { // local variables // DataPoint data_point; SearchNode* snode; int32 symbol_id; Long* stat_model_index; StatisticalModel stat_model; HashTable temp_attr_value; // get the SearchNode // snode = snodes.getCurr(); // get the symbol id at this search node // symbol_id = snode->getSymbolId(); // proceed only if the symbol is not dummy or default, // I.e. symbol-id is non-negative // if(symbol_id >= (int32)0) { // get the index corresponding to this symbol (state) // SearchSymbol search_symbol = symbol_table_a(symbol_id); stat_model_index = symbol_hash_a.get(search_symbol); // get the statistical model // stat_model = stat_models_a(*stat_model_index); // check if this statistical model corresponding to this // search-symbol is tied // if (!isTiedSSymbol(search_symbol, symbol_hash_a)) { // add the symbol-position (state-position) to the // attributes, increment the position // temp_attr_value.assign(attr_value); String position_string; position_string.assign(position); temp_attr_value.insert(temp_pos, &position_string); position++; // get a datapoint by combining statistical-model index, // statistical-model and the hash-table of attribute-value // data_point.assign(*stat_model_index, stat_model, temp_attr_value); // add this data-point to the data // data_a.insert(&data_point); } // end of if tied-symbol // else add this search symbol to the hash table // else { // add the corresponding statistical model to the vector if // it doesn't already exists in the vector // if (!temp_hash.containsKey((Long)*stat_model_index)) { int32 temp_len = tied_stat_models_a.length(); tied_stat_models_a.setLength(temp_len + (int32)1); tied_stat_models_a(temp_len).assign(stat_model); if (!temp_hash.insert(*stat_model_index, &tied_index)) { return Error::handle(name(), L"poolStatisticalModel - nonunique search-symbols", ERR, __FILE__, __LINE__); } tied_index++; } Long tmp_index = 0; tmp_index = *(temp_hash.get(*stat_model_index)); if (!tied_symbol_hash_a.insert(search_symbol, &tmp_index)) { return Error::handle(name(), L"poolStatisticalModel - nonunique search-symbols", ERR, __FILE__, __LINE__); } } // end of else } // end of if statement } // end of loop-over of SearchNodes that contain search-symbol id's } // end of if statement for valid context // else, add the search-symbols and the corresponding // statistical-models corresponding to the contexts with // contextless symbols // else { // get the corresponding sub-graph for this context-index // sub_graph = sub_graphs_a(context_index); // loop-over all the SearchNode's in this sub-graph, get the // symbol id's (states) and then accumulate the statistical-models // corresponding to the symbols (states) // sub_graph.get(snodes, sarcs); for (bool8 more = snodes.gotoFirst(); more; more = snodes.gotoNext()) { // local variables // DataPoint data_point; SearchNode* snode; int32 symbol_id; Long* stat_model_index; StatisticalModel stat_model; HashTable temp_attr_value; // get the SearchNode // snode = snodes.getCurr(); // get the symbol id at this search node // symbol_id = snode->getSymbolId(); // proceed only if the symbol is not dummy or default, // I.e. symbol-id is non-negative // if(symbol_id >= (int32)0) { // get the index corresponding to this symbol (state) // SearchSymbol search_symbol = symbol_table_a(symbol_id); stat_model_index = symbol_hash_a.get(search_symbol); // get the statistical model // stat_model = stat_models_a(*stat_model_index); // add the corresponding statistical model to the vector if // it doesn't already exists in the vector // if (!temp_hash.containsKey((Long)*stat_model_index)) { int32 temp_len = tied_stat_models_a.length(); tied_stat_models_a.setLength(temp_len + (int32)1); tied_stat_models_a(temp_len).assign(stat_model); if (!temp_hash.insert(*stat_model_index, &tied_index)) { return Error::handle(name(), L"poolStatisticalModel - nonunique search-symbols", ERR, __FILE__, __LINE__); } tied_index++; } Long tmp_index = 0; tmp_index = *(temp_hash.get(*stat_model_index)); if (!tied_symbol_hash_a.insert(search_symbol, &tmp_index)) { return Error::handle(name(), L"poolStatisticalModel - nonunique search-symbols", ERR, __FILE__, __LINE__); } } } // end of if loop } //end of loop over else } // end of loop-over of context-maps // exit gracefully // return status; } // method: isTiedSSymbol // // arguments: // SearchSymbol& search_symbol: (input) a search symbol // HashTable& symbol_hash: (input) Mapping from symbols // to StatisticalModels // // return: a bool8 value indicating if the input search symbol is tied. // bool8 PhoneticDecisionTree::isTiedSSymbol(SearchSymbol& search_symbol_a, HashTable& symbol_hash_a) { // local variables // bool8 status = false; // get the statistical-model index corresponding to the input // search-symbol // Long stat_model_index = 0; stat_model_index = *(symbol_hash_a.get(search_symbol_a)); // loop over all the search-symbols and find if the statistical // model corresponding to this is tied to any other model // // get all the search symbols from the hash table // Vector symbols; if (!symbol_hash_a.keys(symbols)) { return Error::handle(name(), L"isTiedSSymbol", ERR, __FILE__, __LINE__); } // loop over all the symbols // for (int32 i = 0; i < symbols.length(); i++) { // avoid the input search symbol // SearchSymbol temp_symbol = symbols(i); if (!search_symbol_a.eq(temp_symbol)) { if (stat_model_index == (int64)symbol_hash_a.get(temp_symbol)) { status = true; } } } // exit gracefully // return status; }