// file: $isip/class/pr/RegressionDecisionTree/rdt_06.cc // version: $Id: rdt_06.cc 9451 2004-04-13 19:27:43Z gao $ // // isip include files // #include "RegressionDecisionTree.h" // method: createTransform // // arguments: // Vector& stat_models: (input) vector of statistical model // // return: a bool8 value indicating status // // this method create the transform for thr regression decision tree // by using the data from the statistical models // bool8 RegressionDecisionTree::createTransform(Vector& stat_models_a) { // local variables // RTreeNode* root_node = (RTreeNode*)NULL; // get the root node // root_node = getFirst(); // check the node // if (root_node == (RTreeNode*)NULL) { return Error::handle(name(), L"createTransform - NULL VERTEX", Error::ARG, __FILE__, __LINE__); } // create transformation for each node in the regression tree // createTransforms(root_node, stat_models_a); updateTransformID(stat_models_a); // exit gracefully // return true; } // method: createTransforms // // arguments: // RTreeNode*& root_node: (input) root node // Vector& stat_models: (input) vector of statistical model // // return: a bool8 value indicating status // // this method create transformation for each node starting from the // root node by using the data from the statistical model // bool8 RegressionDecisionTree::createTransforms(RTreeNode*& root_node_a, Vector& stat_models_a) { DoubleLinkedList >* children; BiGraphVertex* child_node; // check the node // if (root_node_a != (RTreeNode*)NULL) { // get all the child nodes of this node // children = root_node_a->getChildren(); // if this node has child nodes, accumulate the child nodes // if (root_node_a->gotoFirstChild()) { // loop for the left child // children->gotoFirst(); child_node = children->getCurr()->getVertex(); createTransforms(child_node, stat_models_a); } if (children->gotoNext()) { // loop for the right child // child_node = children->getCurr()->getVertex(); createTransforms(child_node, stat_models_a); } // local variables // RegressionDecisionTreeNode* rdt_node = (RegressionDecisionTreeNode*)NULL; // get the data on this node // rdt_node = root_node_a->getItem(); rdt_node->createTransform(stat_models_a); if (debug_level_d >= Integral::ALL) { rdt_node->getNodeIndex().debug(L"node index -- transformation"); } } // exit gracefully // return true; } // method: updateTransformID // // arguments: // Vector & stat_models: (input) input node // // return: a bool8 value indicating status // // this method creates a relationship between eachstatistical model // with regression class index // bool8 RegressionDecisionTree::updateTransformID(Vector& stat_models_a) { // local variables // SingleLinkedList leaf_nodes(DstrBase::USER); RTreeNode* root_node = (RTreeNode*)NULL; // get the root node and all the leaf nodes // root_node = getFirst(); getAllLeafNodes(*root_node, leaf_nodes); RTreeNode* temp_node = (RTreeNode*)NULL; int32 w_index = 0; // loop over all statistical models // for (int i = 0; i < stat_models_a.length(); i++) { // if the model to adapt is a single Gaussian model, add to data // strucutre directly // if (stat_models_d(i).getType() == StatisticalModel::GAUSSIAN_MODEL) { } // if the model to adapt is a mixture of Gaussian models, add each // mixture component to the data structure // else if (stat_models_d(i).getType() == StatisticalModel::MIXTURE_MODEL) { // get the linked list of mixture components // SingleLinkedList& models = const_cast(stat_models_a(i)).getMixtureModel(). getModels(); bool8 more = models.gotoFirst(); // Gaussian Model index // w_index = 0; while (more) { // get the current mixture component // StatisticalModel* model = models.getCurr(); // check whether it is Gaussian // if (model->getType() == StatisticalModel::GAUSSIAN_MODEL) { // w_index++; } else { return Error::handle(name(), L"updateTransformID", RegressionDecisionTree::ERR_ADAPT_NO_GAUSSIAN, __FILE__, __LINE__); } // move to the next mixture component // more = models.gotoNext(); // decide which node it belongs to // for (bool8 more_leaf = leaf_nodes.gotoFirst(); more_leaf; more_leaf = leaf_nodes.gotoNext()) { temp_node = leaf_nodes.getCurr(); // local variables // RegressionDecisionTreeNode* rdt_node = (RegressionDecisionTreeNode*)NULL; // get the data on this node // rdt_node = temp_node->getItem(); if ( rdt_node ->containModel(i, w_index)) { int32 temp_index = mixture_offset_d(i) + w_index; map_stat_to_trans_d(temp_index) = rdt_node->getNodeIndex(); break; } } w_index++; } // end of loop over all mixtures } // end of else if mixture model // it is neither single Gaussian, nor mixture model // else { return Error::handle(name(), L"updateTransformID", RegressionDecisionTree::ERR_ADAPT_NO_GAUSSIAN, __FILE__, __LINE__); } // end of else unsupported model } // end of loop over all statistical models // exit gracefully // return true; } // method: findTerminal // // arguments: // RTreeNode* root_node: (input) input root node // int32 index_a: (input) the best score upto now // // return: a TreeNode point for the node index node represented // // this method return the node represented by this index number. If // the node is not found, return null. // BiGraphVertex* RegressionDecisionTree::findTerminal(RTreeNode*& root_node_a, int32 index_a) { // local variables // float32 score = (float32)0; DoubleLinkedList >* children; BiGraphArc* child; BiGraphVertex* child_node; RTreeNode* best = (RTreeNode*)NULL; // check the node // if (root_node_a != (RTreeNode*)NULL) { // get all the child nodes of this node // children = root_node_a->getChildren(); // if this node has child nodes, accumulate the child nodes // if (root_node_a->gotoFirstChild()) { // loop for the left child // children->gotoFirst(); child_node = children->getCurr()->getVertex(); best = findTerminal(child_node, index_a); if(best != (RTreeNode*)NULL) { return best; } if (child_node == (RTreeNode*)NULL) { child = children->getCurr(); child_node = child->getVertex(); // local variables // RegressionDecisionTreeNode* rdt_node = (RegressionDecisionTreeNode*)NULL; // get the data on this node // rdt_node = child_node->getItem(); score = rdt_node->getClusterScore(); return best; } if( children->gotoNext()) { // loop for the right child // child_node = children->getCurr()->getVertex(); best = findTerminal(child_node, index_a); if(best != (RTreeNode*)NULL) { return best; } } } } RegressionDecisionTreeNode* root_node = (RegressionDecisionTreeNode*)NULL; root_node = root_node_a->getItem(); if ( root_node->getNodeIndex() == index_a) best = root_node_a; // return the best node // return best; }